# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import datetime import inspect import logging import os import re import textwrap import time import unittest import urllib import uuid from tempfile import NamedTemporaryFile, mkdtemp import pendulum import six from mock import ANY, Mock, mock_open, patch from parameterized import parameterized from freezegun import freeze_time from cryptography.fernet import Fernet from airflow import AirflowException, configuration, models, settings from airflow.contrib.sensors.python_sensor import PythonSensor from airflow.exceptions import AirflowDagCycleException, AirflowSkipException from airflow.jobs import BackfillJob from airflow.models import DAG, TaskInstance as TI from airflow.models import DagModel, DagRun from airflow.models import KubeResourceVersion, KubeWorkerIdentifier from airflow.models import SkipMixin from airflow.models import State as ST from airflow.models import TaskReschedule as TR from airflow.models import XCom from airflow.models import Variable from airflow.models import clear_task_instances from airflow.models.connection import Connection from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import PythonOperator from airflow.operators.python_operator import ShortCircuitOperator from airflow.operators.subdag_operator import SubDagOperator from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone from airflow.utils.dag_processing import SimpleTaskInstance from airflow.utils.db import create_session from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule DEFAULT_DATE = timezone.datetime(2016, 1, 1) TEST_DAGS_FOLDER = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'dags') class DagTest(unittest.TestCase): def test_params_not_passed_is_empty_dict(self): """ Test that when 'params' is _not_ passed to a new Dag, that the params attribute is set to an empty dictionary. """ dag = models.DAG('test-dag') self.assertEqual(dict, type(dag.params)) self.assertEqual(0, len(dag.params)) def test_params_passed_and_params_in_default_args_no_override(self): """ Test that when 'params' exists as a key passed to the default_args dict in addition to params being passed explicitly as an argument to the dag, that the 'params' key of the default_args dict is merged with the dict of the params argument. """ params1 = {'parameter1': 1} params2 = {'parameter2': 2} dag = models.DAG('test-dag', default_args={'params': params1}, params=params2) params_combined = params1.copy() params_combined.update(params2) self.assertEqual(params_combined, dag.params) def test_dag_as_context_manager(self): """ Test DAG as a context manager. When used as a context manager, Operators are automatically added to the DAG (unless they specify a different DAG) """ dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) dag2 = DAG( 'dag2', start_date=DEFAULT_DATE, default_args={'owner': 'owner2'}) with dag: op1 = DummyOperator(task_id='op1') op2 = DummyOperator(task_id='op2', dag=dag2) self.assertIs(op1.dag, dag) self.assertEqual(op1.owner, 'owner1') self.assertIs(op2.dag, dag2) self.assertEqual(op2.owner, 'owner2') with dag2: op3 = DummyOperator(task_id='op3') self.assertIs(op3.dag, dag2) self.assertEqual(op3.owner, 'owner2') with dag: with dag2: op4 = DummyOperator(task_id='op4') op5 = DummyOperator(task_id='op5') self.assertIs(op4.dag, dag2) self.assertIs(op5.dag, dag) self.assertEqual(op4.owner, 'owner2') self.assertEqual(op5.owner, 'owner1') with DAG('creating_dag_in_cm', start_date=DEFAULT_DATE) as dag: DummyOperator(task_id='op6') self.assertEqual(dag.dag_id, 'creating_dag_in_cm') self.assertEqual(dag.tasks[0].task_id, 'op6') with dag: with dag: op7 = DummyOperator(task_id='op7') op8 = DummyOperator(task_id='op8') op9 = DummyOperator(task_id='op8') op9.dag = dag2 self.assertEqual(op7.dag, dag) self.assertEqual(op8.dag, dag) self.assertEqual(op9.dag, dag2) def test_dag_topological_sort(self): dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B # A -> C -> D # ordered: B, D, C, A or D, B, C, A or D, C, B, A with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op3 = DummyOperator(task_id='C') op4 = DummyOperator(task_id='D') op1.set_upstream([op2, op3]) op3.set_upstream(op4) topological_list = dag.topological_sort() logging.info(topological_list) tasks = [op2, op3, op4] self.assertTrue(topological_list[0] in tasks) tasks.remove(topological_list[0]) self.assertTrue(topological_list[1] in tasks) tasks.remove(topological_list[1]) self.assertTrue(topological_list[2] in tasks) tasks.remove(topological_list[2]) self.assertTrue(topological_list[3] == op1) dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # C -> (A u B) -> D # C -> E # ordered: E | D, A | B, C with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op3 = DummyOperator(task_id='C') op4 = DummyOperator(task_id='D') op5 = DummyOperator(task_id='E') op1.set_downstream(op3) op2.set_downstream(op3) op1.set_upstream(op4) op2.set_upstream(op4) op5.set_downstream(op3) topological_list = dag.topological_sort() logging.info(topological_list) set1 = [op4, op5] self.assertTrue(topological_list[0] in set1) set1.remove(topological_list[0]) set2 = [op1, op2] set2.extend(set1) self.assertTrue(topological_list[1] in set2) set2.remove(topological_list[1]) self.assertTrue(topological_list[2] in set2) set2.remove(topological_list[2]) self.assertTrue(topological_list[3] in set2) self.assertTrue(topological_list[4] == op3) dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) self.assertEqual(tuple(), dag.topological_sort()) def test_dag_naive_default_args_start_date(self): dag = DAG('DAG', default_args={'start_date': datetime.datetime(2018, 1, 1)}) self.assertEqual(dag.timezone, settings.TIMEZONE) dag = DAG('DAG', start_date=datetime.datetime(2018, 1, 1)) self.assertEqual(dag.timezone, settings.TIMEZONE) def test_dag_none_default_args_start_date(self): """ Tests if a start_date of None in default_args works. """ dag = DAG('DAG', default_args={'start_date': None}) self.assertEqual(dag.timezone, settings.TIMEZONE) def test_dag_task_priority_weight_total(self): width = 5 depth = 5 weight = 5 pattern = re.compile('stage(\\d*).(\\d*)') # Fully connected parallel tasks. i.e. every task at each parallel # stage is dependent on every task in the previous stage. # Default weight should be calculated using downstream descendants with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag: pipeline = [ [DummyOperator( task_id='stage{}.{}'.format(i, j), priority_weight=weight) for j in range(0, width)] for i in range(0, depth) ] for d, stage in enumerate(pipeline): if d == 0: continue for current_task in stage: for prev_task in pipeline[d - 1]: current_task.set_upstream(prev_task) for task in six.itervalues(dag.task_dict): match = pattern.match(task.task_id) task_depth = int(match.group(1)) # the sum of each stages after this task + itself correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight calculated_weight = task.priority_weight_total self.assertEqual(calculated_weight, correct_weight) # Same test as above except use 'upstream' for weight calculation weight = 3 with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag: pipeline = [ [DummyOperator( task_id='stage{}.{}'.format(i, j), priority_weight=weight, weight_rule=WeightRule.UPSTREAM) for j in range(0, width)] for i in range(0, depth) ] for d, stage in enumerate(pipeline): if d == 0: continue for current_task in stage: for prev_task in pipeline[d - 1]: current_task.set_upstream(prev_task) for task in six.itervalues(dag.task_dict): match = pattern.match(task.task_id) task_depth = int(match.group(1)) # the sum of each stages after this task + itself correct_weight = (task_depth * width + 1) * weight calculated_weight = task.priority_weight_total self.assertEqual(calculated_weight, correct_weight) # Same test as above except use 'absolute' for weight calculation weight = 10 with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag: pipeline = [ [DummyOperator( task_id='stage{}.{}'.format(i, j), priority_weight=weight, weight_rule=WeightRule.ABSOLUTE) for j in range(0, width)] for i in range(0, depth) ] for d, stage in enumerate(pipeline): if d == 0: continue for current_task in stage: for prev_task in pipeline[d - 1]: current_task.set_upstream(prev_task) for task in six.itervalues(dag.task_dict): match = pattern.match(task.task_id) task_depth = int(match.group(1)) # the sum of each stages after this task + itself correct_weight = weight calculated_weight = task.priority_weight_total self.assertEqual(calculated_weight, correct_weight) # Test if we enter an invalid weight rule with DAG('dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) as dag: with self.assertRaises(AirflowException): DummyOperator(task_id='should_fail', weight_rule='no rule') def test_get_num_task_instances(self): test_dag_id = 'test_get_num_task_instances_dag' test_task_id = 'task_1' test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE) test_task = DummyOperator(task_id=test_task_id, dag=test_dag) ti1 = TI(task=test_task, execution_date=DEFAULT_DATE) ti1.state = None ti2 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) ti2.state = State.RUNNING ti3 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=2)) ti3.state = State.QUEUED ti4 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=3)) ti4.state = State.RUNNING session = settings.Session() session.merge(ti1) session.merge(ti2) session.merge(ti3) session.merge(ti4) session.commit() self.assertEqual( 0, DAG.get_num_task_instances(test_dag_id, ['fakename'], session=session) ) self.assertEqual( 4, DAG.get_num_task_instances(test_dag_id, [test_task_id], session=session) ) self.assertEqual( 4, DAG.get_num_task_instances( test_dag_id, ['fakename', test_task_id], session=session) ) self.assertEqual( 1, DAG.get_num_task_instances( test_dag_id, [test_task_id], states=[None], session=session) ) self.assertEqual( 2, DAG.get_num_task_instances( test_dag_id, [test_task_id], states=[State.RUNNING], session=session) ) self.assertEqual( 3, DAG.get_num_task_instances( test_dag_id, [test_task_id], states=[None, State.RUNNING], session=session) ) self.assertEqual( 4, DAG.get_num_task_instances( test_dag_id, [test_task_id], states=[None, State.QUEUED, State.RUNNING], session=session) ) session.close() def test_render_template_field(self): """Tests if render_template from a field works""" dag = DAG('test-dag', start_date=DEFAULT_DATE) with dag: task = DummyOperator(task_id='op1') result = task.render_template('', '{{ foo }}', dict(foo='bar')) self.assertEqual(result, 'bar') def test_render_template_list_field(self): """Tests if render_template from a list field works""" dag = DAG('test-dag', start_date=DEFAULT_DATE) with dag: task = DummyOperator(task_id='op1') self.assertListEqual( task.render_template('', ['{{ foo }}_1', '{{ foo }}_2'], {'foo': 'bar'}), ['bar_1', 'bar_2'] ) def test_render_template_tuple_field(self): """Tests if render_template from a tuple field works""" dag = DAG('test-dag', start_date=DEFAULT_DATE) with dag: task = DummyOperator(task_id='op1') # tuple is replaced by a list self.assertListEqual( task.render_template('', ('{{ foo }}_1', '{{ foo }}_2'), {'foo': 'bar'}), ['bar_1', 'bar_2'] ) def test_render_template_dict_field(self): """Tests if render_template from a dict field works""" dag = DAG('test-dag', start_date=DEFAULT_DATE) with dag: task = DummyOperator(task_id='op1') self.assertDictEqual( task.render_template('', {'key1': '{{ foo }}_1', 'key2': '{{ foo }}_2'}, {'foo': 'bar'}), {'key1': 'bar_1', 'key2': 'bar_2'} ) def test_render_template_dict_field_with_templated_keys(self): """Tests if render_template from a dict field works as expected: dictionary keys are not templated""" dag = DAG('test-dag', start_date=DEFAULT_DATE) with dag: task = DummyOperator(task_id='op1') self.assertDictEqual( task.render_template('', {'key_{{ foo }}_1': 1, 'key_2': '{{ foo }}_2'}, {'foo': 'bar'}), {'key_{{ foo }}_1': 1, 'key_2': 'bar_2'} ) def test_render_template_date_field(self): """Tests if render_template from a date field works""" dag = DAG('test-dag', start_date=DEFAULT_DATE) with dag: task = DummyOperator(task_id='op1') self.assertEqual( task.render_template('', datetime.date(2018, 12, 6), {'foo': 'bar'}), datetime.date(2018, 12, 6) ) def test_render_template_datetime_field(self): """Tests if render_template from a datetime field works""" dag = DAG('test-dag', start_date=DEFAULT_DATE) with dag: task = DummyOperator(task_id='op1') self.assertEqual( task.render_template('', datetime.datetime(2018, 12, 6, 10, 55), {'foo': 'bar'}), datetime.datetime(2018, 12, 6, 10, 55) ) def test_render_template_UUID_field(self): """Tests if render_template from a UUID field works""" dag = DAG('test-dag', start_date=DEFAULT_DATE) with dag: task = DummyOperator(task_id='op1') random_uuid = uuid.uuid4() self.assertIs( task.render_template('', random_uuid, {'foo': 'bar'}), random_uuid ) def test_render_template_object_field(self): """Tests if render_template from an object field works""" dag = DAG('test-dag', start_date=DEFAULT_DATE) with dag: task = DummyOperator(task_id='op1') test_object = object() self.assertIs( task.render_template('', test_object, {'foo': 'bar'}), test_object ) def test_render_template_field_macro(self): """ Tests if render_template from a field works, if a custom filter was defined""" dag = DAG('test-dag', start_date=DEFAULT_DATE, user_defined_macros=dict(foo='bar')) with dag: task = DummyOperator(task_id='op1') result = task.render_template('', '{{ foo }}', dict()) self.assertEqual(result, 'bar') def test_render_template_numeric_field(self): """ Tests if render_template from a field works, if a custom filter was defined""" dag = DAG('test-dag', start_date=DEFAULT_DATE, user_defined_macros=dict(foo='bar')) with dag: task = DummyOperator(task_id='op1') result = task.render_template('', 1, dict()) self.assertEqual(result, 1) def test_user_defined_filters(self): def jinja_udf(name): return 'Hello %s' % name dag = models.DAG('test-dag', start_date=DEFAULT_DATE, user_defined_filters=dict(hello=jinja_udf)) jinja_env = dag.get_template_env() self.assertIn('hello', jinja_env.filters) self.assertEqual(jinja_env.filters['hello'], jinja_udf) def test_render_template_field_filter(self): """ Tests if render_template from a field works, if a custom filter was defined""" def jinja_udf(name): return 'Hello %s' % name dag = DAG('test-dag', start_date=DEFAULT_DATE, user_defined_filters=dict(hello=jinja_udf)) with dag: task = DummyOperator(task_id='op1') result = task.render_template('', "{{ 'world' | hello}}", dict()) self.assertEqual(result, 'Hello world') def test_resolve_template_files_value(self): with NamedTemporaryFile(suffix='.template') as f: f.write('{{ ds }}'.encode('utf8')) f.flush() template_dir = os.path.dirname(f.name) template_file = os.path.basename(f.name) dag = DAG('test-dag', start_date=DEFAULT_DATE, template_searchpath=template_dir) with dag: task = DummyOperator(task_id='op1') task.test_field = template_file task.template_fields = ('test_field',) task.template_ext = ('.template',) task.resolve_template_files() self.assertEqual(task.test_field, '{{ ds }}') def test_resolve_template_files_list(self): with NamedTemporaryFile(suffix='.template') as f: f = NamedTemporaryFile(suffix='.template') f.write('{{ ds }}'.encode('utf8')) f.flush() template_dir = os.path.dirname(f.name) template_file = os.path.basename(f.name) dag = DAG('test-dag', start_date=DEFAULT_DATE, template_searchpath=template_dir) with dag: task = DummyOperator(task_id='op1') task.test_field = [template_file, 'some_string'] task.template_fields = ('test_field',) task.template_ext = ('.template',) task.resolve_template_files() self.assertEqual(task.test_field, ['{{ ds }}', 'some_string']) def test_cycle(self): # test empty dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) self.assertFalse(dag.test_cycle()) # test single task dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: opA = DummyOperator(task_id='A') self.assertFalse(dag.test_cycle()) # test no cycle dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B -> C # B -> D # E -> F with dag: opA = DummyOperator(task_id='A') opB = DummyOperator(task_id='B') opC = DummyOperator(task_id='C') opD = DummyOperator(task_id='D') opE = DummyOperator(task_id='E') opF = DummyOperator(task_id='F') opA.set_downstream(opB) opB.set_downstream(opC) opB.set_downstream(opD) opE.set_downstream(opF) self.assertFalse(dag.test_cycle()) # test self loop dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> A with dag: opA = DummyOperator(task_id='A') opA.set_downstream(opA) with self.assertRaises(AirflowDagCycleException): dag.test_cycle() # test downstream self loop dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B -> C -> D -> E -> E with dag: opA = DummyOperator(task_id='A') opB = DummyOperator(task_id='B') opC = DummyOperator(task_id='C') opD = DummyOperator(task_id='D') opE = DummyOperator(task_id='E') opA.set_downstream(opB) opB.set_downstream(opC) opC.set_downstream(opD) opD.set_downstream(opE) opE.set_downstream(opE) with self.assertRaises(AirflowDagCycleException): dag.test_cycle() # large loop dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B -> C -> D -> E -> A with dag: opA = DummyOperator(task_id='A') opB = DummyOperator(task_id='B') opC = DummyOperator(task_id='C') opD = DummyOperator(task_id='D') opE = DummyOperator(task_id='E') opA.set_downstream(opB) opB.set_downstream(opC) opC.set_downstream(opD) opD.set_downstream(opE) opE.set_downstream(opA) with self.assertRaises(AirflowDagCycleException): dag.test_cycle() # test arbitrary loop dag = DAG( 'dag', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # E-> A -> B -> F -> A # -> C -> F with dag: opA = DummyOperator(task_id='A') opB = DummyOperator(task_id='B') opC = DummyOperator(task_id='C') opD = DummyOperator(task_id='D') opE = DummyOperator(task_id='E') opF = DummyOperator(task_id='F') opA.set_downstream(opB) opA.set_downstream(opC) opE.set_downstream(opA) opC.set_downstream(opF) opB.set_downstream(opF) opF.set_downstream(opA) with self.assertRaises(AirflowDagCycleException): dag.test_cycle() def test_following_previous_schedule(self): """ Make sure DST transitions are properly observed """ local_tz = pendulum.timezone('Europe/Zurich') start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55), dst_rule=pendulum.PRE_TRANSITION) self.assertEqual(start.isoformat(), "2018-10-28T02:55:00+02:00", "Pre-condition: start date is in DST") utc = timezone.convert_to_utc(start) dag = DAG('tz_dag', start_date=start, schedule_interval='*/5 * * * *') _next = dag.following_schedule(utc) next_local = local_tz.convert(_next) self.assertEqual(_next.isoformat(), "2018-10-28T01:00:00+00:00") self.assertEqual(next_local.isoformat(), "2018-10-28T02:00:00+01:00") prev = dag.previous_schedule(utc) prev_local = local_tz.convert(prev) self.assertEqual(prev_local.isoformat(), "2018-10-28T02:50:00+02:00") prev = dag.previous_schedule(_next) prev_local = local_tz.convert(prev) self.assertEqual(prev_local.isoformat(), "2018-10-28T02:55:00+02:00") self.assertEqual(prev, utc) def test_following_previous_schedule_daily_dag_CEST_to_CET(self): """ Make sure DST transitions are properly observed """ local_tz = pendulum.timezone('Europe/Zurich') start = local_tz.convert(datetime.datetime(2018, 10, 27, 3), dst_rule=pendulum.PRE_TRANSITION) utc = timezone.convert_to_utc(start) dag = DAG('tz_dag', start_date=start, schedule_interval='0 3 * * *') prev = dag.previous_schedule(utc) prev_local = local_tz.convert(prev) self.assertEqual(prev_local.isoformat(), "2018-10-26T03:00:00+02:00") self.assertEqual(prev.isoformat(), "2018-10-26T01:00:00+00:00") _next = dag.following_schedule(utc) next_local = local_tz.convert(_next) self.assertEqual(next_local.isoformat(), "2018-10-28T03:00:00+01:00") self.assertEqual(_next.isoformat(), "2018-10-28T02:00:00+00:00") prev = dag.previous_schedule(_next) prev_local = local_tz.convert(prev) self.assertEqual(prev_local.isoformat(), "2018-10-27T03:00:00+02:00") self.assertEqual(prev.isoformat(), "2018-10-27T01:00:00+00:00") def test_following_previous_schedule_daily_dag_CET_to_CEST(self): """ Make sure DST transitions are properly observed """ local_tz = pendulum.timezone('Europe/Zurich') start = local_tz.convert(datetime.datetime(2018, 3, 25, 2), dst_rule=pendulum.PRE_TRANSITION) utc = timezone.convert_to_utc(start) dag = DAG('tz_dag', start_date=start, schedule_interval='0 3 * * *') prev = dag.previous_schedule(utc) prev_local = local_tz.convert(prev) self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00") self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00") _next = dag.following_schedule(utc) next_local = local_tz.convert(_next) self.assertEqual(next_local.isoformat(), "2018-03-25T03:00:00+02:00") self.assertEqual(_next.isoformat(), "2018-03-25T01:00:00+00:00") prev = dag.previous_schedule(_next) prev_local = local_tz.convert(prev) self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00") self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00") @patch('airflow.models.timezone.utcnow') def test_sync_to_db(self, mock_now): dag = DAG( 'dag', start_date=DEFAULT_DATE, ) with dag: DummyOperator(task_id='task', owner='owner1') SubDagOperator( task_id='subtask', owner='owner2', subdag=DAG( 'dag.subtask', start_date=DEFAULT_DATE, ) ) now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) mock_now.return_value = now session = settings.Session() dag.sync_to_db(session=session) orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one() self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'}) self.assertEqual(orm_dag.last_scheduler_run, now) self.assertTrue(orm_dag.is_active) self.assertIsNone(orm_dag.default_view) self.assertEqual(orm_dag.get_default_view(), configuration.conf.get('webserver', 'dag_default_view').lower()) self.assertEqual(orm_dag.safe_dag_id, 'dag') orm_subdag = session.query(DagModel).filter( DagModel.dag_id == 'dag.subtask').one() self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'}) self.assertEqual(orm_subdag.last_scheduler_run, now) self.assertTrue(orm_subdag.is_active) self.assertEqual(orm_subdag.safe_dag_id, 'dag__dot__subtask') @patch('airflow.models.timezone.utcnow') def test_sync_to_db_default_view(self, mock_now): dag = DAG( 'dag', start_date=DEFAULT_DATE, default_view="graph", ) with dag: DummyOperator(task_id='task', owner='owner1') SubDagOperator( task_id='subtask', owner='owner2', subdag=DAG( 'dag.subtask', start_date=DEFAULT_DATE, ) ) now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) mock_now.return_value = now session = settings.Session() dag.sync_to_db(session=session) orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one() self.assertIsNotNone(orm_dag.default_view) self.assertEqual(orm_dag.get_default_view(), "graph") class DagRunTest(unittest.TestCase): def create_dag_run(self, dag, state=State.RUNNING, task_states=None, execution_date=None, is_backfill=False, ): now = timezone.utcnow() if execution_date is None: execution_date = now if is_backfill: run_id = BackfillJob.ID_PREFIX + now.isoformat() else: run_id = 'manual__' + now.isoformat() dag_run = dag.create_dagrun( run_id=run_id, execution_date=execution_date, start_date=now, state=state, external_trigger=False, ) if task_states is not None: session = settings.Session() for task_id, state in task_states.items(): ti = dag_run.get_task_instance(task_id) ti.set_state(state, session) session.close() return dag_run def test_clear_task_instances_for_backfill_dagrun(self): now = timezone.utcnow() session = settings.Session() dag_id = 'test_clear_task_instances_for_backfill_dagrun' dag = DAG(dag_id=dag_id, start_date=now) self.create_dag_run(dag, execution_date=now, is_backfill=True) task0 = DummyOperator(task_id='backfill_task_0', owner='test', dag=dag) ti0 = TI(task=task0, execution_date=now) ti0.run() qry = session.query(TI).filter( TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session) session.commit() ti0.refresh_from_db() dr0 = session.query(DagRun).filter( DagRun.dag_id == dag_id, DagRun.execution_date == now ).first() self.assertEqual(dr0.state, State.RUNNING) def test_id_for_date(self): run_id = models.DagRun.id_for_date( timezone.datetime(2015, 1, 2, 3, 4, 5, 6)) self.assertEqual( 'scheduled__2015-01-02T03:04:05', run_id, 'Generated run_id did not match expectations: {0}'.format(run_id)) def test_dagrun_find(self): session = settings.Session() now = timezone.utcnow() dag_id1 = "test_dagrun_find_externally_triggered" dag_run = models.DagRun( dag_id=dag_id1, run_id='manual__' + now.isoformat(), execution_date=now, start_date=now, state=State.RUNNING, external_trigger=True, ) session.add(dag_run) dag_id2 = "test_dagrun_find_not_externally_triggered" dag_run = models.DagRun( dag_id=dag_id2, run_id='manual__' + now.isoformat(), execution_date=now, start_date=now, state=State.RUNNING, external_trigger=False, ) session.add(dag_run) session.commit() self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id1, external_trigger=True))) self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id1, external_trigger=False))) self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id2, external_trigger=True))) self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id2, external_trigger=False))) def test_dagrun_success_when_all_skipped(self): """ Tests that a DAG run succeeds when all tasks are skipped """ dag = DAG( dag_id='test_dagrun_success_when_all_skipped', start_date=timezone.datetime(2017, 1, 1) ) dag_task1 = ShortCircuitOperator( task_id='test_short_circuit_false', dag=dag, python_callable=lambda: False) dag_task2 = DummyOperator( task_id='test_state_skipped1', dag=dag) dag_task3 = DummyOperator( task_id='test_state_skipped2', dag=dag) dag_task1.set_downstream(dag_task2) dag_task2.set_downstream(dag_task3) initial_task_states = { 'test_short_circuit_false': State.SUCCESS, 'test_state_skipped1': State.SKIPPED, 'test_state_skipped2': State.SKIPPED, } dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) updated_dag_state = dag_run.update_state() self.assertEqual(State.SUCCESS, updated_dag_state) def test_dagrun_success_conditions(self): session = settings.Session() dag = DAG( 'test_dagrun_success_conditions', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B # A -> C -> D # ordered: B, D, C, A or D, B, C, A or D, C, B, A with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op3 = DummyOperator(task_id='C') op4 = DummyOperator(task_id='D') op1.set_upstream([op2, op3]) op3.set_upstream(op4) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun(run_id='test_dagrun_success_conditions', state=State.RUNNING, execution_date=now, start_date=now) # op1 = root ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2 = dr.get_task_instance(task_id=op2.task_id) ti_op3 = dr.get_task_instance(task_id=op3.task_id) ti_op4 = dr.get_task_instance(task_id=op4.task_id) # root is successful, but unfinished tasks state = dr.update_state() self.assertEqual(State.RUNNING, state) # one has failed, but root is successful ti_op2.set_state(state=State.FAILED, session=session) ti_op3.set_state(state=State.SUCCESS, session=session) ti_op4.set_state(state=State.SUCCESS, session=session) state = dr.update_state() self.assertEqual(State.SUCCESS, state) def test_dagrun_deadlock(self): session = settings.Session() dag = DAG( 'text_dagrun_deadlock', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op2.trigger_rule = TriggerRule.ONE_FAILED op2.set_upstream(op1) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun(run_id='test_dagrun_deadlock', state=State.RUNNING, execution_date=now, start_date=now) ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2 = dr.get_task_instance(task_id=op2.task_id) ti_op2.set_state(state=State.NONE, session=session) dr.update_state() self.assertEqual(dr.state, State.RUNNING) ti_op2.set_state(state=State.NONE, session=session) op2.trigger_rule = 'invalid' dr.update_state() self.assertEqual(dr.state, State.FAILED) def test_dagrun_no_deadlock_with_shutdown(self): session = settings.Session() dag = DAG('test_dagrun_no_deadlock_with_shutdown', start_date=DEFAULT_DATE) with dag: op1 = DummyOperator(task_id='upstream_task') op2 = DummyOperator(task_id='downstream_task') op2.set_upstream(op1) dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_with_shutdown', state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE) upstream_ti = dr.get_task_instance(task_id='upstream_task') upstream_ti.set_state(State.SHUTDOWN, session=session) dr.update_state() self.assertEqual(dr.state, State.RUNNING) def test_dagrun_no_deadlock_with_depends_on_past(self): session = settings.Session() dag = DAG('test_dagrun_no_deadlock', start_date=DEFAULT_DATE) with dag: DummyOperator(task_id='dop', depends_on_past=True) DummyOperator(task_id='tc', task_concurrency=1) dag.clear() dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1', state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE) dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2', state=State.RUNNING, execution_date=DEFAULT_DATE + datetime.timedelta(days=1), start_date=DEFAULT_DATE + datetime.timedelta(days=1)) ti1_op1 = dr.get_task_instance(task_id='dop') dr2.get_task_instance(task_id='dop') ti2_op1 = dr.get_task_instance(task_id='tc') dr.get_task_instance(task_id='tc') ti1_op1.set_state(state=State.RUNNING, session=session) dr.update_state() dr2.update_state() self.assertEqual(dr.state, State.RUNNING) self.assertEqual(dr2.state, State.RUNNING) ti2_op1.set_state(state=State.RUNNING, session=session) dr.update_state() dr2.update_state() self.assertEqual(dr.state, State.RUNNING) self.assertEqual(dr2.state, State.RUNNING) def test_dagrun_success_callback(self): def on_success_callable(context): self.assertEqual( context['dag_run'].dag_id, 'test_dagrun_success_callback' ) dag = DAG( dag_id='test_dagrun_success_callback', start_date=datetime.datetime(2017, 1, 1), on_success_callback=on_success_callable, ) dag_task1 = DummyOperator( task_id='test_state_succeeded1', dag=dag) dag_task2 = DummyOperator( task_id='test_state_succeeded2', dag=dag) dag_task1.set_downstream(dag_task2) initial_task_states = { 'test_state_succeeded1': State.SUCCESS, 'test_state_succeeded2': State.SUCCESS, } dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) updated_dag_state = dag_run.update_state() self.assertEqual(State.SUCCESS, updated_dag_state) def test_dagrun_failure_callback(self): def on_failure_callable(context): self.assertEqual( context['dag_run'].dag_id, 'test_dagrun_failure_callback' ) dag = DAG( dag_id='test_dagrun_failure_callback', start_date=datetime.datetime(2017, 1, 1), on_failure_callback=on_failure_callable, ) dag_task1 = DummyOperator( task_id='test_state_succeeded1', dag=dag) dag_task2 = DummyOperator( task_id='test_state_failed2', dag=dag) initial_task_states = { 'test_state_succeeded1': State.SUCCESS, 'test_state_failed2': State.FAILED, } dag_task1.set_downstream(dag_task2) dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) updated_dag_state = dag_run.update_state() self.assertEqual(State.FAILED, updated_dag_state) def test_dagrun_set_state_end_date(self): session = settings.Session() dag = DAG( 'test_dagrun_set_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun(run_id='test_dagrun_set_state_end_date', state=State.RUNNING, execution_date=now, start_date=now) # Initial end_date should be NULL # State.SUCCESS and State.FAILED are all ending state and should set end_date # State.RUNNING set end_date back to NULL session.add(dr) session.commit() self.assertIsNone(dr.end_date) dr.set_state(State.SUCCESS) session.merge(dr) session.commit() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_set_state_end_date' ).one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) dr.set_state(State.RUNNING) session.merge(dr) session.commit() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_set_state_end_date' ).one() self.assertIsNone(dr_database.end_date) dr.set_state(State.FAILED) session.merge(dr) session.commit() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_set_state_end_date' ).one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) def test_dagrun_update_state_end_date(self): session = settings.Session() dag = DAG( 'test_dagrun_update_state_end_date', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) # A -> B with dag: op1 = DummyOperator(task_id='A') op2 = DummyOperator(task_id='B') op1.set_upstream(op2) dag.clear() now = timezone.utcnow() dr = dag.create_dagrun(run_id='test_dagrun_update_state_end_date', state=State.RUNNING, execution_date=now, start_date=now) # Initial end_date should be NULL # State.SUCCESS and State.FAILED are all ending state and should set end_date # State.RUNNING set end_date back to NULL session.merge(dr) session.commit() self.assertIsNone(dr.end_date) ti_op1 = dr.get_task_instance(task_id=op1.task_id) ti_op1.set_state(state=State.SUCCESS, session=session) ti_op2 = dr.get_task_instance(task_id=op2.task_id) ti_op2.set_state(state=State.SUCCESS, session=session) dr.update_state() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_update_state_end_date' ).one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) ti_op1.set_state(state=State.RUNNING, session=session) ti_op2.set_state(state=State.RUNNING, session=session) dr.update_state() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_update_state_end_date' ).one() self.assertEqual(dr._state, State.RUNNING) self.assertIsNone(dr.end_date) self.assertIsNone(dr_database.end_date) ti_op1.set_state(state=State.FAILED, session=session) ti_op2.set_state(state=State.FAILED, session=session) dr.update_state() dr_database = session.query(DagRun).filter( DagRun.run_id == 'test_dagrun_update_state_end_date' ).one() self.assertIsNotNone(dr_database.end_date) self.assertEqual(dr.end_date, dr_database.end_date) def test_get_task_instance_on_empty_dagrun(self): """ Make sure that a proper value is returned when a dagrun has no task instances """ dag = DAG( dag_id='test_get_task_instance_on_empty_dagrun', start_date=timezone.datetime(2017, 1, 1) ) ShortCircuitOperator( task_id='test_short_circuit_false', dag=dag, python_callable=lambda: False) session = settings.Session() now = timezone.utcnow() # Don't use create_dagrun since it will create the task instances too which we # don't want dag_run = models.DagRun( dag_id=dag.dag_id, run_id='manual__' + now.isoformat(), execution_date=now, start_date=now, state=State.RUNNING, external_trigger=False, ) session.add(dag_run) session.commit() ti = dag_run.get_task_instance('test_short_circuit_false') self.assertEqual(None, ti) def test_get_latest_runs(self): session = settings.Session() dag = DAG( dag_id='test_latest_runs_1', start_date=DEFAULT_DATE) self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 1)) self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 2)) dagruns = models.DagRun.get_latest_runs(session) session.close() for dagrun in dagruns: if dagrun.dag_id == 'test_latest_runs_1': self.assertEqual(dagrun.execution_date, timezone.datetime(2015, 1, 2)) def test_is_backfill(self): dag = DAG(dag_id='test_is_backfill', start_date=DEFAULT_DATE) dagrun = self.create_dag_run(dag, execution_date=DEFAULT_DATE) dagrun.run_id = BackfillJob.ID_PREFIX + '_sfddsffds' dagrun2 = self.create_dag_run( dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) dagrun3 = self.create_dag_run( dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=2)) dagrun3.run_id = None self.assertTrue(dagrun.is_backfill) self.assertFalse(dagrun2.is_backfill) self.assertFalse(dagrun3.is_backfill) def test_removed_task_instances_can_be_restored(self): def with_all_tasks_removed(dag): return DAG(dag_id=dag.dag_id, start_date=dag.start_date) dag = DAG('test_task_restoration', start_date=DEFAULT_DATE) dag.add_task(DummyOperator(task_id='flaky_task', owner='test')) dagrun = self.create_dag_run(dag) flaky_ti = dagrun.get_task_instances()[0] self.assertEqual('flaky_task', flaky_ti.task_id) self.assertEqual(State.NONE, flaky_ti.state) dagrun.dag = with_all_tasks_removed(dag) dagrun.verify_integrity() flaky_ti.refresh_from_db() self.assertEqual(State.NONE, flaky_ti.state) dagrun.dag.add_task(DummyOperator(task_id='flaky_task', owner='test')) dagrun.verify_integrity() flaky_ti.refresh_from_db() self.assertEqual(State.NONE, flaky_ti.state) class DagBagTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.empty_dir = mkdtemp() @classmethod def tearDownClass(cls): os.rmdir(cls.empty_dir) def test_get_existing_dag(self): """ Test that we're able to parse some example DAGs and retrieve them """ dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True) some_expected_dag_ids = ["example_bash_operator", "example_branch_operator"] for dag_id in some_expected_dag_ids: dag = dagbag.get_dag(dag_id) self.assertIsNotNone(dag) self.assertEqual(dag_id, dag.dag_id) self.assertGreaterEqual(dagbag.size(), 7) def test_get_non_existing_dag(self): """ test that retrieving a non existing dag id returns None without crashing """ dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) non_existing_dag_id = "non_existing_dag_id" self.assertIsNone(dagbag.get_dag(non_existing_dag_id)) def test_process_file_that_contains_multi_bytes_char(self): """ test that we're able to parse file that contains multi-byte char """ f = NamedTemporaryFile() f.write('\u3042'.encode('utf8')) # write multi-byte char (hiragana) f.flush() dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) self.assertEqual([], dagbag.process_file(f.name)) def test_zip_skip_log(self): """ test the loading of a DAG from within a zip file that skips another file because it doesn't have "airflow" and "DAG" """ from mock import Mock with patch('airflow.models.DagBag.log') as log_mock: log_mock.info = Mock() test_zip_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip") dagbag = models.DagBag(dag_folder=test_zip_path, include_examples=False) self.assertTrue(dagbag.has_logged) log_mock.info.assert_any_call("File %s assumed to contain no DAGs. Skipping.", test_zip_path) def test_zip(self): """ test the loading of a DAG within a zip file that includes dependencies """ dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip")) self.assertTrue(dagbag.get_dag("test_zip_dag")) def test_process_file_cron_validity_check(self): """ test if an invalid cron expression as schedule interval can be identified """ invalid_dag_files = ["test_invalid_cron.py", "test_zip_invalid_cron.zip"] dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) self.assertEqual(len(dagbag.import_errors), 0) for d in invalid_dag_files: dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, d)) self.assertEqual(len(dagbag.import_errors), len(invalid_dag_files)) @patch.object(DagModel, 'get_current') def test_get_dag_without_refresh(self, mock_dagmodel): """ Test that, once a DAG is loaded, it doesn't get refreshed again if it hasn't been expired. """ dag_id = 'example_bash_operator' mock_dagmodel.return_value = DagModel() mock_dagmodel.return_value.last_expired = None mock_dagmodel.return_value.fileloc = 'foo' class TestDagBag(models.DagBag): process_file_calls = 0 def process_file(self, filepath, only_if_updated=True, safe_mode=True): if 'example_bash_operator.py' == os.path.basename(filepath): TestDagBag.process_file_calls += 1 super(TestDagBag, self).process_file(filepath, only_if_updated, safe_mode) dagbag = TestDagBag(include_examples=True) dagbag.process_file_calls # Should not call process_file again, since it's already loaded during init. self.assertEqual(1, dagbag.process_file_calls) self.assertIsNotNone(dagbag.get_dag(dag_id)) self.assertEqual(1, dagbag.process_file_calls) def test_get_dag_fileloc(self): """ Test that fileloc is correctly set when we load example DAGs, specifically SubDAGs. """ dagbag = models.DagBag(include_examples=True) expected = { 'example_bash_operator': 'example_bash_operator.py', 'example_subdag_operator': 'example_subdag_operator.py', 'example_subdag_operator.section-1': 'subdags/subdag.py' } for dag_id, path in expected.items(): dag = dagbag.get_dag(dag_id) self.assertTrue( dag.fileloc.endswith('airflow/example_dags/' + path)) def process_dag(self, create_dag): """ Helper method to process a file generated from the input create_dag function. """ # write source to file source = textwrap.dedent(''.join( inspect.getsource(create_dag).splitlines(True)[1:-1])) f = NamedTemporaryFile() f.write(source.encode('utf8')) f.flush() dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) found_dags = dagbag.process_file(f.name) return dagbag, found_dags, f.name def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag, should_be_found=True): expected_dag_ids = list(map(lambda dag: dag.dag_id, expected_parent_dag.subdags)) expected_dag_ids.append(expected_parent_dag.dag_id) actual_found_dag_ids = list(map(lambda dag: dag.dag_id, actual_found_dags)) for dag_id in expected_dag_ids: actual_dagbag.log.info('validating %s' % dag_id) self.assertEqual( dag_id in actual_found_dag_ids, should_be_found, 'dag "%s" should %shave been found after processing dag "%s"' % (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id) ) self.assertEqual( dag_id in actual_dagbag.dags, should_be_found, 'dag "%s" should %sbe in dagbag.dags after processing dag "%s"' % (dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id) ) def test_load_subdags(self): # Define Dag to load def standard_subdag(): from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator import datetime DAG_NAME = 'master' DEFAULT_ARGS = { 'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1) } dag = DAG( DAG_NAME, default_args=DEFAULT_ARGS) # master: # A -> opSubDag_0 # master.opsubdag_0: # -> subdag_0.task # A -> opSubDag_1 # master.opsubdag_1: # -> subdag_1.task with dag: def subdag_0(): subdag_0 = DAG('master.opSubdag_0', default_args=DEFAULT_ARGS) DummyOperator(task_id='subdag_0.task', dag=subdag_0) return subdag_0 def subdag_1(): subdag_1 = DAG('master.opSubdag_1', default_args=DEFAULT_ARGS) DummyOperator(task_id='subdag_1.task', dag=subdag_1) return subdag_1 opSubdag_0 = SubDagOperator( task_id='opSubdag_0', dag=dag, subdag=subdag_0()) opSubdag_1 = SubDagOperator( task_id='opSubdag_1', dag=dag, subdag=subdag_1()) opA = DummyOperator(task_id='A') opA.set_downstream(opSubdag_0) opA.set_downstream(opSubdag_1) return dag testDag = standard_subdag() # sanity check to make sure DAG.subdag is still functioning properly self.assertEqual(len(testDag.subdags), 2) # Perform processing dag dagbag, found_dags, _ = self.process_dag(standard_subdag) # Validate correctness # all dags from testDag should be listed self.validate_dags(testDag, found_dags, dagbag) # Define Dag to load def nested_subdags(): from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator import datetime DAG_NAME = 'master' DEFAULT_ARGS = { 'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1) } dag = DAG( DAG_NAME, default_args=DEFAULT_ARGS) # master: # A -> opSubdag_0 # master.opSubdag_0: # -> opSubDag_A # master.opSubdag_0.opSubdag_A: # -> subdag_A.task # -> opSubdag_B # master.opSubdag_0.opSubdag_B: # -> subdag_B.task # A -> opSubdag_1 # master.opSubdag_1: # -> opSubdag_C # master.opSubdag_1.opSubdag_C: # -> subdag_C.task # -> opSubDag_D # master.opSubdag_1.opSubdag_D: # -> subdag_D.task with dag: def subdag_A(): subdag_A = DAG( 'master.opSubdag_0.opSubdag_A', default_args=DEFAULT_ARGS) DummyOperator(task_id='subdag_A.task', dag=subdag_A) return subdag_A def subdag_B(): subdag_B = DAG( 'master.opSubdag_0.opSubdag_B', default_args=DEFAULT_ARGS) DummyOperator(task_id='subdag_B.task', dag=subdag_B) return subdag_B def subdag_C(): subdag_C = DAG( 'master.opSubdag_1.opSubdag_C', default_args=DEFAULT_ARGS) DummyOperator(task_id='subdag_C.task', dag=subdag_C) return subdag_C def subdag_D(): subdag_D = DAG( 'master.opSubdag_1.opSubdag_D', default_args=DEFAULT_ARGS) DummyOperator(task_id='subdag_D.task', dag=subdag_D) return subdag_D def subdag_0(): subdag_0 = DAG('master.opSubdag_0', default_args=DEFAULT_ARGS) SubDagOperator(task_id='opSubdag_A', dag=subdag_0, subdag=subdag_A()) SubDagOperator(task_id='opSubdag_B', dag=subdag_0, subdag=subdag_B()) return subdag_0 def subdag_1(): subdag_1 = DAG('master.opSubdag_1', default_args=DEFAULT_ARGS) SubDagOperator(task_id='opSubdag_C', dag=subdag_1, subdag=subdag_C()) SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_D()) return subdag_1 opSubdag_0 = SubDagOperator( task_id='opSubdag_0', dag=dag, subdag=subdag_0()) opSubdag_1 = SubDagOperator( task_id='opSubdag_1', dag=dag, subdag=subdag_1()) opA = DummyOperator(task_id='A') opA.set_downstream(opSubdag_0) opA.set_downstream(opSubdag_1) return dag testDag = nested_subdags() # sanity check to make sure DAG.subdag is still functioning properly self.assertEqual(len(testDag.subdags), 6) # Perform processing dag dagbag, found_dags, _ = self.process_dag(nested_subdags) # Validate correctness # all dags from testDag should be listed self.validate_dags(testDag, found_dags, dagbag) def test_skip_cycle_dags(self): """ Don't crash when loading an invalid (contains a cycle) DAG file. Don't load the dag into the DagBag either """ # Define Dag to load def basic_cycle(): from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator import datetime DAG_NAME = 'cycle_dag' DEFAULT_ARGS = { 'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1) } dag = DAG( DAG_NAME, default_args=DEFAULT_ARGS) # A -> A with dag: opA = DummyOperator(task_id='A') opA.set_downstream(opA) return dag testDag = basic_cycle() # sanity check to make sure DAG.subdag is still functioning properly self.assertEqual(len(testDag.subdags), 0) # Perform processing dag dagbag, found_dags, file_path = self.process_dag(basic_cycle) # #Validate correctness # None of the dags should be found self.validate_dags(testDag, found_dags, dagbag, should_be_found=False) self.assertIn(file_path, dagbag.import_errors) # Define Dag to load def nested_subdag_cycle(): from airflow.models import DAG from airflow.operators.dummy_operator import DummyOperator from airflow.operators.subdag_operator import SubDagOperator import datetime DAG_NAME = 'nested_cycle' DEFAULT_ARGS = { 'owner': 'owner1', 'start_date': datetime.datetime(2016, 1, 1) } dag = DAG( DAG_NAME, default_args=DEFAULT_ARGS) # cycle: # A -> opSubdag_0 # cycle.opSubdag_0: # -> opSubDag_A # cycle.opSubdag_0.opSubdag_A: # -> subdag_A.task # -> opSubdag_B # cycle.opSubdag_0.opSubdag_B: # -> subdag_B.task # A -> opSubdag_1 # cycle.opSubdag_1: # -> opSubdag_C # cycle.opSubdag_1.opSubdag_C: # -> subdag_C.task -> subdag_C.task >Invalid Loop< # -> opSubDag_D # cycle.opSubdag_1.opSubdag_D: # -> subdag_D.task with dag: def subdag_A(): subdag_A = DAG( 'nested_cycle.opSubdag_0.opSubdag_A', default_args=DEFAULT_ARGS) DummyOperator(task_id='subdag_A.task', dag=subdag_A) return subdag_A def subdag_B(): subdag_B = DAG( 'nested_cycle.opSubdag_0.opSubdag_B', default_args=DEFAULT_ARGS) DummyOperator(task_id='subdag_B.task', dag=subdag_B) return subdag_B def subdag_C(): subdag_C = DAG( 'nested_cycle.opSubdag_1.opSubdag_C', default_args=DEFAULT_ARGS) opSubdag_C_task = DummyOperator( task_id='subdag_C.task', dag=subdag_C) # introduce a loop in opSubdag_C opSubdag_C_task.set_downstream(opSubdag_C_task) return subdag_C def subdag_D(): subdag_D = DAG( 'nested_cycle.opSubdag_1.opSubdag_D', default_args=DEFAULT_ARGS) DummyOperator(task_id='subdag_D.task', dag=subdag_D) return subdag_D def subdag_0(): subdag_0 = DAG('nested_cycle.opSubdag_0', default_args=DEFAULT_ARGS) SubDagOperator(task_id='opSubdag_A', dag=subdag_0, subdag=subdag_A()) SubDagOperator(task_id='opSubdag_B', dag=subdag_0, subdag=subdag_B()) return subdag_0 def subdag_1(): subdag_1 = DAG('nested_cycle.opSubdag_1', default_args=DEFAULT_ARGS) SubDagOperator(task_id='opSubdag_C', dag=subdag_1, subdag=subdag_C()) SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_D()) return subdag_1 opSubdag_0 = SubDagOperator( task_id='opSubdag_0', dag=dag, subdag=subdag_0()) opSubdag_1 = SubDagOperator( task_id='opSubdag_1', dag=dag, subdag=subdag_1()) opA = DummyOperator(task_id='A') opA.set_downstream(opSubdag_0) opA.set_downstream(opSubdag_1) return dag testDag = nested_subdag_cycle() # sanity check to make sure DAG.subdag is still functioning properly self.assertEqual(len(testDag.subdags), 6) # Perform processing dag dagbag, found_dags, file_path = self.process_dag(nested_subdag_cycle) # Validate correctness # None of the dags should be found self.validate_dags(testDag, found_dags, dagbag, should_be_found=False) self.assertIn(file_path, dagbag.import_errors) def test_process_file_with_none(self): """ test that process_file can handle Nones """ dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False) self.assertEqual([], dagbag.process_file(None)) @patch.object(TI, 'handle_failure') def test_kill_zombies(self, mock_ti_handle_failure): """ Test that kill zombies call TIs failure handler with proper context """ dagbag = models.DagBag() with create_session() as session: session.query(TI).delete() dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') ti = TI(task, DEFAULT_DATE, State.RUNNING) session.add(ti) session.commit() zombies = [SimpleTaskInstance(ti)] dagbag.kill_zombies(zombies) mock_ti_handle_failure \ .assert_called_with(ANY, configuration.getboolean('core', 'unit_test_mode'), ANY) def test_deactivate_unknown_dags(self): """ Test that dag_ids not passed into deactivate_unknown_dags are deactivated when function is invoked """ dagbag = models.DagBag(include_examples=True) expected_active_dags = dagbag.dags.keys() session = settings.Session session.add(DagModel(dag_id='test_deactivate_unknown_dags', is_active=True)) session.commit() models.DAG.deactivate_unknown_dags(expected_active_dags) for dag in session.query(DagModel).all(): if dag.dag_id in expected_active_dags: self.assertTrue(dag.is_active) else: self.assertEqual(dag.dag_id, 'test_deactivate_unknown_dags') self.assertFalse(dag.is_active) # clean up session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete() session.commit() class TaskInstanceTest(unittest.TestCase): def tearDown(self): with create_session() as session: session.query(models.TaskFail).delete() session.query(models.TaskReschedule).delete() session.query(models.TaskInstance).delete() def test_set_task_dates(self): """ Test that tasks properly take start/end dates from DAGs """ dag = DAG('dag', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) op1 = DummyOperator(task_id='op_1', owner='test') self.assertTrue(op1.start_date is None and op1.end_date is None) # dag should assign its dates to op1 because op1 has no dates dag.add_task(op1) self.assertTrue( op1.start_date == dag.start_date and op1.end_date == dag.end_date) op2 = DummyOperator( task_id='op_2', owner='test', start_date=DEFAULT_DATE - datetime.timedelta(days=1), end_date=DEFAULT_DATE + datetime.timedelta(days=11)) # dag should assign its dates to op2 because they are more restrictive dag.add_task(op2) self.assertTrue( op2.start_date == dag.start_date and op2.end_date == dag.end_date) op3 = DummyOperator( task_id='op_3', owner='test', start_date=DEFAULT_DATE + datetime.timedelta(days=1), end_date=DEFAULT_DATE + datetime.timedelta(days=9)) # op3 should keep its dates because they are more restrictive dag.add_task(op3) self.assertTrue( op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1)) self.assertTrue( op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9)) def test_timezone_awareness(self): NAIVE_DATETIME = DEFAULT_DATE.replace(tzinfo=None) # check ti without dag (just for bw compat) op_no_dag = DummyOperator(task_id='op_no_dag') ti = TI(task=op_no_dag, execution_date=NAIVE_DATETIME) self.assertEqual(ti.execution_date, DEFAULT_DATE) # check with dag without localized execution_date dag = DAG('dag', start_date=DEFAULT_DATE) op1 = DummyOperator(task_id='op_1') dag.add_task(op1) ti = TI(task=op1, execution_date=NAIVE_DATETIME) self.assertEqual(ti.execution_date, DEFAULT_DATE) # with dag and localized execution_date tz = pendulum.timezone("Europe/Amsterdam") execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tz) utc_date = timezone.convert_to_utc(execution_date) ti = TI(task=op1, execution_date=execution_date) self.assertEqual(ti.execution_date, utc_date) def test_task_naive_datetime(self): NAIVE_DATETIME = DEFAULT_DATE.replace(tzinfo=None) op_no_dag = DummyOperator(task_id='test_task_naive_datetime', start_date=NAIVE_DATETIME, end_date=NAIVE_DATETIME) self.assertTrue(op_no_dag.start_date.tzinfo) self.assertTrue(op_no_dag.end_date.tzinfo) def test_set_dag(self): """ Test assigning Operators to Dags, including deferred assignment """ dag = DAG('dag', start_date=DEFAULT_DATE) dag2 = DAG('dag2', start_date=DEFAULT_DATE) op = DummyOperator(task_id='op_1', owner='test') # no dag assigned self.assertFalse(op.has_dag()) self.assertRaises(AirflowException, getattr, op, 'dag') # no improper assignment with self.assertRaises(TypeError): op.dag = 1 op.dag = dag # no reassignment with self.assertRaises(AirflowException): op.dag = dag2 # but assigning the same dag is ok op.dag = dag self.assertIs(op.dag, dag) self.assertIn(op, dag.tasks) def test_infer_dag(self): dag = DAG('dag', start_date=DEFAULT_DATE) dag2 = DAG('dag2', start_date=DEFAULT_DATE) op1 = DummyOperator(task_id='test_op_1', owner='test') op2 = DummyOperator(task_id='test_op_2', owner='test') op3 = DummyOperator(task_id='test_op_3', owner='test', dag=dag) op4 = DummyOperator(task_id='test_op_4', owner='test', dag=dag2) # double check dags self.assertEqual( [i.has_dag() for i in [op1, op2, op3, op4]], [False, False, True, True]) # can't combine operators with no dags self.assertRaises(AirflowException, op1.set_downstream, op2) # op2 should infer dag from op1 op1.dag = dag op1.set_downstream(op2) self.assertIs(op2.dag, dag) # can't assign across multiple DAGs self.assertRaises(AirflowException, op1.set_downstream, op4) self.assertRaises(AirflowException, op1.set_downstream, [op3, op4]) def test_bitshift_compose_operators(self): dag = DAG('dag', start_date=DEFAULT_DATE) op1 = DummyOperator(task_id='test_op_1', owner='test') op2 = DummyOperator(task_id='test_op_2', owner='test') op3 = DummyOperator(task_id='test_op_3', owner='test') op4 = DummyOperator(task_id='test_op_4', owner='test') op5 = DummyOperator(task_id='test_op_5', owner='test') # can't compose operators without dags with self.assertRaises(AirflowException): op1 >> op2 dag >> op1 >> op2 << op3 # make sure dag assignment carries through # using __rrshift__ self.assertIs(op1.dag, dag) self.assertIs(op2.dag, dag) self.assertIs(op3.dag, dag) # op2 should be downstream of both self.assertIn(op2, op1.downstream_list) self.assertIn(op2, op3.downstream_list) # test dag assignment with __rlshift__ dag << op4 self.assertIs(op4.dag, dag) # dag assignment with __rrshift__ dag >> op5 self.assertIs(op5.dag, dag) @patch.object(DAG, 'concurrency_reached') def test_requeue_over_concurrency(self, mock_concurrency_reached): mock_concurrency_reached.return_value = True dag = DAG(dag_id='test_requeue_over_concurrency', start_date=DEFAULT_DATE, max_active_runs=1, concurrency=2) task = DummyOperator(task_id='test_requeue_over_concurrency_op', dag=dag) ti = TI(task=task, execution_date=timezone.utcnow()) ti.run() self.assertEqual(ti.state, models.State.NONE) @patch.object(TI, 'pool_full') def test_run_pooling_task(self, mock_pool_full): """ test that running task update task state as without running task. (no dependency check in ti_deps anymore, so also -> SUCCESS) """ # Mock the pool out with a full pool because the pool doesn't actually exist mock_pool_full.return_value = True dag = models.DAG(dag_id='test_run_pooling_task') task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, pool='test_run_pooling_task_pool', owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) ti = TI( task=task, execution_date=timezone.utcnow()) ti.run() self.assertEqual(ti.state, models.State.SUCCESS) @patch.object(TI, 'pool_full') def test_run_pooling_task_with_mark_success(self, mock_pool_full): """ test that running task with mark_success param update task state as SUCCESS without running task. """ # Mock the pool out with a full pool because the pool doesn't actually exist mock_pool_full.return_value = True dag = models.DAG(dag_id='test_run_pooling_task_with_mark_success') task = DummyOperator( task_id='test_run_pooling_task_with_mark_success_op', dag=dag, pool='test_run_pooling_task_with_mark_success_pool', owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) ti = TI( task=task, execution_date=timezone.utcnow()) ti.run(mark_success=True) self.assertEqual(ti.state, models.State.SUCCESS) def test_run_pooling_task_with_skip(self): """ test that running task which returns AirflowSkipOperator will end up in a SKIPPED state. """ def raise_skip_exception(): raise AirflowSkipException dag = models.DAG(dag_id='test_run_pooling_task_with_skip') task = PythonOperator( task_id='test_run_pooling_task_with_skip', dag=dag, python_callable=raise_skip_exception, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) ti = TI( task=task, execution_date=timezone.utcnow()) ti.run() self.assertEqual(models.State.SKIPPED, ti.state) def test_retry_delay(self): """ Test that retry delays are respected """ dag = models.DAG(dag_id='test_retry_handling') task = BashOperator( task_id='test_retry_handling_op', bash_command='exit 1', retries=1, retry_delay=datetime.timedelta(seconds=3), dag=dag, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) def run_with_error(ti): try: ti.run() except AirflowException: pass ti = TI( task=task, execution_date=timezone.utcnow()) self.assertEqual(ti.try_number, 1) # first run -- up for retry run_with_error(ti) self.assertEqual(ti.state, State.UP_FOR_RETRY) self.assertEqual(ti.try_number, 2) # second run -- still up for retry because retry_delay hasn't expired run_with_error(ti) self.assertEqual(ti.state, State.UP_FOR_RETRY) # third run -- failed time.sleep(3) run_with_error(ti) self.assertEqual(ti.state, State.FAILED) @patch.object(TI, 'pool_full') def test_retry_handling(self, mock_pool_full): """ Test that task retries are handled properly """ # Mock the pool with a pool with slots open since the pool doesn't actually exist mock_pool_full.return_value = False dag = models.DAG(dag_id='test_retry_handling') task = BashOperator( task_id='test_retry_handling_op', bash_command='exit 1', retries=1, retry_delay=datetime.timedelta(seconds=0), dag=dag, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) def run_with_error(ti): try: ti.run() except AirflowException: pass ti = TI( task=task, execution_date=timezone.utcnow()) self.assertEqual(ti.try_number, 1) # first run -- up for retry run_with_error(ti) self.assertEqual(ti.state, State.UP_FOR_RETRY) self.assertEqual(ti._try_number, 1) self.assertEqual(ti.try_number, 2) # second run -- fail run_with_error(ti) self.assertEqual(ti.state, State.FAILED) self.assertEqual(ti._try_number, 2) self.assertEqual(ti.try_number, 3) # Clear the TI state since you can't run a task with a FAILED state without # clearing it first dag.clear() # third run -- up for retry run_with_error(ti) self.assertEqual(ti.state, State.UP_FOR_RETRY) self.assertEqual(ti._try_number, 3) self.assertEqual(ti.try_number, 4) # fourth run -- fail run_with_error(ti) ti.refresh_from_db() self.assertEqual(ti.state, State.FAILED) self.assertEqual(ti._try_number, 4) self.assertEqual(ti.try_number, 5) def test_next_retry_datetime(self): delay = datetime.timedelta(seconds=30) max_delay = datetime.timedelta(minutes=60) dag = models.DAG(dag_id='fail_dag') task = BashOperator( task_id='task_with_exp_backoff_and_max_delay', bash_command='exit 1', retries=3, retry_delay=delay, retry_exponential_backoff=True, max_retry_delay=max_delay, dag=dag, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) ti = TI( task=task, execution_date=DEFAULT_DATE) ti.end_date = pendulum.instance(timezone.utcnow()) dt = ti.next_retry_datetime() # between 30 * 2^0.5 and 30 * 2^1 (15 and 30) period = ti.end_date.add(seconds=30) - ti.end_date.add(seconds=15) self.assertTrue(dt in period) ti.try_number = 3 dt = ti.next_retry_datetime() # between 30 * 2^2 and 30 * 2^3 (120 and 240) period = ti.end_date.add(seconds=240) - ti.end_date.add(seconds=120) self.assertTrue(dt in period) ti.try_number = 5 dt = ti.next_retry_datetime() # between 30 * 2^4 and 30 * 2^5 (480 and 960) period = ti.end_date.add(seconds=960) - ti.end_date.add(seconds=480) self.assertTrue(dt in period) ti.try_number = 9 dt = ti.next_retry_datetime() self.assertEqual(dt, ti.end_date + max_delay) ti.try_number = 50 dt = ti.next_retry_datetime() self.assertEqual(dt, ti.end_date + max_delay) @patch.object(TI, 'pool_full') def test_reschedule_handling(self, mock_pool_full): """ Test that task reschedules are handled properly """ # Mock the pool with a pool with slots open since the pool doesn't actually exist mock_pool_full.return_value = False # Return values of the python sensor callable, modified during tests done = False fail = False def callable(): if fail: raise AirflowException() return done dag = models.DAG(dag_id='test_reschedule_handling') task = PythonSensor( task_id='test_reschedule_handling_sensor', poke_interval=0, mode='reschedule', python_callable=callable, retries=1, retry_delay=datetime.timedelta(seconds=0), dag=dag, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) ti = TI(task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) self.assertEqual(ti.try_number, 1) def run_ti_and_assert(run_date, expected_start_date, expected_end_date, expected_duration, expected_state, expected_try_number, expected_task_reschedule_count): with freeze_time(run_date): try: ti.run() except AirflowException: if not fail: raise ti.refresh_from_db() self.assertEqual(ti.state, expected_state) self.assertEqual(ti._try_number, expected_try_number) self.assertEqual(ti.try_number, expected_try_number + 1) self.assertEqual(ti.start_date, expected_start_date) self.assertEqual(ti.end_date, expected_end_date) self.assertEqual(ti.duration, expected_duration) trs = TR.find_for_task_instance(ti) self.assertEqual(len(trs), expected_task_reschedule_count) date1 = timezone.utcnow() date2 = date1 + datetime.timedelta(minutes=1) date3 = date2 + datetime.timedelta(minutes=1) date4 = date3 + datetime.timedelta(minutes=1) # Run with multiple reschedules. # During reschedule the try number remains the same, but each reschedule is recorded. # The start date is expected to remain the inital date, hence the duration increases. # When finished the try number is incremented and there is no reschedule expected # for this try. done, fail = False, False run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1) done, fail = False, False run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2) done, fail = False, False run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3) done, fail = True, False run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0) # Clear the task instance. dag.clear() ti.refresh_from_db() self.assertEqual(ti.state, State.NONE) self.assertEqual(ti._try_number, 1) # Run again after clearing with reschedules and a retry. # The retry increments the try number, and for that try no reschedule is expected. # After the retry the start date is reset, hence the duration is also reset. done, fail = False, False run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1) done, fail = False, True run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0) done, fail = False, False run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1) done, fail = True, False run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0) def test_depends_on_past(self): dagbag = models.DagBag() dag = dagbag.get_dag('test_depends_on_past') dag.clear() task = dag.tasks[0] run_date = task.start_date + datetime.timedelta(days=5) ti = TI(task, run_date) # depends_on_past prevents the run task.run(start_date=run_date, end_date=run_date) ti.refresh_from_db() self.assertIs(ti.state, None) # ignore first depends_on_past to allow the run task.run( start_date=run_date, end_date=run_date, ignore_first_depends_on_past=True) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) # Parameterized tests to check for the correct firing # of the trigger_rule under various circumstances # Numeric fields are in order: # successes, skipped, failed, upstream_failed, done @parameterized.expand([ # # Tests for all_success # ['all_success', 5, 0, 0, 0, 0, True, None, True], ['all_success', 2, 0, 0, 0, 0, True, None, False], ['all_success', 2, 0, 1, 0, 0, True, ST.UPSTREAM_FAILED, False], ['all_success', 2, 1, 0, 0, 0, True, ST.SKIPPED, False], # # Tests for one_success # ['one_success', 5, 0, 0, 0, 5, True, None, True], ['one_success', 2, 0, 0, 0, 2, True, None, True], ['one_success', 2, 0, 1, 0, 3, True, None, True], ['one_success', 2, 1, 0, 0, 3, True, None, True], # # Tests for all_failed # ['all_failed', 5, 0, 0, 0, 5, True, ST.SKIPPED, False], ['all_failed', 0, 0, 5, 0, 5, True, None, True], ['all_failed', 2, 0, 0, 0, 2, True, ST.SKIPPED, False], ['all_failed', 2, 0, 1, 0, 3, True, ST.SKIPPED, False], ['all_failed', 2, 1, 0, 0, 3, True, ST.SKIPPED, False], # # Tests for one_failed # ['one_failed', 5, 0, 0, 0, 0, True, None, False], ['one_failed', 2, 0, 0, 0, 0, True, None, False], ['one_failed', 2, 0, 1, 0, 0, True, None, True], ['one_failed', 2, 1, 0, 0, 3, True, None, False], ['one_failed', 2, 3, 0, 0, 5, True, ST.SKIPPED, False], # # Tests for done # ['all_done', 5, 0, 0, 0, 5, True, None, True], ['all_done', 2, 0, 0, 0, 2, True, None, False], ['all_done', 2, 0, 1, 0, 3, True, None, False], ['all_done', 2, 1, 0, 0, 3, True, None, False] ]) def test_check_task_dependencies(self, trigger_rule, successes, skipped, failed, upstream_failed, done, flag_upstream_failed, expect_state, expect_completed): start_date = timezone.datetime(2016, 2, 1, 0, 0, 0) dag = models.DAG('test-dag', start_date=start_date) downstream = DummyOperator(task_id='downstream', dag=dag, owner='airflow', trigger_rule=trigger_rule) for i in range(5): task = DummyOperator(task_id='runme_{}'.format(i), dag=dag, owner='airflow') task.set_downstream(downstream) run_date = task.start_date + datetime.timedelta(days=5) ti = TI(downstream, run_date) dep_results = TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=successes, skipped=skipped, failed=failed, upstream_failed=upstream_failed, done=done, flag_upstream_failed=flag_upstream_failed) completed = all([dep.passed for dep in dep_results]) self.assertEqual(completed, expect_completed) self.assertEqual(ti.state, expect_state) def test_xcom_pull(self): """ Test xcom_pull, using different filtering methods. """ dag = models.DAG( dag_id='test_xcom', schedule_interval='@monthly', start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) exec_date = timezone.utcnow() # Push a value task1 = DummyOperator(task_id='test_xcom_1', dag=dag, owner='airflow') ti1 = TI(task=task1, execution_date=exec_date) ti1.xcom_push(key='foo', value='bar') # Push another value with the same key (but by a different task) task2 = DummyOperator(task_id='test_xcom_2', dag=dag, owner='airflow') ti2 = TI(task=task2, execution_date=exec_date) ti2.xcom_push(key='foo', value='baz') # Pull with no arguments result = ti1.xcom_pull() self.assertEqual(result, None) # Pull the value pushed most recently by any task. result = ti1.xcom_pull(key='foo') self.assertIn(result, 'baz') # Pull the value pushed by the first task result = ti1.xcom_pull(task_ids='test_xcom_1', key='foo') self.assertEqual(result, 'bar') # Pull the value pushed by the second task result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo') self.assertEqual(result, 'baz') # Pull the values pushed by both tasks result = ti1.xcom_pull( task_ids=['test_xcom_1', 'test_xcom_2'], key='foo') self.assertEqual(result, ('bar', 'baz')) def test_xcom_pull_after_success(self): """ tests xcom set/clear relative to a task in a 'success' rerun scenario """ key = 'xcom_key' value = 'xcom_value' dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly') task = DummyOperator( task_id='test_xcom', dag=dag, pool='test_xcom', owner='airflow', start_date=timezone.datetime(2016, 6, 2, 0, 0, 0)) exec_date = timezone.utcnow() ti = TI( task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() # The second run and assert is to handle AIRFLOW-131 (don't clear on # prior success) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) # Test AIRFLOW-703: Xcom shouldn't be cleared if the task doesn't # execute, even if dependencies are ignored ti.run(ignore_all_deps=True, mark_success=True) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) # Xcom IS finally cleared once task has executed ti.run(ignore_all_deps=True) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None) def test_xcom_pull_different_execution_date(self): """ tests xcom fetch behavior with different execution dates, using both xcom_pull with "include_prior_dates" and without """ key = 'xcom_key' value = 'xcom_value' dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly') task = DummyOperator( task_id='test_xcom', dag=dag, pool='test_xcom', owner='airflow', start_date=timezone.datetime(2016, 6, 2, 0, 0, 0)) exec_date = timezone.utcnow() ti = TI( task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() exec_date += datetime.timedelta(days=1) ti = TI( task=task, execution_date=exec_date) ti.run() # We have set a new execution date (and did not pass in # 'include_prior_dates'which means this task should now have a cleared # xcom value self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None) # We *should* get a value using 'include_prior_dates' self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key, include_prior_dates=True), value) def test_xcom_push_flag(self): """ Tests the option for Operators to push XComs """ value = 'hello' task_id = 'test_no_xcom_push' dag = models.DAG(dag_id='test_xcom') # nothing saved to XCom task = PythonOperator( task_id=task_id, dag=dag, python_callable=lambda: value, do_xcom_push=False, owner='airflow', start_date=datetime.datetime(2017, 1, 1) ) ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1)) ti.run() self.assertEqual( ti.xcom_pull( task_ids=task_id, key=models.XCOM_RETURN_KEY ), None ) def test_post_execute_hook(self): """ Test that post_execute hook is called with the Operator's result. The result ('error') will cause an error to be raised and trapped. """ class TestError(Exception): pass class TestOperator(PythonOperator): def post_execute(self, context, result): if result == 'error': raise TestError('expected error.') dag = models.DAG(dag_id='test_post_execute_dag') task = TestOperator( task_id='test_operator', dag=dag, python_callable=lambda: 'error', owner='airflow', start_date=timezone.datetime(2017, 2, 1)) ti = TI(task=task, execution_date=timezone.utcnow()) with self.assertRaises(TestError): ti.run() def test_check_and_change_state_before_execution(self): dag = models.DAG(dag_id='test_check_and_change_state_before_execution') task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) ti = TI( task=task, execution_date=timezone.utcnow()) self.assertEqual(ti._try_number, 0) self.assertTrue(ti._check_and_change_state_before_execution()) # State should be running, and try_number column should be incremented self.assertEqual(ti.state, State.RUNNING) self.assertEqual(ti._try_number, 1) def test_check_and_change_state_before_execution_dep_not_met(self): dag = models.DAG(dag_id='test_check_and_change_state_before_execution') task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE) task >> task2 ti = TI( task=task2, execution_date=timezone.utcnow()) self.assertFalse(ti._check_and_change_state_before_execution()) def test_try_number(self): """ Test the try_number accessor behaves in various running states """ dag = models.DAG(dag_id='test_check_and_change_state_before_execution') task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) ti = TI(task=task, execution_date=timezone.utcnow()) self.assertEqual(1, ti.try_number) ti.try_number = 2 ti.state = State.RUNNING self.assertEqual(2, ti.try_number) ti.state = State.SUCCESS self.assertEqual(3, ti.try_number) def test_get_num_running_task_instances(self): session = settings.Session() dag = models.DAG(dag_id='test_get_num_running_task_instances') dag2 = models.DAG(dag_id='test_get_num_running_task_instances_dummy') task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) task2 = DummyOperator(task_id='task', dag=dag2, start_date=DEFAULT_DATE) ti1 = TI(task=task, execution_date=DEFAULT_DATE) ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) ti3 = TI(task=task2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED ti3.state = State.RUNNING session.add(ti1) session.add(ti2) session.add(ti3) session.commit() self.assertEqual(1, ti1.get_num_running_task_instances(session=session)) self.assertEqual(1, ti2.get_num_running_task_instances(session=session)) self.assertEqual(1, ti3.get_num_running_task_instances(session=session)) # def test_log_url(self): # now = pendulum.now('Europe/Brussels') # dag = DAG('dag', start_date=DEFAULT_DATE) # task = DummyOperator(task_id='op', dag=dag) # ti = TI(task=task, execution_date=now) # d = urllib.parse.parse_qs( # urllib.parse.urlparse(ti.log_url).query, # keep_blank_values=True, strict_parsing=True) # self.assertEqual(d['dag_id'][0], 'dag') # self.assertEqual(d['task_id'][0], 'op') # self.assertEqual(pendulum.parse(d['execution_date'][0]), now) def test_log_url(self): dag = DAG('dag', start_date=DEFAULT_DATE) task = DummyOperator(task_id='op', dag=dag) ti = TI(task=task, execution_date=datetime.datetime(2018, 1, 1)) expected_url = ( 'http://localhost:8080/log?' 'execution_date=2018-01-01T00%3A00%3A00%2B00%3A00' '&task_id=op' '&dag_id=dag' ) self.assertEqual(ti.log_url, expected_url) def test_mark_success_url(self): now = pendulum.now('Europe/Brussels') dag = DAG('dag', start_date=DEFAULT_DATE) task = DummyOperator(task_id='op', dag=dag) ti = TI(task=task, execution_date=now) d = urllib.parse.parse_qs( urllib.parse.urlparse(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True) self.assertEqual(d['dag_id'][0], 'dag') self.assertEqual(d['task_id'][0], 'op') self.assertEqual(pendulum.parse(d['execution_date'][0]), now) def test_overwrite_params_with_dag_run_conf(self): task = DummyOperator(task_id='op') ti = TI(task=task, execution_date=datetime.datetime.now()) dag_run = DagRun() dag_run.conf = {"override": True} params = {"override": False} ti.overwrite_params_with_dag_run_conf(params, dag_run) self.assertEqual(True, params["override"]) def test_overwrite_params_with_dag_run_none(self): task = DummyOperator(task_id='op') ti = TI(task=task, execution_date=datetime.datetime.now()) params = {"override": False} ti.overwrite_params_with_dag_run_conf(params, None) self.assertEqual(False, params["override"]) def test_overwrite_params_with_dag_run_conf_none(self): task = DummyOperator(task_id='op') ti = TI(task=task, execution_date=datetime.datetime.now()) params = {"override": False} dag_run = DagRun() ti.overwrite_params_with_dag_run_conf(params, dag_run) self.assertEqual(False, params["override"]) @patch('airflow.models.send_email') def test_email_alert(self, mock_send_email): dag = models.DAG(dag_id='test_failure_email') task = BashOperator( task_id='test_email_alert', dag=dag, bash_command='exit 1', start_date=DEFAULT_DATE, email='to') ti = TI(task=task, execution_date=datetime.datetime.now()) try: ti.run() except AirflowException: pass (email, title, body), _ = mock_send_email.call_args self.assertEqual(email, 'to') self.assertIn('test_email_alert', title) self.assertIn('test_email_alert', body) @patch('airflow.models.send_email') def test_email_alert_with_config(self, mock_send_email): dag = models.DAG(dag_id='test_failure_email') task = BashOperator( task_id='test_email_alert_with_config', dag=dag, bash_command='exit 1', start_date=DEFAULT_DATE, email='to') ti = TI( task=task, execution_date=datetime.datetime.now()) configuration.set('email', 'SUBJECT_TEMPLATE', '/subject/path') configuration.set('email', 'HTML_CONTENT_TEMPLATE', '/html_content/path') opener = mock_open(read_data='template: {{ti.task_id}}') with patch('airflow.models.open', opener, create=True): try: ti.run() except AirflowException: pass (email, title, body), _ = mock_send_email.call_args self.assertEqual(email, 'to') self.assertEqual('template: test_email_alert_with_config', title) self.assertEqual('template: test_email_alert_with_config', body) def test_set_duration(self): task = DummyOperator(task_id='op', email='test@test.test') ti = TI( task=task, execution_date=datetime.datetime.now(), ) ti.start_date = datetime.datetime(2018, 10, 1, 1) ti.end_date = datetime.datetime(2018, 10, 1, 2) ti.set_duration() self.assertEqual(ti.duration, 3600) def test_set_duration_empty_dates(self): task = DummyOperator(task_id='op', email='test@test.test') ti = TI(task=task, execution_date=datetime.datetime.now()) ti.set_duration() self.assertIsNone(ti.duration) def test_success_callbak_no_race_condition(self): class CallbackWrapper(object): def wrap_task_instance(self, ti): self.task_id = ti.task_id self.dag_id = ti.dag_id self.execution_date = ti.execution_date self.task_state_in_callback = "" self.callback_ran = False def success_handler(self, context): self.callback_ran = True session = settings.Session() temp_instance = session.query(TI).filter( TI.task_id == self.task_id).filter( TI.dag_id == self.dag_id).filter( TI.execution_date == self.execution_date).one() self.task_state_in_callback = temp_instance.state cw = CallbackWrapper() dag = DAG('test_success_callbak_no_race_condition', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task = DummyOperator(task_id='op', email='test@test.test', on_success_callback=cw.success_handler, dag=dag) ti = TI(task=task, execution_date=datetime.datetime.now()) ti.state = State.RUNNING session = settings.Session() session.merge(ti) session.commit() cw.wrap_task_instance(ti) ti._run_raw_task() self.assertTrue(cw.callback_ran) self.assertEqual(cw.task_state_in_callback, State.RUNNING) ti.refresh_from_db() self.assertEqual(ti.state, State.SUCCESS) class ClearTasksTest(unittest.TestCase): def test_clear_task_instances(self): dag = DAG('test_clear_task_instances', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='0', owner='test', dag=dag) task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) ti0.run() ti1.run() session = settings.Session() qry = session.query(TI).filter( TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session, dag=dag) session.commit() ti0.refresh_from_db() ti1.refresh_from_db() # Next try to run will be try 2 self.assertEqual(ti0.try_number, 2) self.assertEqual(ti0.max_tries, 1) self.assertEqual(ti1.try_number, 2) self.assertEqual(ti1.max_tries, 3) def test_clear_task_instances_without_task(self): dag = DAG('test_clear_task_instances_without_task', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='task0', owner='test', dag=dag) task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) ti0.run() ti1.run() # Remove the task from dag. dag.task_dict = {} self.assertFalse(dag.has_task(task0.task_id)) self.assertFalse(dag.has_task(task1.task_id)) session = settings.Session() qry = session.query(TI).filter( TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session) session.commit() # When dag is None, max_tries will be maximum of original max_tries or try_number. ti0.refresh_from_db() ti1.refresh_from_db() # Next try to run will be try 2 self.assertEqual(ti0.try_number, 2) self.assertEqual(ti0.max_tries, 1) self.assertEqual(ti1.try_number, 2) self.assertEqual(ti1.max_tries, 2) def test_clear_task_instances_without_dag(self): dag = DAG('test_clear_task_instances_without_dag', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='task_0', owner='test', dag=dag) task1 = DummyOperator(task_id='task_1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) ti0.run() ti1.run() session = settings.Session() qry = session.query(TI).filter( TI.dag_id == dag.dag_id).all() clear_task_instances(qry, session) session.commit() # When dag is None, max_tries will be maximum of original max_tries or try_number. ti0.refresh_from_db() ti1.refresh_from_db() # Next try to run will be try 2 self.assertEqual(ti0.try_number, 2) self.assertEqual(ti0.max_tries, 1) self.assertEqual(ti1.try_number, 2) self.assertEqual(ti1.max_tries, 2) def test_dag_clear(self): dag = DAG('test_dag_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) # Next try to run will be try 1 self.assertEqual(ti0.try_number, 1) ti0.run() self.assertEqual(ti0.try_number, 2) dag.clear() ti0.refresh_from_db() self.assertEqual(ti0.try_number, 2) self.assertEqual(ti0.state, State.NONE) self.assertEqual(ti0.max_tries, 1) task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', dag=dag, retries=2) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) self.assertEqual(ti1.max_tries, 2) ti1.try_number = 1 # Next try will be 2 ti1.run() self.assertEqual(ti1.try_number, 3) self.assertEqual(ti1.max_tries, 2) dag.clear() ti0.refresh_from_db() ti1.refresh_from_db() # after clear dag, ti2 should show attempt 3 of 5 self.assertEqual(ti1.max_tries, 4) self.assertEqual(ti1.try_number, 3) # after clear dag, ti1 should show attempt 2 of 2 self.assertEqual(ti0.try_number, 2) self.assertEqual(ti0.max_tries, 1) def test_dags_clear(self): # setup session = settings.Session() dags, tis = [], [] num_of_dags = 5 for i in range(num_of_dags): dag = DAG('test_dag_clear_' + str(i), start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) ti = TI(task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test', dag=dag), execution_date=DEFAULT_DATE) dags.append(dag) tis.append(ti) # test clear all dags for i in range(num_of_dags): tis[i].run() self.assertEqual(tis[i].state, State.SUCCESS) self.assertEqual(tis[i].try_number, 2) self.assertEqual(tis[i].max_tries, 0) DAG.clear_dags(dags) for i in range(num_of_dags): tis[i].refresh_from_db() self.assertEqual(tis[i].state, State.NONE) self.assertEqual(tis[i].try_number, 2) self.assertEqual(tis[i].max_tries, 1) # test dry_run for i in range(num_of_dags): tis[i].run() self.assertEqual(tis[i].state, State.SUCCESS) self.assertEqual(tis[i].try_number, 3) self.assertEqual(tis[i].max_tries, 1) DAG.clear_dags(dags, dry_run=True) for i in range(num_of_dags): tis[i].refresh_from_db() self.assertEqual(tis[i].state, State.SUCCESS) self.assertEqual(tis[i].try_number, 3) self.assertEqual(tis[i].max_tries, 1) # test only_failed from random import randint failed_dag_idx = randint(0, len(tis) - 1) tis[failed_dag_idx].state = State.FAILED session.merge(tis[failed_dag_idx]) session.commit() DAG.clear_dags(dags, only_failed=True) for i in range(num_of_dags): tis[i].refresh_from_db() if i != failed_dag_idx: self.assertEqual(tis[i].state, State.SUCCESS) self.assertEqual(tis[i].try_number, 3) self.assertEqual(tis[i].max_tries, 1) else: self.assertEqual(tis[i].state, State.NONE) self.assertEqual(tis[i].try_number, 3) self.assertEqual(tis[i].max_tries, 2) def test_operator_clear(self): dag = DAG('test_operator_clear', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10)) t1 = DummyOperator(task_id='bash_op', owner='test', dag=dag) t2 = DummyOperator(task_id='dummy_op', owner='test', dag=dag, retries=1) t2.set_upstream(t1) ti1 = TI(task=t1, execution_date=DEFAULT_DATE) ti2 = TI(task=t2, execution_date=DEFAULT_DATE) ti2.run() # Dependency not met self.assertEqual(ti2.try_number, 1) self.assertEqual(ti2.max_tries, 1) t2.clear(upstream=True) ti1.run() ti2.run() self.assertEqual(ti1.try_number, 2) # max_tries is 0 because there is no task instance in db for ti1 # so clear won't change the max_tries. self.assertEqual(ti1.max_tries, 0) self.assertEqual(ti2.try_number, 2) # try_number (0) + retries(1) self.assertEqual(ti2.max_tries, 1) def test_xcom_disable_pickle_type(self): configuration.load_test_config() json_obj = {"key": "value"} execution_date = timezone.utcnow() key = "xcom_test1" dag_id = "test_dag1" task_id = "test_task1" configuration.set("core", "enable_xcom_pickling", "False") XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date) ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date) self.assertEqual(ret_value, json_obj) session = settings.Session() ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.execution_date == execution_date ).first().value self.assertEqual(ret_value, json_obj) def test_xcom_enable_pickle_type(self): json_obj = {"key": "value"} execution_date = timezone.utcnow() key = "xcom_test2" dag_id = "test_dag2" task_id = "test_task2" configuration.set("core", "enable_xcom_pickling", "True") XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date) ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date) self.assertEqual(ret_value, json_obj) session = settings.Session() ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.execution_date == execution_date ).first().value self.assertEqual(ret_value, json_obj) def test_xcom_disable_pickle_type_fail_on_non_json(self): class PickleRce(object): def __reduce__(self): return os.system, ("ls -alt",) configuration.set("core", "xcom_enable_pickling", "False") self.assertRaises(TypeError, XCom.set, key="xcom_test3", value=PickleRce(), dag_id="test_dag3", task_id="test_task3", execution_date=timezone.utcnow()) def test_xcom_get_many(self): json_obj = {"key": "value"} execution_date = timezone.utcnow() key = "xcom_test4" dag_id1 = "test_dag4" task_id1 = "test_task4" dag_id2 = "test_dag5" task_id2 = "test_task5" configuration.set("core", "xcom_enable_pickling", "True") XCom.set(key=key, value=json_obj, dag_id=dag_id1, task_id=task_id1, execution_date=execution_date) XCom.set(key=key, value=json_obj, dag_id=dag_id2, task_id=task_id2, execution_date=execution_date) results = XCom.get_many(key=key, execution_date=execution_date) for result in results: self.assertEqual(result.value, json_obj) class VariableTest(unittest.TestCase): def setUp(self): models._fernet = None def tearDown(self): models._fernet = None @patch('airflow.models.configuration.conf.get') def test_variable_no_encryption(self, mock_get): """ Test variables without encryption """ mock_get.return_value = '' Variable.set('key', 'value') session = settings.Session() test_var = session.query(Variable).filter(Variable.key == 'key').one() self.assertFalse(test_var.is_encrypted) self.assertEqual(test_var.val, 'value') @patch('airflow.models.configuration.conf.get') def test_variable_with_encryption(self, mock_get): """ Test variables with encryption """ mock_get.return_value = Fernet.generate_key().decode() Variable.set('key', 'value') session = settings.Session() test_var = session.query(Variable).filter(Variable.key == 'key').one() self.assertTrue(test_var.is_encrypted) self.assertEqual(test_var.val, 'value') @patch('airflow.models.configuration.conf.get') def test_var_with_encryption_rotate_fernet_key(self, mock_get): """ Tests rotating encrypted variables. """ key1 = Fernet.generate_key() key2 = Fernet.generate_key() mock_get.return_value = key1.decode() Variable.set('key', 'value') session = settings.Session() test_var = session.query(Variable).filter(Variable.key == 'key').one() self.assertTrue(test_var.is_encrypted) self.assertEqual(test_var.val, 'value') self.assertEqual(Fernet(key1).decrypt(test_var._val.encode()), b'value') # Test decrypt of old value with new key mock_get.return_value = ','.join([key2.decode(), key1.decode()]) models._fernet = None self.assertEqual(test_var.val, 'value') # Test decrypt of new value with new key test_var.rotate_fernet_key() self.assertTrue(test_var.is_encrypted) self.assertEqual(test_var.val, 'value') self.assertEqual(Fernet(key2).decrypt(test_var._val.encode()), b'value') class ConnectionTest(unittest.TestCase): def setUp(self): models._fernet = None def tearDown(self): models._fernet = None @patch('airflow.models.configuration.conf.get') def test_connection_extra_no_encryption(self, mock_get): """ Tests extras on a new connection without encryption. The fernet key is set to a non-base64-encoded string and the extra is stored without encryption. """ mock_get.return_value = '' test_connection = Connection(extra='testextra') self.assertFalse(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') @patch('airflow.models.configuration.conf.get') def test_connection_extra_with_encryption(self, mock_get): """ Tests extras on a new connection with encryption. """ mock_get.return_value = Fernet.generate_key().decode() test_connection = Connection(extra='testextra') self.assertTrue(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') @patch('airflow.models.configuration.conf.get') def test_connection_extra_with_encryption_rotate_fernet_key(self, mock_get): """ Tests rotating encrypted extras. """ key1 = Fernet.generate_key() key2 = Fernet.generate_key() mock_get.return_value = key1.decode() test_connection = Connection(extra='testextra') self.assertTrue(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra') # Test decrypt of old value with new key mock_get.return_value = ','.join([key2.decode(), key1.decode()]) models._fernet = None self.assertEqual(test_connection.extra, 'testextra') # Test decrypt of new value with new key test_connection.rotate_fernet_key() self.assertTrue(test_connection.is_extra_encrypted) self.assertEqual(test_connection.extra, 'testextra') self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra') def test_connection_from_uri_without_extras(self): uri = 'scheme://user:password@host%2flocation:1234/schema' connection = Connection(uri=uri) self.assertEqual(connection.conn_type, 'scheme') self.assertEqual(connection.host, 'host/location') self.assertEqual(connection.schema, 'schema') self.assertEqual(connection.login, 'user') self.assertEqual(connection.password, 'password') self.assertEqual(connection.port, 1234) self.assertIsNone(connection.extra) def test_connection_from_uri_with_extras(self): uri = 'scheme://user:password@host%2flocation:1234/schema?' \ 'extra1=a%20value&extra2=%2fpath%2f' connection = Connection(uri=uri) self.assertEqual(connection.conn_type, 'scheme') self.assertEqual(connection.host, 'host/location') self.assertEqual(connection.schema, 'schema') self.assertEqual(connection.login, 'user') self.assertEqual(connection.password, 'password') self.assertEqual(connection.port, 1234) self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value', 'extra2': '/path/'}) def test_connection_from_uri_with_colon_in_hostname(self): uri = 'scheme://user:password@host%2flocation%3ax%3ay:1234/schema?' \ 'extra1=a%20value&extra2=%2fpath%2f' connection = Connection(uri=uri) self.assertEqual(connection.conn_type, 'scheme') self.assertEqual(connection.host, 'host/location:x:y') self.assertEqual(connection.schema, 'schema') self.assertEqual(connection.login, 'user') self.assertEqual(connection.password, 'password') self.assertEqual(connection.port, 1234) self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value', 'extra2': '/path/'}) def test_connection_from_uri_with_encoded_password(self): uri = 'scheme://user:password%20with%20space@host%2flocation%3ax%3ay:1234/schema' connection = Connection(uri=uri) self.assertEqual(connection.conn_type, 'scheme') self.assertEqual(connection.host, 'host/location:x:y') self.assertEqual(connection.schema, 'schema') self.assertEqual(connection.login, 'user') self.assertEqual(connection.password, 'password with space') self.assertEqual(connection.port, 1234) def test_connection_from_uri_with_encoded_user(self): uri = 'scheme://domain%2fuser:password@host%2flocation%3ax%3ay:1234/schema' connection = Connection(uri=uri) self.assertEqual(connection.conn_type, 'scheme') self.assertEqual(connection.host, 'host/location:x:y') self.assertEqual(connection.schema, 'schema') self.assertEqual(connection.login, 'domain/user') self.assertEqual(connection.password, 'password') self.assertEqual(connection.port, 1234) def test_connection_from_uri_with_encoded_schema(self): uri = 'scheme://user:password%20with%20space@host:1234/schema%2ftest' connection = Connection(uri=uri) self.assertEqual(connection.conn_type, 'scheme') self.assertEqual(connection.host, 'host') self.assertEqual(connection.schema, 'schema/test') self.assertEqual(connection.login, 'user') self.assertEqual(connection.password, 'password with space') self.assertEqual(connection.port, 1234) def test_connection_from_uri_no_schema(self): uri = 'scheme://user:password%20with%20space@host:1234' connection = Connection(uri=uri) self.assertEqual(connection.conn_type, 'scheme') self.assertEqual(connection.host, 'host') self.assertEqual(connection.schema, '') self.assertEqual(connection.login, 'user') self.assertEqual(connection.password, 'password with space') self.assertEqual(connection.port, 1234) class TestSkipMixin(unittest.TestCase): @patch('airflow.models.timezone.utcnow') def test_skip(self, mock_now): session = settings.Session() now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) mock_now.return_value = now dag = DAG( 'dag', start_date=DEFAULT_DATE, ) with dag: tasks = [DummyOperator(task_id='task')] dag_run = dag.create_dagrun( run_id='manual__' + now.isoformat(), state=State.FAILED, ) SkipMixin().skip( dag_run=dag_run, execution_date=now, tasks=tasks, session=session) session.query(TI).filter( TI.dag_id == 'dag', TI.task_id == 'task', TI.state == State.SKIPPED, TI.start_date == now, TI.end_date == now, ).one() @patch('airflow.models.timezone.utcnow') def test_skip_none_dagrun(self, mock_now): session = settings.Session() now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) mock_now.return_value = now dag = DAG( 'dag', start_date=DEFAULT_DATE, ) with dag: tasks = [DummyOperator(task_id='task')] SkipMixin().skip( dag_run=None, execution_date=now, tasks=tasks, session=session) session.query(TI).filter( TI.dag_id == 'dag', TI.task_id == 'task', TI.state == State.SKIPPED, TI.start_date == now, TI.end_date == now, ).one() def test_skip_none_tasks(self): session = Mock() SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session) self.assertFalse(session.query.called) self.assertFalse(session.commit.called) class TestKubeResourceVersion(unittest.TestCase): def test_checkpoint_resource_version(self): session = settings.Session() KubeResourceVersion.checkpoint_resource_version('7', session) self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '7') def test_reset_resource_version(self): session = settings.Session() version = KubeResourceVersion.reset_resource_version(session) self.assertEqual(version, '0') self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '0') class TestKubeWorkerIdentifier(unittest.TestCase): @patch('airflow.models.uuid.uuid4') def test_get_or_create_not_exist(self, mock_uuid): session = settings.Session() session.query(KubeWorkerIdentifier).update({ KubeWorkerIdentifier.worker_uuid: '' }) mock_uuid.return_value = 'abcde' worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session) self.assertEqual(worker_uuid, 'abcde') def test_get_or_create_exist(self): session = settings.Session() KubeWorkerIdentifier.checkpoint_kube_worker_uuid('fghij', session) worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session) self.assertEqual(worker_uuid, 'fghij')