Support for dicts and list in operators template_fields
This commit is contained in:
Родитель
9d3fa3f34b
Коммит
472d3c1221
|
@ -312,9 +312,9 @@ class HiveServer2Hook(BaseHook):
|
|||
Get a set of records from a Hive query.
|
||||
|
||||
>>> hh = HiveServer2Hook()
|
||||
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
|
||||
>>> hh.get_records(sql)
|
||||
[[340698]]
|
||||
>>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100"
|
||||
>>> len(hh.get_records(sql))
|
||||
100
|
||||
'''
|
||||
return self.get_results(hql, schema=schema)['data']
|
||||
|
||||
|
@ -323,10 +323,10 @@ class HiveServer2Hook(BaseHook):
|
|||
Get a pandas dataframe from a Hive query
|
||||
|
||||
>>> hh = HiveServer2Hook()
|
||||
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
|
||||
>>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100"
|
||||
>>> df = hh.get_pandas_df(sql)
|
||||
>>> df.to_dict()
|
||||
{'num': {0: 340698}}
|
||||
>>> len(df.index)
|
||||
100
|
||||
'''
|
||||
import pandas as pd
|
||||
res = self.get_results(hql, schema=schema)
|
||||
|
|
|
@ -709,10 +709,18 @@ class TaskInstance(Base):
|
|||
if self.task.dag.user_defined_macros:
|
||||
jinja_context.update(
|
||||
self.task.dag.user_defined_macros)
|
||||
|
||||
rt = self.task.render_template # shortcut to method
|
||||
for attr in task.__class__.template_fields:
|
||||
result = getattr(task, attr)
|
||||
template = self.task.get_template(attr)
|
||||
result = template.render(**jinja_context)
|
||||
content = getattr(task, attr)
|
||||
if isinstance(content, basestring):
|
||||
result = rt(content, jinja_context)
|
||||
elif isinstance(content, list):
|
||||
result = [rt(s, jinja_context) for s in content]
|
||||
elif isinstance(content, dict):
|
||||
result = {k: rt(content[k], jinja_context) for k in content}
|
||||
else:
|
||||
raise Exception("Type not supported for templating")
|
||||
setattr(task, attr, result)
|
||||
|
||||
def email_alert(self, exception, is_retry=False):
|
||||
|
@ -913,8 +921,7 @@ class BaseOperator(Base):
|
|||
'''
|
||||
pass
|
||||
|
||||
def get_template(self, attr):
|
||||
content = getattr(self, attr)
|
||||
def render_template(self, content, context):
|
||||
if hasattr(self, 'dag'):
|
||||
env = self.dag.get_template_env()
|
||||
else:
|
||||
|
@ -925,7 +932,7 @@ class BaseOperator(Base):
|
|||
template = env.get_template(content)
|
||||
else:
|
||||
template = env.from_string(content)
|
||||
return template
|
||||
return template.render(**context)
|
||||
|
||||
def prepare_template(self):
|
||||
'''
|
||||
|
@ -940,7 +947,8 @@ class BaseOperator(Base):
|
|||
# Getting the content of files for template_field / template_ext
|
||||
for attr in self.template_fields:
|
||||
content = getattr(self, attr)
|
||||
if any([content.endswith(ext) for ext in self.template_ext]):
|
||||
if (content and isinstance(content, basestring) and
|
||||
any([content.endswith(ext) for ext in self.template_ext])):
|
||||
env = self.dag.get_template_env()
|
||||
try:
|
||||
setattr(self, attr, env.loader.get_source(env, content)[0])
|
||||
|
|
|
@ -46,7 +46,7 @@ class MySqlToHiveTransfer(BaseOperator):
|
|||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'MySqlToHiveOperator'
|
||||
}
|
||||
template_fields = ('sql',)
|
||||
template_fields = ('sql', 'partition')
|
||||
template_ext = ('.sql',)
|
||||
ui_color = '#a0e08c'
|
||||
|
||||
|
@ -71,6 +71,7 @@ class MySqlToHiveTransfer(BaseOperator):
|
|||
self.delimiter = delimiter
|
||||
self.hive = HiveCliHook(hive_cli_conn_id=hive_cli_conn_id)
|
||||
self.mysql = MySqlHook(mysql_conn_id=mysql_conn_id)
|
||||
self.partition = partition or {}
|
||||
|
||||
@classmethod
|
||||
def type_map(cls, mysql_type):
|
||||
|
|
|
@ -318,9 +318,6 @@ def import_module_attrs(parent_module_globals, module_attrs_dict):
|
|||
silence the import errors for when libraries are missing. It makes
|
||||
for a clean package abstracting the underlying modules and only
|
||||
brings funcitonal operators to those namespaces.
|
||||
|
||||
>>> module_attrs = {'operators': ['BashOperator']}
|
||||
>>> import_module_attrs(globals(), module_attrs)
|
||||
'''
|
||||
imported_attrs = []
|
||||
for mod, attrs in module_attrs_dict.items():
|
||||
|
|
|
@ -621,7 +621,7 @@ class Airflow(BaseView):
|
|||
)
|
||||
else:
|
||||
html_dict[template_field] = (
|
||||
"<pre><code>" + content + "</pre></code>")
|
||||
"<pre><code>" + str(content) + "</pre></code>")
|
||||
|
||||
return self.render(
|
||||
'airflow/ti_code.html',
|
||||
|
@ -1320,6 +1320,7 @@ admin.add_view(mv)
|
|||
|
||||
|
||||
class ConnectionModelView(SuperUserMixin, ModelView):
|
||||
column_default_sort = ('conn_id', False)
|
||||
column_list = ('conn_id', 'conn_type', 'host', 'port')
|
||||
form_choices = {
|
||||
'conn_type': [
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
export AIRFLOW_HOME=${AIRFLOW_HOME:=~/airflow}
|
||||
export AIRFLOW_CONFIG=~/airflow/unittests.cfg
|
||||
export AIRFLOW_CONFIG=$AIRFLOW_HOME/unittests.cfg
|
||||
rm airflow/www/static/coverage/*
|
||||
nosetests --with-doctest --with-coverage --cover-html --cover-package=airflow -v --cover-html-dir=airflow/www/static/coverage
|
||||
|
|
Загрузка…
Ссылка в новой задаче