From 48135ad255263d4718bbcace39c746aea5929568 Mon Sep 17 00:00:00 2001 From: Niels Zeilemaker Date: Sat, 29 Apr 2017 17:14:40 +0200 Subject: [PATCH] [AIRFLOW 1149][AIRFLOW-1149] Allow for custom filters in Jinja2 templates Closes #2258 from NielsZeilemaker/jinja_custom_filters --- airflow/models.py | 19 +++++++++++++--- docs/tutorial.rst | 10 +++++++++ tests/models.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 3 deletions(-) diff --git a/airflow/models.py b/airflow/models.py index d2f78946e5..aab4833ae8 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -2249,11 +2249,13 @@ class BaseOperator(object): memo[id(self)] = result for k, v in list(self.__dict__.items()): - if k not in ('user_defined_macros', 'params'): + if k not in ('user_defined_macros', 'user_defined_filters', 'params'): setattr(result, k, copy.deepcopy(v, memo)) result.params = self.params if hasattr(self, 'user_defined_macros'): result.user_defined_macros = self.user_defined_macros + if hasattr(self, 'user_defined_filters'): + result.user_defined_filters = self.user_defined_filters return result def render_template_from_field(self, attr, content, context, jinja_env): @@ -2644,6 +2646,12 @@ class DAG(BaseDag, LoggingMixin): templates related to this DAG. Note that you can pass any type of object here. :type user_defined_macros: dict + :param user_defined_filters: a dictionary of filters that will be exposed + in your jinja templates. For example, passing + ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows + you to ``{{ 'world' | hello }}`` in all jinja templates related to + this DAG. + :type user_defined_filters: dict :param default_args: A dictionary of default parameters to be used as constructor keyword parameters when initialising operators. Note that operators have the same hook, and precede those defined @@ -2684,6 +2692,7 @@ class DAG(BaseDag, LoggingMixin): full_filepath=None, template_searchpath=None, user_defined_macros=None, + user_defined_filters=None, default_args=None, concurrency=configuration.getint('core', 'dag_concurrency'), max_active_runs=configuration.getint( @@ -2696,6 +2705,7 @@ class DAG(BaseDag, LoggingMixin): params=None): self.user_defined_macros = user_defined_macros + self.user_defined_filters = user_defined_filters self.default_args = default_args or {} self.params = params or {} @@ -3034,7 +3044,7 @@ class DAG(BaseDag, LoggingMixin): def get_template_env(self): """ Returns a jinja2 Environment while taking into account the DAGs - template_searchpath and user_defined_macros + template_searchpath, user_defined_macros and user_defined_filters """ searchpath = [self.folder] if self.template_searchpath: @@ -3046,6 +3056,8 @@ class DAG(BaseDag, LoggingMixin): cache_size=0) if self.user_defined_macros: env.globals.update(self.user_defined_macros) + if self.user_defined_filters: + env.filters.update(self.user_defined_filters) return env @@ -3212,10 +3224,11 @@ class DAG(BaseDag, LoggingMixin): result = cls.__new__(cls) memo[id(self)] = result for k, v in list(self.__dict__.items()): - if k not in ('user_defined_macros', 'params'): + if k not in ('user_defined_macros', 'user_defined_filters', 'params'): setattr(result, k, copy.deepcopy(v, memo)) result.user_defined_macros = self.user_defined_macros + result.user_defined_filters = self.user_defined_filters result.params = self.params return result diff --git a/docs/tutorial.rst b/docs/tutorial.rst index d047f82550..dc09482003 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -231,6 +231,16 @@ different languages, and general flexibility in structuring pipelines. It is also possible to define your ``template_searchpath`` as pointing to any folder locations in the DAG constructor call. +Using that same DAG constructor call, it is possible to define +``user_defined_macros`` which allow you to specify your own variables. +For example, passing ``dict(foo='bar')`` to this argument allows you +to use ``{{ foo }}`` in your templates. Moreover, specifying +``user_defined_filters`` allow you to register you own filters. For example, +passing ``dict(hello=lambda name: 'Hello %s' % name)`` to this argument allows +you to use ``{{ 'world' | hello }}`` in your templates. For more information +regarding custom filters have a look at the +`Jinja Documentation `_ + For more information on the variables and macros that can be referenced in templates, make sure to read through the :ref:`macros` section diff --git a/tests/models.py b/tests/models.py index 49e5c75691..4c2a15f4fa 100644 --- a/tests/models.py +++ b/tests/models.py @@ -233,6 +233,61 @@ class DagTest(unittest.TestCase): states=[None, State.QUEUED, State.RUNNING], session=session)) session.close() + def test_render_template_field(self): + """Tests if render_template from a field works""" + + dag = DAG('test-dag', + start_date=DEFAULT_DATE) + + with dag: + task = DummyOperator(task_id='op1') + + result = task.render_template('', '{{ foo }}', dict(foo='bar')) + self.assertEqual(result, 'bar') + + def test_render_template_field_macro(self): + """ Tests if render_template from a field works, + if a custom filter was defined""" + + dag = DAG('test-dag', + start_date=DEFAULT_DATE, + user_defined_macros = dict(foo='bar')) + + with dag: + task = DummyOperator(task_id='op1') + + result = task.render_template('', '{{ foo }}', dict()) + self.assertEqual(result, 'bar') + + def test_user_defined_filters(self): + def jinja_udf(name): + return 'Hello %s' %name + + dag = models.DAG('test-dag', + start_date=DEFAULT_DATE, + user_defined_filters=dict(hello=jinja_udf)) + jinja_env = dag.get_template_env() + + self.assertIn('hello', jinja_env.filters) + self.assertEqual(jinja_env.filters['hello'], jinja_udf) + + def test_render_template_field_filter(self): + """ Tests if render_template from a field works, + if a custom filter was defined""" + + def jinja_udf(name): + return 'Hello %s' %name + + dag = DAG('test-dag', + start_date=DEFAULT_DATE, + user_defined_filters = dict(hello=jinja_udf)) + + with dag: + task = DummyOperator(task_id='op1') + + result = task.render_template('', "{{ 'world' | hello}}", dict()) + self.assertEqual(result, 'Hello world') + class DagStatTest(unittest.TestCase): def test_dagstats_crud(self):