incubator-airflow/tests/models.py

3490 строки
128 KiB
Python

# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import datetime
import inspect
import logging
import os
import re
import textwrap
import time
import unittest
import urllib
import uuid
from tempfile import NamedTemporaryFile, mkdtemp
import pendulum
import six
from mock import ANY, Mock, mock_open, patch
from parameterized import parameterized
from freezegun import freeze_time
from cryptography.fernet import Fernet
from airflow import AirflowException, configuration, models, settings
from airflow.contrib.sensors.python_sensor import PythonSensor
from airflow.exceptions import AirflowDagCycleException, AirflowSkipException
from airflow.jobs import BackfillJob
from airflow.models import DAG, TaskInstance as TI
from airflow.models import DagModel, DagRun
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier
from airflow.models import SkipMixin
from airflow.models import State as ST
from airflow.models import TaskReschedule as TR
from airflow.models import XCom
from airflow.models import Variable
from airflow.models import clear_task_instances
from airflow.models.connection import Connection
from airflow.operators.bash_operator import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
from airflow.operators.python_operator import ShortCircuitOperator
from airflow.operators.subdag_operator import SubDagOperator
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.dag_processing import SimpleTaskInstance
from airflow.utils.db import create_session
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
TEST_DAGS_FOLDER = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'dags')
class DagTest(unittest.TestCase):
def test_params_not_passed_is_empty_dict(self):
"""
Test that when 'params' is _not_ passed to a new Dag, that the params
attribute is set to an empty dictionary.
"""
dag = models.DAG('test-dag')
self.assertEqual(dict, type(dag.params))
self.assertEqual(0, len(dag.params))
def test_params_passed_and_params_in_default_args_no_override(self):
"""
Test that when 'params' exists as a key passed to the default_args dict
in addition to params being passed explicitly as an argument to the
dag, that the 'params' key of the default_args dict is merged with the
dict of the params argument.
"""
params1 = {'parameter1': 1}
params2 = {'parameter2': 2}
dag = models.DAG('test-dag',
default_args={'params': params1},
params=params2)
params_combined = params1.copy()
params_combined.update(params2)
self.assertEqual(params_combined, dag.params)
def test_dag_as_context_manager(self):
"""
Test DAG as a context manager.
When used as a context manager, Operators are automatically added to
the DAG (unless they specify a different DAG)
"""
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
dag2 = DAG(
'dag2',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner2'})
with dag:
op1 = DummyOperator(task_id='op1')
op2 = DummyOperator(task_id='op2', dag=dag2)
self.assertIs(op1.dag, dag)
self.assertEqual(op1.owner, 'owner1')
self.assertIs(op2.dag, dag2)
self.assertEqual(op2.owner, 'owner2')
with dag2:
op3 = DummyOperator(task_id='op3')
self.assertIs(op3.dag, dag2)
self.assertEqual(op3.owner, 'owner2')
with dag:
with dag2:
op4 = DummyOperator(task_id='op4')
op5 = DummyOperator(task_id='op5')
self.assertIs(op4.dag, dag2)
self.assertIs(op5.dag, dag)
self.assertEqual(op4.owner, 'owner2')
self.assertEqual(op5.owner, 'owner1')
with DAG('creating_dag_in_cm', start_date=DEFAULT_DATE) as dag:
DummyOperator(task_id='op6')
self.assertEqual(dag.dag_id, 'creating_dag_in_cm')
self.assertEqual(dag.tasks[0].task_id, 'op6')
with dag:
with dag:
op7 = DummyOperator(task_id='op7')
op8 = DummyOperator(task_id='op8')
op9 = DummyOperator(task_id='op8')
op9.dag = dag2
self.assertEqual(op7.dag, dag)
self.assertEqual(op8.dag, dag)
self.assertEqual(op9.dag, dag2)
def test_dag_topological_sort(self):
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
# A -> B
# A -> C -> D
# ordered: B, D, C, A or D, B, C, A or D, C, B, A
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op3 = DummyOperator(task_id='C')
op4 = DummyOperator(task_id='D')
op1.set_upstream([op2, op3])
op3.set_upstream(op4)
topological_list = dag.topological_sort()
logging.info(topological_list)
tasks = [op2, op3, op4]
self.assertTrue(topological_list[0] in tasks)
tasks.remove(topological_list[0])
self.assertTrue(topological_list[1] in tasks)
tasks.remove(topological_list[1])
self.assertTrue(topological_list[2] in tasks)
tasks.remove(topological_list[2])
self.assertTrue(topological_list[3] == op1)
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
# C -> (A u B) -> D
# C -> E
# ordered: E | D, A | B, C
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op3 = DummyOperator(task_id='C')
op4 = DummyOperator(task_id='D')
op5 = DummyOperator(task_id='E')
op1.set_downstream(op3)
op2.set_downstream(op3)
op1.set_upstream(op4)
op2.set_upstream(op4)
op5.set_downstream(op3)
topological_list = dag.topological_sort()
logging.info(topological_list)
set1 = [op4, op5]
self.assertTrue(topological_list[0] in set1)
set1.remove(topological_list[0])
set2 = [op1, op2]
set2.extend(set1)
self.assertTrue(topological_list[1] in set2)
set2.remove(topological_list[1])
self.assertTrue(topological_list[2] in set2)
set2.remove(topological_list[2])
self.assertTrue(topological_list[3] in set2)
self.assertTrue(topological_list[4] == op3)
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
self.assertEqual(tuple(), dag.topological_sort())
def test_dag_naive_default_args_start_date(self):
dag = DAG('DAG', default_args={'start_date': datetime.datetime(2018, 1, 1)})
self.assertEqual(dag.timezone, settings.TIMEZONE)
dag = DAG('DAG', start_date=datetime.datetime(2018, 1, 1))
self.assertEqual(dag.timezone, settings.TIMEZONE)
def test_dag_none_default_args_start_date(self):
"""
Tests if a start_date of None in default_args
works.
"""
dag = DAG('DAG', default_args={'start_date': None})
self.assertEqual(dag.timezone, settings.TIMEZONE)
def test_dag_task_priority_weight_total(self):
width = 5
depth = 5
weight = 5
pattern = re.compile('stage(\\d*).(\\d*)')
# Fully connected parallel tasks. i.e. every task at each parallel
# stage is dependent on every task in the previous stage.
# Default weight should be calculated using downstream descendants
with DAG('dag', start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'}) as dag:
pipeline = [
[DummyOperator(
task_id='stage{}.{}'.format(i, j), priority_weight=weight)
for j in range(0, width)] for i in range(0, depth)
]
for d, stage in enumerate(pipeline):
if d == 0:
continue
for current_task in stage:
for prev_task in pipeline[d - 1]:
current_task.set_upstream(prev_task)
for task in six.itervalues(dag.task_dict):
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight
calculated_weight = task.priority_weight_total
self.assertEqual(calculated_weight, correct_weight)
# Same test as above except use 'upstream' for weight calculation
weight = 3
with DAG('dag', start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'}) as dag:
pipeline = [
[DummyOperator(
task_id='stage{}.{}'.format(i, j), priority_weight=weight,
weight_rule=WeightRule.UPSTREAM)
for j in range(0, width)] for i in range(0, depth)
]
for d, stage in enumerate(pipeline):
if d == 0:
continue
for current_task in stage:
for prev_task in pipeline[d - 1]:
current_task.set_upstream(prev_task)
for task in six.itervalues(dag.task_dict):
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = (task_depth * width + 1) * weight
calculated_weight = task.priority_weight_total
self.assertEqual(calculated_weight, correct_weight)
# Same test as above except use 'absolute' for weight calculation
weight = 10
with DAG('dag', start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'}) as dag:
pipeline = [
[DummyOperator(
task_id='stage{}.{}'.format(i, j), priority_weight=weight,
weight_rule=WeightRule.ABSOLUTE)
for j in range(0, width)] for i in range(0, depth)
]
for d, stage in enumerate(pipeline):
if d == 0:
continue
for current_task in stage:
for prev_task in pipeline[d - 1]:
current_task.set_upstream(prev_task)
for task in six.itervalues(dag.task_dict):
match = pattern.match(task.task_id)
task_depth = int(match.group(1))
# the sum of each stages after this task + itself
correct_weight = weight
calculated_weight = task.priority_weight_total
self.assertEqual(calculated_weight, correct_weight)
# Test if we enter an invalid weight rule
with DAG('dag', start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'}) as dag:
with self.assertRaises(AirflowException):
DummyOperator(task_id='should_fail', weight_rule='no rule')
def test_get_num_task_instances(self):
test_dag_id = 'test_get_num_task_instances_dag'
test_task_id = 'task_1'
test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE)
test_task = DummyOperator(task_id=test_task_id, dag=test_dag)
ti1 = TI(task=test_task, execution_date=DEFAULT_DATE)
ti1.state = None
ti2 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
ti2.state = State.RUNNING
ti3 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=2))
ti3.state = State.QUEUED
ti4 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=3))
ti4.state = State.RUNNING
session = settings.Session()
session.merge(ti1)
session.merge(ti2)
session.merge(ti3)
session.merge(ti4)
session.commit()
self.assertEqual(
0,
DAG.get_num_task_instances(test_dag_id, ['fakename'], session=session)
)
self.assertEqual(
4,
DAG.get_num_task_instances(test_dag_id, [test_task_id], session=session)
)
self.assertEqual(
4,
DAG.get_num_task_instances(
test_dag_id, ['fakename', test_task_id], session=session)
)
self.assertEqual(
1,
DAG.get_num_task_instances(
test_dag_id, [test_task_id], states=[None], session=session)
)
self.assertEqual(
2,
DAG.get_num_task_instances(
test_dag_id, [test_task_id], states=[State.RUNNING], session=session)
)
self.assertEqual(
3,
DAG.get_num_task_instances(
test_dag_id, [test_task_id],
states=[None, State.RUNNING], session=session)
)
self.assertEqual(
4,
DAG.get_num_task_instances(
test_dag_id, [test_task_id],
states=[None, State.QUEUED, State.RUNNING], session=session)
)
session.close()
def test_render_template_field(self):
"""Tests if render_template from a field works"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE)
with dag:
task = DummyOperator(task_id='op1')
result = task.render_template('', '{{ foo }}', dict(foo='bar'))
self.assertEqual(result, 'bar')
def test_render_template_list_field(self):
"""Tests if render_template from a list field works"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE)
with dag:
task = DummyOperator(task_id='op1')
self.assertListEqual(
task.render_template('', ['{{ foo }}_1', '{{ foo }}_2'], {'foo': 'bar'}),
['bar_1', 'bar_2']
)
def test_render_template_tuple_field(self):
"""Tests if render_template from a tuple field works"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE)
with dag:
task = DummyOperator(task_id='op1')
# tuple is replaced by a list
self.assertListEqual(
task.render_template('', ('{{ foo }}_1', '{{ foo }}_2'), {'foo': 'bar'}),
['bar_1', 'bar_2']
)
def test_render_template_dict_field(self):
"""Tests if render_template from a dict field works"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE)
with dag:
task = DummyOperator(task_id='op1')
self.assertDictEqual(
task.render_template('', {'key1': '{{ foo }}_1', 'key2': '{{ foo }}_2'}, {'foo': 'bar'}),
{'key1': 'bar_1', 'key2': 'bar_2'}
)
def test_render_template_dict_field_with_templated_keys(self):
"""Tests if render_template from a dict field works as expected:
dictionary keys are not templated"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE)
with dag:
task = DummyOperator(task_id='op1')
self.assertDictEqual(
task.render_template('', {'key_{{ foo }}_1': 1, 'key_2': '{{ foo }}_2'}, {'foo': 'bar'}),
{'key_{{ foo }}_1': 1, 'key_2': 'bar_2'}
)
def test_render_template_date_field(self):
"""Tests if render_template from a date field works"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE)
with dag:
task = DummyOperator(task_id='op1')
self.assertEqual(
task.render_template('', datetime.date(2018, 12, 6), {'foo': 'bar'}),
datetime.date(2018, 12, 6)
)
def test_render_template_datetime_field(self):
"""Tests if render_template from a datetime field works"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE)
with dag:
task = DummyOperator(task_id='op1')
self.assertEqual(
task.render_template('', datetime.datetime(2018, 12, 6, 10, 55), {'foo': 'bar'}),
datetime.datetime(2018, 12, 6, 10, 55)
)
def test_render_template_UUID_field(self):
"""Tests if render_template from a UUID field works"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE)
with dag:
task = DummyOperator(task_id='op1')
random_uuid = uuid.uuid4()
self.assertIs(
task.render_template('', random_uuid, {'foo': 'bar'}),
random_uuid
)
def test_render_template_object_field(self):
"""Tests if render_template from an object field works"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE)
with dag:
task = DummyOperator(task_id='op1')
test_object = object()
self.assertIs(
task.render_template('', test_object, {'foo': 'bar'}),
test_object
)
def test_render_template_field_macro(self):
""" Tests if render_template from a field works,
if a custom filter was defined"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE,
user_defined_macros=dict(foo='bar'))
with dag:
task = DummyOperator(task_id='op1')
result = task.render_template('', '{{ foo }}', dict())
self.assertEqual(result, 'bar')
def test_render_template_numeric_field(self):
""" Tests if render_template from a field works,
if a custom filter was defined"""
dag = DAG('test-dag',
start_date=DEFAULT_DATE,
user_defined_macros=dict(foo='bar'))
with dag:
task = DummyOperator(task_id='op1')
result = task.render_template('', 1, dict())
self.assertEqual(result, 1)
def test_user_defined_filters(self):
def jinja_udf(name):
return 'Hello %s' % name
dag = models.DAG('test-dag',
start_date=DEFAULT_DATE,
user_defined_filters=dict(hello=jinja_udf))
jinja_env = dag.get_template_env()
self.assertIn('hello', jinja_env.filters)
self.assertEqual(jinja_env.filters['hello'], jinja_udf)
def test_render_template_field_filter(self):
""" Tests if render_template from a field works,
if a custom filter was defined"""
def jinja_udf(name):
return 'Hello %s' % name
dag = DAG('test-dag',
start_date=DEFAULT_DATE,
user_defined_filters=dict(hello=jinja_udf))
with dag:
task = DummyOperator(task_id='op1')
result = task.render_template('', "{{ 'world' | hello}}", dict())
self.assertEqual(result, 'Hello world')
def test_resolve_template_files_value(self):
with NamedTemporaryFile(suffix='.template') as f:
f.write('{{ ds }}'.encode('utf8'))
f.flush()
template_dir = os.path.dirname(f.name)
template_file = os.path.basename(f.name)
dag = DAG('test-dag',
start_date=DEFAULT_DATE,
template_searchpath=template_dir)
with dag:
task = DummyOperator(task_id='op1')
task.test_field = template_file
task.template_fields = ('test_field',)
task.template_ext = ('.template',)
task.resolve_template_files()
self.assertEqual(task.test_field, '{{ ds }}')
def test_resolve_template_files_list(self):
with NamedTemporaryFile(suffix='.template') as f:
f = NamedTemporaryFile(suffix='.template')
f.write('{{ ds }}'.encode('utf8'))
f.flush()
template_dir = os.path.dirname(f.name)
template_file = os.path.basename(f.name)
dag = DAG('test-dag',
start_date=DEFAULT_DATE,
template_searchpath=template_dir)
with dag:
task = DummyOperator(task_id='op1')
task.test_field = [template_file, 'some_string']
task.template_fields = ('test_field',)
task.template_ext = ('.template',)
task.resolve_template_files()
self.assertEqual(task.test_field, ['{{ ds }}', 'some_string'])
def test_cycle(self):
# test empty
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
self.assertFalse(dag.test_cycle())
# test single task
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
with dag:
opA = DummyOperator(task_id='A')
self.assertFalse(dag.test_cycle())
# test no cycle
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
# A -> B -> C
# B -> D
# E -> F
with dag:
opA = DummyOperator(task_id='A')
opB = DummyOperator(task_id='B')
opC = DummyOperator(task_id='C')
opD = DummyOperator(task_id='D')
opE = DummyOperator(task_id='E')
opF = DummyOperator(task_id='F')
opA.set_downstream(opB)
opB.set_downstream(opC)
opB.set_downstream(opD)
opE.set_downstream(opF)
self.assertFalse(dag.test_cycle())
# test self loop
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
# A -> A
with dag:
opA = DummyOperator(task_id='A')
opA.set_downstream(opA)
with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()
# test downstream self loop
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
# A -> B -> C -> D -> E -> E
with dag:
opA = DummyOperator(task_id='A')
opB = DummyOperator(task_id='B')
opC = DummyOperator(task_id='C')
opD = DummyOperator(task_id='D')
opE = DummyOperator(task_id='E')
opA.set_downstream(opB)
opB.set_downstream(opC)
opC.set_downstream(opD)
opD.set_downstream(opE)
opE.set_downstream(opE)
with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()
# large loop
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
# A -> B -> C -> D -> E -> A
with dag:
opA = DummyOperator(task_id='A')
opB = DummyOperator(task_id='B')
opC = DummyOperator(task_id='C')
opD = DummyOperator(task_id='D')
opE = DummyOperator(task_id='E')
opA.set_downstream(opB)
opB.set_downstream(opC)
opC.set_downstream(opD)
opD.set_downstream(opE)
opE.set_downstream(opA)
with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()
# test arbitrary loop
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
# E-> A -> B -> F -> A
# -> C -> F
with dag:
opA = DummyOperator(task_id='A')
opB = DummyOperator(task_id='B')
opC = DummyOperator(task_id='C')
opD = DummyOperator(task_id='D')
opE = DummyOperator(task_id='E')
opF = DummyOperator(task_id='F')
opA.set_downstream(opB)
opA.set_downstream(opC)
opE.set_downstream(opA)
opC.set_downstream(opF)
opB.set_downstream(opF)
opF.set_downstream(opA)
with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()
def test_following_previous_schedule(self):
"""
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone('Europe/Zurich')
start = local_tz.convert(datetime.datetime(2018, 10, 28, 2, 55),
dst_rule=pendulum.PRE_TRANSITION)
self.assertEqual(start.isoformat(), "2018-10-28T02:55:00+02:00",
"Pre-condition: start date is in DST")
utc = timezone.convert_to_utc(start)
dag = DAG('tz_dag', start_date=start, schedule_interval='*/5 * * * *')
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
self.assertEqual(_next.isoformat(), "2018-10-28T01:00:00+00:00")
self.assertEqual(next_local.isoformat(), "2018-10-28T02:00:00+01:00")
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
self.assertEqual(prev_local.isoformat(), "2018-10-28T02:50:00+02:00")
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
self.assertEqual(prev_local.isoformat(), "2018-10-28T02:55:00+02:00")
self.assertEqual(prev, utc)
def test_following_previous_schedule_daily_dag_CEST_to_CET(self):
"""
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone('Europe/Zurich')
start = local_tz.convert(datetime.datetime(2018, 10, 27, 3),
dst_rule=pendulum.PRE_TRANSITION)
utc = timezone.convert_to_utc(start)
dag = DAG('tz_dag', start_date=start, schedule_interval='0 3 * * *')
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
self.assertEqual(prev_local.isoformat(), "2018-10-26T03:00:00+02:00")
self.assertEqual(prev.isoformat(), "2018-10-26T01:00:00+00:00")
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
self.assertEqual(next_local.isoformat(), "2018-10-28T03:00:00+01:00")
self.assertEqual(_next.isoformat(), "2018-10-28T02:00:00+00:00")
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
self.assertEqual(prev_local.isoformat(), "2018-10-27T03:00:00+02:00")
self.assertEqual(prev.isoformat(), "2018-10-27T01:00:00+00:00")
def test_following_previous_schedule_daily_dag_CET_to_CEST(self):
"""
Make sure DST transitions are properly observed
"""
local_tz = pendulum.timezone('Europe/Zurich')
start = local_tz.convert(datetime.datetime(2018, 3, 25, 2),
dst_rule=pendulum.PRE_TRANSITION)
utc = timezone.convert_to_utc(start)
dag = DAG('tz_dag', start_date=start, schedule_interval='0 3 * * *')
prev = dag.previous_schedule(utc)
prev_local = local_tz.convert(prev)
self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
_next = dag.following_schedule(utc)
next_local = local_tz.convert(_next)
self.assertEqual(next_local.isoformat(), "2018-03-25T03:00:00+02:00")
self.assertEqual(_next.isoformat(), "2018-03-25T01:00:00+00:00")
prev = dag.previous_schedule(_next)
prev_local = local_tz.convert(prev)
self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00")
self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00")
@patch('airflow.models.timezone.utcnow')
def test_sync_to_db(self, mock_now):
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
DummyOperator(task_id='task', owner='owner1')
SubDagOperator(
task_id='subtask',
owner='owner2',
subdag=DAG(
'dag.subtask',
start_date=DEFAULT_DATE,
)
)
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
session = settings.Session()
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'})
self.assertEqual(orm_dag.last_scheduler_run, now)
self.assertTrue(orm_dag.is_active)
self.assertIsNone(orm_dag.default_view)
self.assertEqual(orm_dag.get_default_view(),
configuration.conf.get('webserver', 'dag_default_view').lower())
self.assertEqual(orm_dag.safe_dag_id, 'dag')
orm_subdag = session.query(DagModel).filter(
DagModel.dag_id == 'dag.subtask').one()
self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'})
self.assertEqual(orm_subdag.last_scheduler_run, now)
self.assertTrue(orm_subdag.is_active)
self.assertEqual(orm_subdag.safe_dag_id, 'dag__dot__subtask')
@patch('airflow.models.timezone.utcnow')
def test_sync_to_db_default_view(self, mock_now):
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_view="graph",
)
with dag:
DummyOperator(task_id='task', owner='owner1')
SubDagOperator(
task_id='subtask',
owner='owner2',
subdag=DAG(
'dag.subtask',
start_date=DEFAULT_DATE,
)
)
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
session = settings.Session()
dag.sync_to_db(session=session)
orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
self.assertIsNotNone(orm_dag.default_view)
self.assertEqual(orm_dag.get_default_view(), "graph")
class DagRunTest(unittest.TestCase):
def create_dag_run(self, dag,
state=State.RUNNING,
task_states=None,
execution_date=None,
is_backfill=False,
):
now = timezone.utcnow()
if execution_date is None:
execution_date = now
if is_backfill:
run_id = BackfillJob.ID_PREFIX + now.isoformat()
else:
run_id = 'manual__' + now.isoformat()
dag_run = dag.create_dagrun(
run_id=run_id,
execution_date=execution_date,
start_date=now,
state=state,
external_trigger=False,
)
if task_states is not None:
session = settings.Session()
for task_id, state in task_states.items():
ti = dag_run.get_task_instance(task_id)
ti.set_state(state, session)
session.close()
return dag_run
def test_clear_task_instances_for_backfill_dagrun(self):
now = timezone.utcnow()
session = settings.Session()
dag_id = 'test_clear_task_instances_for_backfill_dagrun'
dag = DAG(dag_id=dag_id, start_date=now)
self.create_dag_run(dag, execution_date=now, is_backfill=True)
task0 = DummyOperator(task_id='backfill_task_0', owner='test', dag=dag)
ti0 = TI(task=task0, execution_date=now)
ti0.run()
qry = session.query(TI).filter(
TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session)
session.commit()
ti0.refresh_from_db()
dr0 = session.query(DagRun).filter(
DagRun.dag_id == dag_id,
DagRun.execution_date == now
).first()
self.assertEqual(dr0.state, State.RUNNING)
def test_id_for_date(self):
run_id = models.DagRun.id_for_date(
timezone.datetime(2015, 1, 2, 3, 4, 5, 6))
self.assertEqual(
'scheduled__2015-01-02T03:04:05', run_id,
'Generated run_id did not match expectations: {0}'.format(run_id))
def test_dagrun_find(self):
session = settings.Session()
now = timezone.utcnow()
dag_id1 = "test_dagrun_find_externally_triggered"
dag_run = models.DagRun(
dag_id=dag_id1,
run_id='manual__' + now.isoformat(),
execution_date=now,
start_date=now,
state=State.RUNNING,
external_trigger=True,
)
session.add(dag_run)
dag_id2 = "test_dagrun_find_not_externally_triggered"
dag_run = models.DagRun(
dag_id=dag_id2,
run_id='manual__' + now.isoformat(),
execution_date=now,
start_date=now,
state=State.RUNNING,
external_trigger=False,
)
session.add(dag_run)
session.commit()
self.assertEqual(1,
len(models.DagRun.find(dag_id=dag_id1, external_trigger=True)))
self.assertEqual(0,
len(models.DagRun.find(dag_id=dag_id1, external_trigger=False)))
self.assertEqual(0,
len(models.DagRun.find(dag_id=dag_id2, external_trigger=True)))
self.assertEqual(1,
len(models.DagRun.find(dag_id=dag_id2, external_trigger=False)))
def test_dagrun_success_when_all_skipped(self):
"""
Tests that a DAG run succeeds when all tasks are skipped
"""
dag = DAG(
dag_id='test_dagrun_success_when_all_skipped',
start_date=timezone.datetime(2017, 1, 1)
)
dag_task1 = ShortCircuitOperator(
task_id='test_short_circuit_false',
dag=dag,
python_callable=lambda: False)
dag_task2 = DummyOperator(
task_id='test_state_skipped1',
dag=dag)
dag_task3 = DummyOperator(
task_id='test_state_skipped2',
dag=dag)
dag_task1.set_downstream(dag_task2)
dag_task2.set_downstream(dag_task3)
initial_task_states = {
'test_short_circuit_false': State.SUCCESS,
'test_state_skipped1': State.SKIPPED,
'test_state_skipped2': State.SKIPPED,
}
dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.SUCCESS, updated_dag_state)
def test_dagrun_success_conditions(self):
session = settings.Session()
dag = DAG(
'test_dagrun_success_conditions',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
# A -> B
# A -> C -> D
# ordered: B, D, C, A or D, B, C, A or D, C, B, A
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op3 = DummyOperator(task_id='C')
op4 = DummyOperator(task_id='D')
op1.set_upstream([op2, op3])
op3.set_upstream(op4)
dag.clear()
now = timezone.utcnow()
dr = dag.create_dagrun(run_id='test_dagrun_success_conditions',
state=State.RUNNING,
execution_date=now,
start_date=now)
# op1 = root
ti_op1 = dr.get_task_instance(task_id=op1.task_id)
ti_op1.set_state(state=State.SUCCESS, session=session)
ti_op2 = dr.get_task_instance(task_id=op2.task_id)
ti_op3 = dr.get_task_instance(task_id=op3.task_id)
ti_op4 = dr.get_task_instance(task_id=op4.task_id)
# root is successful, but unfinished tasks
state = dr.update_state()
self.assertEqual(State.RUNNING, state)
# one has failed, but root is successful
ti_op2.set_state(state=State.FAILED, session=session)
ti_op3.set_state(state=State.SUCCESS, session=session)
ti_op4.set_state(state=State.SUCCESS, session=session)
state = dr.update_state()
self.assertEqual(State.SUCCESS, state)
def test_dagrun_deadlock(self):
session = settings.Session()
dag = DAG(
'text_dagrun_deadlock',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op2.trigger_rule = TriggerRule.ONE_FAILED
op2.set_upstream(op1)
dag.clear()
now = timezone.utcnow()
dr = dag.create_dagrun(run_id='test_dagrun_deadlock',
state=State.RUNNING,
execution_date=now,
start_date=now)
ti_op1 = dr.get_task_instance(task_id=op1.task_id)
ti_op1.set_state(state=State.SUCCESS, session=session)
ti_op2 = dr.get_task_instance(task_id=op2.task_id)
ti_op2.set_state(state=State.NONE, session=session)
dr.update_state()
self.assertEqual(dr.state, State.RUNNING)
ti_op2.set_state(state=State.NONE, session=session)
op2.trigger_rule = 'invalid'
dr.update_state()
self.assertEqual(dr.state, State.FAILED)
def test_dagrun_no_deadlock_with_shutdown(self):
session = settings.Session()
dag = DAG('test_dagrun_no_deadlock_with_shutdown',
start_date=DEFAULT_DATE)
with dag:
op1 = DummyOperator(task_id='upstream_task')
op2 = DummyOperator(task_id='downstream_task')
op2.set_upstream(op1)
dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_with_shutdown',
state=State.RUNNING,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE)
upstream_ti = dr.get_task_instance(task_id='upstream_task')
upstream_ti.set_state(State.SHUTDOWN, session=session)
dr.update_state()
self.assertEqual(dr.state, State.RUNNING)
def test_dagrun_no_deadlock_with_depends_on_past(self):
session = settings.Session()
dag = DAG('test_dagrun_no_deadlock',
start_date=DEFAULT_DATE)
with dag:
DummyOperator(task_id='dop', depends_on_past=True)
DummyOperator(task_id='tc', task_concurrency=1)
dag.clear()
dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1',
state=State.RUNNING,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE)
dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2',
state=State.RUNNING,
execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
start_date=DEFAULT_DATE + datetime.timedelta(days=1))
ti1_op1 = dr.get_task_instance(task_id='dop')
dr2.get_task_instance(task_id='dop')
ti2_op1 = dr.get_task_instance(task_id='tc')
dr.get_task_instance(task_id='tc')
ti1_op1.set_state(state=State.RUNNING, session=session)
dr.update_state()
dr2.update_state()
self.assertEqual(dr.state, State.RUNNING)
self.assertEqual(dr2.state, State.RUNNING)
ti2_op1.set_state(state=State.RUNNING, session=session)
dr.update_state()
dr2.update_state()
self.assertEqual(dr.state, State.RUNNING)
self.assertEqual(dr2.state, State.RUNNING)
def test_dagrun_success_callback(self):
def on_success_callable(context):
self.assertEqual(
context['dag_run'].dag_id,
'test_dagrun_success_callback'
)
dag = DAG(
dag_id='test_dagrun_success_callback',
start_date=datetime.datetime(2017, 1, 1),
on_success_callback=on_success_callable,
)
dag_task1 = DummyOperator(
task_id='test_state_succeeded1',
dag=dag)
dag_task2 = DummyOperator(
task_id='test_state_succeeded2',
dag=dag)
dag_task1.set_downstream(dag_task2)
initial_task_states = {
'test_state_succeeded1': State.SUCCESS,
'test_state_succeeded2': State.SUCCESS,
}
dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.SUCCESS, updated_dag_state)
def test_dagrun_failure_callback(self):
def on_failure_callable(context):
self.assertEqual(
context['dag_run'].dag_id,
'test_dagrun_failure_callback'
)
dag = DAG(
dag_id='test_dagrun_failure_callback',
start_date=datetime.datetime(2017, 1, 1),
on_failure_callback=on_failure_callable,
)
dag_task1 = DummyOperator(
task_id='test_state_succeeded1',
dag=dag)
dag_task2 = DummyOperator(
task_id='test_state_failed2',
dag=dag)
initial_task_states = {
'test_state_succeeded1': State.SUCCESS,
'test_state_failed2': State.FAILED,
}
dag_task1.set_downstream(dag_task2)
dag_run = self.create_dag_run(dag=dag,
state=State.RUNNING,
task_states=initial_task_states)
updated_dag_state = dag_run.update_state()
self.assertEqual(State.FAILED, updated_dag_state)
def test_dagrun_set_state_end_date(self):
session = settings.Session()
dag = DAG(
'test_dagrun_set_state_end_date',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
dag.clear()
now = timezone.utcnow()
dr = dag.create_dagrun(run_id='test_dagrun_set_state_end_date',
state=State.RUNNING,
execution_date=now,
start_date=now)
# Initial end_date should be NULL
# State.SUCCESS and State.FAILED are all ending state and should set end_date
# State.RUNNING set end_date back to NULL
session.add(dr)
session.commit()
self.assertIsNone(dr.end_date)
dr.set_state(State.SUCCESS)
session.merge(dr)
session.commit()
dr_database = session.query(DagRun).filter(
DagRun.run_id == 'test_dagrun_set_state_end_date'
).one()
self.assertIsNotNone(dr_database.end_date)
self.assertEqual(dr.end_date, dr_database.end_date)
dr.set_state(State.RUNNING)
session.merge(dr)
session.commit()
dr_database = session.query(DagRun).filter(
DagRun.run_id == 'test_dagrun_set_state_end_date'
).one()
self.assertIsNone(dr_database.end_date)
dr.set_state(State.FAILED)
session.merge(dr)
session.commit()
dr_database = session.query(DagRun).filter(
DagRun.run_id == 'test_dagrun_set_state_end_date'
).one()
self.assertIsNotNone(dr_database.end_date)
self.assertEqual(dr.end_date, dr_database.end_date)
def test_dagrun_update_state_end_date(self):
session = settings.Session()
dag = DAG(
'test_dagrun_update_state_end_date',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})
# A -> B
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op1.set_upstream(op2)
dag.clear()
now = timezone.utcnow()
dr = dag.create_dagrun(run_id='test_dagrun_update_state_end_date',
state=State.RUNNING,
execution_date=now,
start_date=now)
# Initial end_date should be NULL
# State.SUCCESS and State.FAILED are all ending state and should set end_date
# State.RUNNING set end_date back to NULL
session.merge(dr)
session.commit()
self.assertIsNone(dr.end_date)
ti_op1 = dr.get_task_instance(task_id=op1.task_id)
ti_op1.set_state(state=State.SUCCESS, session=session)
ti_op2 = dr.get_task_instance(task_id=op2.task_id)
ti_op2.set_state(state=State.SUCCESS, session=session)
dr.update_state()
dr_database = session.query(DagRun).filter(
DagRun.run_id == 'test_dagrun_update_state_end_date'
).one()
self.assertIsNotNone(dr_database.end_date)
self.assertEqual(dr.end_date, dr_database.end_date)
ti_op1.set_state(state=State.RUNNING, session=session)
ti_op2.set_state(state=State.RUNNING, session=session)
dr.update_state()
dr_database = session.query(DagRun).filter(
DagRun.run_id == 'test_dagrun_update_state_end_date'
).one()
self.assertEqual(dr._state, State.RUNNING)
self.assertIsNone(dr.end_date)
self.assertIsNone(dr_database.end_date)
ti_op1.set_state(state=State.FAILED, session=session)
ti_op2.set_state(state=State.FAILED, session=session)
dr.update_state()
dr_database = session.query(DagRun).filter(
DagRun.run_id == 'test_dagrun_update_state_end_date'
).one()
self.assertIsNotNone(dr_database.end_date)
self.assertEqual(dr.end_date, dr_database.end_date)
def test_get_task_instance_on_empty_dagrun(self):
"""
Make sure that a proper value is returned when a dagrun has no task instances
"""
dag = DAG(
dag_id='test_get_task_instance_on_empty_dagrun',
start_date=timezone.datetime(2017, 1, 1)
)
ShortCircuitOperator(
task_id='test_short_circuit_false',
dag=dag,
python_callable=lambda: False)
session = settings.Session()
now = timezone.utcnow()
# Don't use create_dagrun since it will create the task instances too which we
# don't want
dag_run = models.DagRun(
dag_id=dag.dag_id,
run_id='manual__' + now.isoformat(),
execution_date=now,
start_date=now,
state=State.RUNNING,
external_trigger=False,
)
session.add(dag_run)
session.commit()
ti = dag_run.get_task_instance('test_short_circuit_false')
self.assertEqual(None, ti)
def test_get_latest_runs(self):
session = settings.Session()
dag = DAG(
dag_id='test_latest_runs_1',
start_date=DEFAULT_DATE)
self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 1))
self.create_dag_run(dag, execution_date=timezone.datetime(2015, 1, 2))
dagruns = models.DagRun.get_latest_runs(session)
session.close()
for dagrun in dagruns:
if dagrun.dag_id == 'test_latest_runs_1':
self.assertEqual(dagrun.execution_date, timezone.datetime(2015, 1, 2))
def test_is_backfill(self):
dag = DAG(dag_id='test_is_backfill', start_date=DEFAULT_DATE)
dagrun = self.create_dag_run(dag, execution_date=DEFAULT_DATE)
dagrun.run_id = BackfillJob.ID_PREFIX + '_sfddsffds'
dagrun2 = self.create_dag_run(
dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
dagrun3 = self.create_dag_run(
dag, execution_date=DEFAULT_DATE + datetime.timedelta(days=2))
dagrun3.run_id = None
self.assertTrue(dagrun.is_backfill)
self.assertFalse(dagrun2.is_backfill)
self.assertFalse(dagrun3.is_backfill)
def test_removed_task_instances_can_be_restored(self):
def with_all_tasks_removed(dag):
return DAG(dag_id=dag.dag_id, start_date=dag.start_date)
dag = DAG('test_task_restoration', start_date=DEFAULT_DATE)
dag.add_task(DummyOperator(task_id='flaky_task', owner='test'))
dagrun = self.create_dag_run(dag)
flaky_ti = dagrun.get_task_instances()[0]
self.assertEqual('flaky_task', flaky_ti.task_id)
self.assertEqual(State.NONE, flaky_ti.state)
dagrun.dag = with_all_tasks_removed(dag)
dagrun.verify_integrity()
flaky_ti.refresh_from_db()
self.assertEqual(State.NONE, flaky_ti.state)
dagrun.dag.add_task(DummyOperator(task_id='flaky_task', owner='test'))
dagrun.verify_integrity()
flaky_ti.refresh_from_db()
self.assertEqual(State.NONE, flaky_ti.state)
class DagBagTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.empty_dir = mkdtemp()
@classmethod
def tearDownClass(cls):
os.rmdir(cls.empty_dir)
def test_get_existing_dag(self):
"""
Test that we're able to parse some example DAGs and retrieve them
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True)
some_expected_dag_ids = ["example_bash_operator",
"example_branch_operator"]
for dag_id in some_expected_dag_ids:
dag = dagbag.get_dag(dag_id)
self.assertIsNotNone(dag)
self.assertEqual(dag_id, dag.dag_id)
self.assertGreaterEqual(dagbag.size(), 7)
def test_get_non_existing_dag(self):
"""
test that retrieving a non existing dag id returns None without crashing
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
non_existing_dag_id = "non_existing_dag_id"
self.assertIsNone(dagbag.get_dag(non_existing_dag_id))
def test_process_file_that_contains_multi_bytes_char(self):
"""
test that we're able to parse file that contains multi-byte char
"""
f = NamedTemporaryFile()
f.write('\u3042'.encode('utf8')) # write multi-byte char (hiragana)
f.flush()
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
self.assertEqual([], dagbag.process_file(f.name))
def test_zip_skip_log(self):
"""
test the loading of a DAG from within a zip file that skips another file because
it doesn't have "airflow" and "DAG"
"""
from mock import Mock
with patch('airflow.models.DagBag.log') as log_mock:
log_mock.info = Mock()
test_zip_path = os.path.join(TEST_DAGS_FOLDER, "test_zip.zip")
dagbag = models.DagBag(dag_folder=test_zip_path, include_examples=False)
self.assertTrue(dagbag.has_logged)
log_mock.info.assert_any_call("File %s assumed to contain no DAGs. Skipping.",
test_zip_path)
def test_zip(self):
"""
test the loading of a DAG within a zip file that includes dependencies
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip"))
self.assertTrue(dagbag.get_dag("test_zip_dag"))
def test_process_file_cron_validity_check(self):
"""
test if an invalid cron expression
as schedule interval can be identified
"""
invalid_dag_files = ["test_invalid_cron.py", "test_zip_invalid_cron.zip"]
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
self.assertEqual(len(dagbag.import_errors), 0)
for d in invalid_dag_files:
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, d))
self.assertEqual(len(dagbag.import_errors), len(invalid_dag_files))
@patch.object(DagModel, 'get_current')
def test_get_dag_without_refresh(self, mock_dagmodel):
"""
Test that, once a DAG is loaded, it doesn't get refreshed again if it
hasn't been expired.
"""
dag_id = 'example_bash_operator'
mock_dagmodel.return_value = DagModel()
mock_dagmodel.return_value.last_expired = None
mock_dagmodel.return_value.fileloc = 'foo'
class TestDagBag(models.DagBag):
process_file_calls = 0
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
if 'example_bash_operator.py' == os.path.basename(filepath):
TestDagBag.process_file_calls += 1
super(TestDagBag, self).process_file(filepath, only_if_updated, safe_mode)
dagbag = TestDagBag(include_examples=True)
dagbag.process_file_calls
# Should not call process_file again, since it's already loaded during init.
self.assertEqual(1, dagbag.process_file_calls)
self.assertIsNotNone(dagbag.get_dag(dag_id))
self.assertEqual(1, dagbag.process_file_calls)
def test_get_dag_fileloc(self):
"""
Test that fileloc is correctly set when we load example DAGs,
specifically SubDAGs.
"""
dagbag = models.DagBag(include_examples=True)
expected = {
'example_bash_operator': 'example_bash_operator.py',
'example_subdag_operator': 'example_subdag_operator.py',
'example_subdag_operator.section-1': 'subdags/subdag.py'
}
for dag_id, path in expected.items():
dag = dagbag.get_dag(dag_id)
self.assertTrue(
dag.fileloc.endswith('airflow/example_dags/' + path))
def process_dag(self, create_dag):
"""
Helper method to process a file generated from the input create_dag function.
"""
# write source to file
source = textwrap.dedent(''.join(
inspect.getsource(create_dag).splitlines(True)[1:-1]))
f = NamedTemporaryFile()
f.write(source.encode('utf8'))
f.flush()
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
found_dags = dagbag.process_file(f.name)
return dagbag, found_dags, f.name
def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag,
should_be_found=True):
expected_dag_ids = list(map(lambda dag: dag.dag_id, expected_parent_dag.subdags))
expected_dag_ids.append(expected_parent_dag.dag_id)
actual_found_dag_ids = list(map(lambda dag: dag.dag_id, actual_found_dags))
for dag_id in expected_dag_ids:
actual_dagbag.log.info('validating %s' % dag_id)
self.assertEqual(
dag_id in actual_found_dag_ids, should_be_found,
'dag "%s" should %shave been found after processing dag "%s"' %
(dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id)
)
self.assertEqual(
dag_id in actual_dagbag.dags, should_be_found,
'dag "%s" should %sbe in dagbag.dags after processing dag "%s"' %
(dag_id, '' if should_be_found else 'not ', expected_parent_dag.dag_id)
)
def test_load_subdags(self):
# Define Dag to load
def standard_subdag():
from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.subdag_operator import SubDagOperator
import datetime
DAG_NAME = 'master'
DEFAULT_ARGS = {
'owner': 'owner1',
'start_date': datetime.datetime(2016, 1, 1)
}
dag = DAG(
DAG_NAME,
default_args=DEFAULT_ARGS)
# master:
# A -> opSubDag_0
# master.opsubdag_0:
# -> subdag_0.task
# A -> opSubDag_1
# master.opsubdag_1:
# -> subdag_1.task
with dag:
def subdag_0():
subdag_0 = DAG('master.opSubdag_0', default_args=DEFAULT_ARGS)
DummyOperator(task_id='subdag_0.task', dag=subdag_0)
return subdag_0
def subdag_1():
subdag_1 = DAG('master.opSubdag_1', default_args=DEFAULT_ARGS)
DummyOperator(task_id='subdag_1.task', dag=subdag_1)
return subdag_1
opSubdag_0 = SubDagOperator(
task_id='opSubdag_0', dag=dag, subdag=subdag_0())
opSubdag_1 = SubDagOperator(
task_id='opSubdag_1', dag=dag, subdag=subdag_1())
opA = DummyOperator(task_id='A')
opA.set_downstream(opSubdag_0)
opA.set_downstream(opSubdag_1)
return dag
testDag = standard_subdag()
# sanity check to make sure DAG.subdag is still functioning properly
self.assertEqual(len(testDag.subdags), 2)
# Perform processing dag
dagbag, found_dags, _ = self.process_dag(standard_subdag)
# Validate correctness
# all dags from testDag should be listed
self.validate_dags(testDag, found_dags, dagbag)
# Define Dag to load
def nested_subdags():
from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.subdag_operator import SubDagOperator
import datetime
DAG_NAME = 'master'
DEFAULT_ARGS = {
'owner': 'owner1',
'start_date': datetime.datetime(2016, 1, 1)
}
dag = DAG(
DAG_NAME,
default_args=DEFAULT_ARGS)
# master:
# A -> opSubdag_0
# master.opSubdag_0:
# -> opSubDag_A
# master.opSubdag_0.opSubdag_A:
# -> subdag_A.task
# -> opSubdag_B
# master.opSubdag_0.opSubdag_B:
# -> subdag_B.task
# A -> opSubdag_1
# master.opSubdag_1:
# -> opSubdag_C
# master.opSubdag_1.opSubdag_C:
# -> subdag_C.task
# -> opSubDag_D
# master.opSubdag_1.opSubdag_D:
# -> subdag_D.task
with dag:
def subdag_A():
subdag_A = DAG(
'master.opSubdag_0.opSubdag_A', default_args=DEFAULT_ARGS)
DummyOperator(task_id='subdag_A.task', dag=subdag_A)
return subdag_A
def subdag_B():
subdag_B = DAG(
'master.opSubdag_0.opSubdag_B', default_args=DEFAULT_ARGS)
DummyOperator(task_id='subdag_B.task', dag=subdag_B)
return subdag_B
def subdag_C():
subdag_C = DAG(
'master.opSubdag_1.opSubdag_C', default_args=DEFAULT_ARGS)
DummyOperator(task_id='subdag_C.task', dag=subdag_C)
return subdag_C
def subdag_D():
subdag_D = DAG(
'master.opSubdag_1.opSubdag_D', default_args=DEFAULT_ARGS)
DummyOperator(task_id='subdag_D.task', dag=subdag_D)
return subdag_D
def subdag_0():
subdag_0 = DAG('master.opSubdag_0', default_args=DEFAULT_ARGS)
SubDagOperator(task_id='opSubdag_A', dag=subdag_0, subdag=subdag_A())
SubDagOperator(task_id='opSubdag_B', dag=subdag_0, subdag=subdag_B())
return subdag_0
def subdag_1():
subdag_1 = DAG('master.opSubdag_1', default_args=DEFAULT_ARGS)
SubDagOperator(task_id='opSubdag_C', dag=subdag_1, subdag=subdag_C())
SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_D())
return subdag_1
opSubdag_0 = SubDagOperator(
task_id='opSubdag_0', dag=dag, subdag=subdag_0())
opSubdag_1 = SubDagOperator(
task_id='opSubdag_1', dag=dag, subdag=subdag_1())
opA = DummyOperator(task_id='A')
opA.set_downstream(opSubdag_0)
opA.set_downstream(opSubdag_1)
return dag
testDag = nested_subdags()
# sanity check to make sure DAG.subdag is still functioning properly
self.assertEqual(len(testDag.subdags), 6)
# Perform processing dag
dagbag, found_dags, _ = self.process_dag(nested_subdags)
# Validate correctness
# all dags from testDag should be listed
self.validate_dags(testDag, found_dags, dagbag)
def test_skip_cycle_dags(self):
"""
Don't crash when loading an invalid (contains a cycle) DAG file.
Don't load the dag into the DagBag either
"""
# Define Dag to load
def basic_cycle():
from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
import datetime
DAG_NAME = 'cycle_dag'
DEFAULT_ARGS = {
'owner': 'owner1',
'start_date': datetime.datetime(2016, 1, 1)
}
dag = DAG(
DAG_NAME,
default_args=DEFAULT_ARGS)
# A -> A
with dag:
opA = DummyOperator(task_id='A')
opA.set_downstream(opA)
return dag
testDag = basic_cycle()
# sanity check to make sure DAG.subdag is still functioning properly
self.assertEqual(len(testDag.subdags), 0)
# Perform processing dag
dagbag, found_dags, file_path = self.process_dag(basic_cycle)
# #Validate correctness
# None of the dags should be found
self.validate_dags(testDag, found_dags, dagbag, should_be_found=False)
self.assertIn(file_path, dagbag.import_errors)
# Define Dag to load
def nested_subdag_cycle():
from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.subdag_operator import SubDagOperator
import datetime
DAG_NAME = 'nested_cycle'
DEFAULT_ARGS = {
'owner': 'owner1',
'start_date': datetime.datetime(2016, 1, 1)
}
dag = DAG(
DAG_NAME,
default_args=DEFAULT_ARGS)
# cycle:
# A -> opSubdag_0
# cycle.opSubdag_0:
# -> opSubDag_A
# cycle.opSubdag_0.opSubdag_A:
# -> subdag_A.task
# -> opSubdag_B
# cycle.opSubdag_0.opSubdag_B:
# -> subdag_B.task
# A -> opSubdag_1
# cycle.opSubdag_1:
# -> opSubdag_C
# cycle.opSubdag_1.opSubdag_C:
# -> subdag_C.task -> subdag_C.task >Invalid Loop<
# -> opSubDag_D
# cycle.opSubdag_1.opSubdag_D:
# -> subdag_D.task
with dag:
def subdag_A():
subdag_A = DAG(
'nested_cycle.opSubdag_0.opSubdag_A', default_args=DEFAULT_ARGS)
DummyOperator(task_id='subdag_A.task', dag=subdag_A)
return subdag_A
def subdag_B():
subdag_B = DAG(
'nested_cycle.opSubdag_0.opSubdag_B', default_args=DEFAULT_ARGS)
DummyOperator(task_id='subdag_B.task', dag=subdag_B)
return subdag_B
def subdag_C():
subdag_C = DAG(
'nested_cycle.opSubdag_1.opSubdag_C', default_args=DEFAULT_ARGS)
opSubdag_C_task = DummyOperator(
task_id='subdag_C.task', dag=subdag_C)
# introduce a loop in opSubdag_C
opSubdag_C_task.set_downstream(opSubdag_C_task)
return subdag_C
def subdag_D():
subdag_D = DAG(
'nested_cycle.opSubdag_1.opSubdag_D', default_args=DEFAULT_ARGS)
DummyOperator(task_id='subdag_D.task', dag=subdag_D)
return subdag_D
def subdag_0():
subdag_0 = DAG('nested_cycle.opSubdag_0', default_args=DEFAULT_ARGS)
SubDagOperator(task_id='opSubdag_A', dag=subdag_0, subdag=subdag_A())
SubDagOperator(task_id='opSubdag_B', dag=subdag_0, subdag=subdag_B())
return subdag_0
def subdag_1():
subdag_1 = DAG('nested_cycle.opSubdag_1', default_args=DEFAULT_ARGS)
SubDagOperator(task_id='opSubdag_C', dag=subdag_1, subdag=subdag_C())
SubDagOperator(task_id='opSubdag_D', dag=subdag_1, subdag=subdag_D())
return subdag_1
opSubdag_0 = SubDagOperator(
task_id='opSubdag_0', dag=dag, subdag=subdag_0())
opSubdag_1 = SubDagOperator(
task_id='opSubdag_1', dag=dag, subdag=subdag_1())
opA = DummyOperator(task_id='A')
opA.set_downstream(opSubdag_0)
opA.set_downstream(opSubdag_1)
return dag
testDag = nested_subdag_cycle()
# sanity check to make sure DAG.subdag is still functioning properly
self.assertEqual(len(testDag.subdags), 6)
# Perform processing dag
dagbag, found_dags, file_path = self.process_dag(nested_subdag_cycle)
# Validate correctness
# None of the dags should be found
self.validate_dags(testDag, found_dags, dagbag, should_be_found=False)
self.assertIn(file_path, dagbag.import_errors)
def test_process_file_with_none(self):
"""
test that process_file can handle Nones
"""
dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=False)
self.assertEqual([], dagbag.process_file(None))
@patch.object(TI, 'handle_failure')
def test_kill_zombies(self, mock_ti_handle_failure):
"""
Test that kill zombies call TIs failure handler with proper context
"""
dagbag = models.DagBag()
with create_session() as session:
session.query(TI).delete()
dag = dagbag.get_dag('example_branch_operator')
task = dag.get_task(task_id='run_this_first')
ti = TI(task, DEFAULT_DATE, State.RUNNING)
session.add(ti)
session.commit()
zombies = [SimpleTaskInstance(ti)]
dagbag.kill_zombies(zombies)
mock_ti_handle_failure \
.assert_called_with(ANY,
configuration.getboolean('core',
'unit_test_mode'),
ANY)
def test_deactivate_unknown_dags(self):
"""
Test that dag_ids not passed into deactivate_unknown_dags
are deactivated when function is invoked
"""
dagbag = models.DagBag(include_examples=True)
expected_active_dags = dagbag.dags.keys()
session = settings.Session
session.add(DagModel(dag_id='test_deactivate_unknown_dags', is_active=True))
session.commit()
models.DAG.deactivate_unknown_dags(expected_active_dags)
for dag in session.query(DagModel).all():
if dag.dag_id in expected_active_dags:
self.assertTrue(dag.is_active)
else:
self.assertEqual(dag.dag_id, 'test_deactivate_unknown_dags')
self.assertFalse(dag.is_active)
# clean up
session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete()
session.commit()
class TaskInstanceTest(unittest.TestCase):
def tearDown(self):
with create_session() as session:
session.query(models.TaskFail).delete()
session.query(models.TaskReschedule).delete()
session.query(models.TaskInstance).delete()
def test_set_task_dates(self):
"""
Test that tasks properly take start/end dates from DAGs
"""
dag = DAG('dag', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
op1 = DummyOperator(task_id='op_1', owner='test')
self.assertTrue(op1.start_date is None and op1.end_date is None)
# dag should assign its dates to op1 because op1 has no dates
dag.add_task(op1)
self.assertTrue(
op1.start_date == dag.start_date and op1.end_date == dag.end_date)
op2 = DummyOperator(
task_id='op_2',
owner='test',
start_date=DEFAULT_DATE - datetime.timedelta(days=1),
end_date=DEFAULT_DATE + datetime.timedelta(days=11))
# dag should assign its dates to op2 because they are more restrictive
dag.add_task(op2)
self.assertTrue(
op2.start_date == dag.start_date and op2.end_date == dag.end_date)
op3 = DummyOperator(
task_id='op_3',
owner='test',
start_date=DEFAULT_DATE + datetime.timedelta(days=1),
end_date=DEFAULT_DATE + datetime.timedelta(days=9))
# op3 should keep its dates because they are more restrictive
dag.add_task(op3)
self.assertTrue(
op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1))
self.assertTrue(
op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9))
def test_timezone_awareness(self):
NAIVE_DATETIME = DEFAULT_DATE.replace(tzinfo=None)
# check ti without dag (just for bw compat)
op_no_dag = DummyOperator(task_id='op_no_dag')
ti = TI(task=op_no_dag, execution_date=NAIVE_DATETIME)
self.assertEqual(ti.execution_date, DEFAULT_DATE)
# check with dag without localized execution_date
dag = DAG('dag', start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id='op_1')
dag.add_task(op1)
ti = TI(task=op1, execution_date=NAIVE_DATETIME)
self.assertEqual(ti.execution_date, DEFAULT_DATE)
# with dag and localized execution_date
tz = pendulum.timezone("Europe/Amsterdam")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tz)
utc_date = timezone.convert_to_utc(execution_date)
ti = TI(task=op1, execution_date=execution_date)
self.assertEqual(ti.execution_date, utc_date)
def test_task_naive_datetime(self):
NAIVE_DATETIME = DEFAULT_DATE.replace(tzinfo=None)
op_no_dag = DummyOperator(task_id='test_task_naive_datetime',
start_date=NAIVE_DATETIME,
end_date=NAIVE_DATETIME)
self.assertTrue(op_no_dag.start_date.tzinfo)
self.assertTrue(op_no_dag.end_date.tzinfo)
def test_set_dag(self):
"""
Test assigning Operators to Dags, including deferred assignment
"""
dag = DAG('dag', start_date=DEFAULT_DATE)
dag2 = DAG('dag2', start_date=DEFAULT_DATE)
op = DummyOperator(task_id='op_1', owner='test')
# no dag assigned
self.assertFalse(op.has_dag())
self.assertRaises(AirflowException, getattr, op, 'dag')
# no improper assignment
with self.assertRaises(TypeError):
op.dag = 1
op.dag = dag
# no reassignment
with self.assertRaises(AirflowException):
op.dag = dag2
# but assigning the same dag is ok
op.dag = dag
self.assertIs(op.dag, dag)
self.assertIn(op, dag.tasks)
def test_infer_dag(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
dag2 = DAG('dag2', start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id='test_op_1', owner='test')
op2 = DummyOperator(task_id='test_op_2', owner='test')
op3 = DummyOperator(task_id='test_op_3', owner='test', dag=dag)
op4 = DummyOperator(task_id='test_op_4', owner='test', dag=dag2)
# double check dags
self.assertEqual(
[i.has_dag() for i in [op1, op2, op3, op4]],
[False, False, True, True])
# can't combine operators with no dags
self.assertRaises(AirflowException, op1.set_downstream, op2)
# op2 should infer dag from op1
op1.dag = dag
op1.set_downstream(op2)
self.assertIs(op2.dag, dag)
# can't assign across multiple DAGs
self.assertRaises(AirflowException, op1.set_downstream, op4)
self.assertRaises(AirflowException, op1.set_downstream, [op3, op4])
def test_bitshift_compose_operators(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id='test_op_1', owner='test')
op2 = DummyOperator(task_id='test_op_2', owner='test')
op3 = DummyOperator(task_id='test_op_3', owner='test')
op4 = DummyOperator(task_id='test_op_4', owner='test')
op5 = DummyOperator(task_id='test_op_5', owner='test')
# can't compose operators without dags
with self.assertRaises(AirflowException):
op1 >> op2
dag >> op1 >> op2 << op3
# make sure dag assignment carries through
# using __rrshift__
self.assertIs(op1.dag, dag)
self.assertIs(op2.dag, dag)
self.assertIs(op3.dag, dag)
# op2 should be downstream of both
self.assertIn(op2, op1.downstream_list)
self.assertIn(op2, op3.downstream_list)
# test dag assignment with __rlshift__
dag << op4
self.assertIs(op4.dag, dag)
# dag assignment with __rrshift__
dag >> op5
self.assertIs(op5.dag, dag)
@patch.object(DAG, 'concurrency_reached')
def test_requeue_over_concurrency(self, mock_concurrency_reached):
mock_concurrency_reached.return_value = True
dag = DAG(dag_id='test_requeue_over_concurrency', start_date=DEFAULT_DATE,
max_active_runs=1, concurrency=2)
task = DummyOperator(task_id='test_requeue_over_concurrency_op', dag=dag)
ti = TI(task=task, execution_date=timezone.utcnow())
ti.run()
self.assertEqual(ti.state, models.State.NONE)
@patch.object(TI, 'pool_full')
def test_run_pooling_task(self, mock_pool_full):
"""
test that running task update task state as without running task.
(no dependency check in ti_deps anymore, so also -> SUCCESS)
"""
# Mock the pool out with a full pool because the pool doesn't actually exist
mock_pool_full.return_value = True
dag = models.DAG(dag_id='test_run_pooling_task')
task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag,
pool='test_run_pooling_task_pool', owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(
task=task, execution_date=timezone.utcnow())
ti.run()
self.assertEqual(ti.state, models.State.SUCCESS)
@patch.object(TI, 'pool_full')
def test_run_pooling_task_with_mark_success(self, mock_pool_full):
"""
test that running task with mark_success param update task state as SUCCESS
without running task.
"""
# Mock the pool out with a full pool because the pool doesn't actually exist
mock_pool_full.return_value = True
dag = models.DAG(dag_id='test_run_pooling_task_with_mark_success')
task = DummyOperator(
task_id='test_run_pooling_task_with_mark_success_op',
dag=dag,
pool='test_run_pooling_task_with_mark_success_pool',
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(
task=task, execution_date=timezone.utcnow())
ti.run(mark_success=True)
self.assertEqual(ti.state, models.State.SUCCESS)
def test_run_pooling_task_with_skip(self):
"""
test that running task which returns AirflowSkipOperator will end
up in a SKIPPED state.
"""
def raise_skip_exception():
raise AirflowSkipException
dag = models.DAG(dag_id='test_run_pooling_task_with_skip')
task = PythonOperator(
task_id='test_run_pooling_task_with_skip',
dag=dag,
python_callable=raise_skip_exception,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(
task=task, execution_date=timezone.utcnow())
ti.run()
self.assertEqual(models.State.SKIPPED, ti.state)
def test_retry_delay(self):
"""
Test that retry delays are respected
"""
dag = models.DAG(dag_id='test_retry_handling')
task = BashOperator(
task_id='test_retry_handling_op',
bash_command='exit 1',
retries=1,
retry_delay=datetime.timedelta(seconds=3),
dag=dag,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
def run_with_error(ti):
try:
ti.run()
except AirflowException:
pass
ti = TI(
task=task, execution_date=timezone.utcnow())
self.assertEqual(ti.try_number, 1)
# first run -- up for retry
run_with_error(ti)
self.assertEqual(ti.state, State.UP_FOR_RETRY)
self.assertEqual(ti.try_number, 2)
# second run -- still up for retry because retry_delay hasn't expired
run_with_error(ti)
self.assertEqual(ti.state, State.UP_FOR_RETRY)
# third run -- failed
time.sleep(3)
run_with_error(ti)
self.assertEqual(ti.state, State.FAILED)
@patch.object(TI, 'pool_full')
def test_retry_handling(self, mock_pool_full):
"""
Test that task retries are handled properly
"""
# Mock the pool with a pool with slots open since the pool doesn't actually exist
mock_pool_full.return_value = False
dag = models.DAG(dag_id='test_retry_handling')
task = BashOperator(
task_id='test_retry_handling_op',
bash_command='exit 1',
retries=1,
retry_delay=datetime.timedelta(seconds=0),
dag=dag,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
def run_with_error(ti):
try:
ti.run()
except AirflowException:
pass
ti = TI(
task=task, execution_date=timezone.utcnow())
self.assertEqual(ti.try_number, 1)
# first run -- up for retry
run_with_error(ti)
self.assertEqual(ti.state, State.UP_FOR_RETRY)
self.assertEqual(ti._try_number, 1)
self.assertEqual(ti.try_number, 2)
# second run -- fail
run_with_error(ti)
self.assertEqual(ti.state, State.FAILED)
self.assertEqual(ti._try_number, 2)
self.assertEqual(ti.try_number, 3)
# Clear the TI state since you can't run a task with a FAILED state without
# clearing it first
dag.clear()
# third run -- up for retry
run_with_error(ti)
self.assertEqual(ti.state, State.UP_FOR_RETRY)
self.assertEqual(ti._try_number, 3)
self.assertEqual(ti.try_number, 4)
# fourth run -- fail
run_with_error(ti)
ti.refresh_from_db()
self.assertEqual(ti.state, State.FAILED)
self.assertEqual(ti._try_number, 4)
self.assertEqual(ti.try_number, 5)
def test_next_retry_datetime(self):
delay = datetime.timedelta(seconds=30)
max_delay = datetime.timedelta(minutes=60)
dag = models.DAG(dag_id='fail_dag')
task = BashOperator(
task_id='task_with_exp_backoff_and_max_delay',
bash_command='exit 1',
retries=3,
retry_delay=delay,
retry_exponential_backoff=True,
max_retry_delay=max_delay,
dag=dag,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(
task=task, execution_date=DEFAULT_DATE)
ti.end_date = pendulum.instance(timezone.utcnow())
dt = ti.next_retry_datetime()
# between 30 * 2^0.5 and 30 * 2^1 (15 and 30)
period = ti.end_date.add(seconds=30) - ti.end_date.add(seconds=15)
self.assertTrue(dt in period)
ti.try_number = 3
dt = ti.next_retry_datetime()
# between 30 * 2^2 and 30 * 2^3 (120 and 240)
period = ti.end_date.add(seconds=240) - ti.end_date.add(seconds=120)
self.assertTrue(dt in period)
ti.try_number = 5
dt = ti.next_retry_datetime()
# between 30 * 2^4 and 30 * 2^5 (480 and 960)
period = ti.end_date.add(seconds=960) - ti.end_date.add(seconds=480)
self.assertTrue(dt in period)
ti.try_number = 9
dt = ti.next_retry_datetime()
self.assertEqual(dt, ti.end_date + max_delay)
ti.try_number = 50
dt = ti.next_retry_datetime()
self.assertEqual(dt, ti.end_date + max_delay)
@patch.object(TI, 'pool_full')
def test_reschedule_handling(self, mock_pool_full):
"""
Test that task reschedules are handled properly
"""
# Mock the pool with a pool with slots open since the pool doesn't actually exist
mock_pool_full.return_value = False
# Return values of the python sensor callable, modified during tests
done = False
fail = False
def callable():
if fail:
raise AirflowException()
return done
dag = models.DAG(dag_id='test_reschedule_handling')
task = PythonSensor(
task_id='test_reschedule_handling_sensor',
poke_interval=0,
mode='reschedule',
python_callable=callable,
retries=1,
retry_delay=datetime.timedelta(seconds=0),
dag=dag,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))
ti = TI(task=task, execution_date=timezone.utcnow())
self.assertEqual(ti._try_number, 0)
self.assertEqual(ti.try_number, 1)
def run_ti_and_assert(run_date, expected_start_date, expected_end_date, expected_duration,
expected_state, expected_try_number, expected_task_reschedule_count):
with freeze_time(run_date):
try:
ti.run()
except AirflowException:
if not fail:
raise
ti.refresh_from_db()
self.assertEqual(ti.state, expected_state)
self.assertEqual(ti._try_number, expected_try_number)
self.assertEqual(ti.try_number, expected_try_number + 1)
self.assertEqual(ti.start_date, expected_start_date)
self.assertEqual(ti.end_date, expected_end_date)
self.assertEqual(ti.duration, expected_duration)
trs = TR.find_for_task_instance(ti)
self.assertEqual(len(trs), expected_task_reschedule_count)
date1 = timezone.utcnow()
date2 = date1 + datetime.timedelta(minutes=1)
date3 = date2 + datetime.timedelta(minutes=1)
date4 = date3 + datetime.timedelta(minutes=1)
# Run with multiple reschedules.
# During reschedule the try number remains the same, but each reschedule is recorded.
# The start date is expected to remain the inital date, hence the duration increases.
# When finished the try number is incremented and there is no reschedule expected
# for this try.
done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1)
done, fail = False, False
run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2)
done, fail = False, False
run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3)
done, fail = True, False
run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0)
# Clear the task instance.
dag.clear()
ti.refresh_from_db()
self.assertEqual(ti.state, State.NONE)
self.assertEqual(ti._try_number, 1)
# Run again after clearing with reschedules and a retry.
# The retry increments the try number, and for that try no reschedule is expected.
# After the retry the start date is reset, hence the duration is also reset.
done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1)
done, fail = False, True
run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0)
done, fail = False, False
run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1)
done, fail = True, False
run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)
def test_depends_on_past(self):
dagbag = models.DagBag()
dag = dagbag.get_dag('test_depends_on_past')
dag.clear()
task = dag.tasks[0]
run_date = task.start_date + datetime.timedelta(days=5)
ti = TI(task, run_date)
# depends_on_past prevents the run
task.run(start_date=run_date, end_date=run_date)
ti.refresh_from_db()
self.assertIs(ti.state, None)
# ignore first depends_on_past to allow the run
task.run(
start_date=run_date,
end_date=run_date,
ignore_first_depends_on_past=True)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done
@parameterized.expand([
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, True, ST.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, True, ST.SKIPPED, False],
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, True, None, True],
['one_success', 2, 0, 0, 0, 2, True, None, True],
['one_success', 2, 0, 1, 0, 3, True, None, True],
['one_success', 2, 1, 0, 0, 3, True, None, True],
#
# Tests for all_failed
#
['all_failed', 5, 0, 0, 0, 5, True, ST.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, True, None, True],
['all_failed', 2, 0, 0, 0, 2, True, ST.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, True, ST.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, True, ST.SKIPPED, False],
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, True, None, False],
['one_failed', 2, 3, 0, 0, 5, True, ST.SKIPPED, False],
#
# Tests for done
#
['all_done', 5, 0, 0, 0, 5, True, None, True],
['all_done', 2, 0, 0, 0, 2, True, None, False],
['all_done', 2, 0, 1, 0, 3, True, None, False],
['all_done', 2, 1, 0, 0, 3, True, None, False]
])
def test_check_task_dependencies(self, trigger_rule, successes, skipped,
failed, upstream_failed, done,
flag_upstream_failed,
expect_state, expect_completed):
start_date = timezone.datetime(2016, 2, 1, 0, 0, 0)
dag = models.DAG('test-dag', start_date=start_date)
downstream = DummyOperator(task_id='downstream',
dag=dag, owner='airflow',
trigger_rule=trigger_rule)
for i in range(5):
task = DummyOperator(task_id='runme_{}'.format(i),
dag=dag, owner='airflow')
task.set_downstream(downstream)
run_date = task.start_date + datetime.timedelta(days=5)
ti = TI(downstream, run_date)
dep_results = TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=successes,
skipped=skipped,
failed=failed,
upstream_failed=upstream_failed,
done=done,
flag_upstream_failed=flag_upstream_failed)
completed = all([dep.passed for dep in dep_results])
self.assertEqual(completed, expect_completed)
self.assertEqual(ti.state, expect_state)
def test_xcom_pull(self):
"""
Test xcom_pull, using different filtering methods.
"""
dag = models.DAG(
dag_id='test_xcom', schedule_interval='@monthly',
start_date=timezone.datetime(2016, 6, 1, 0, 0, 0))
exec_date = timezone.utcnow()
# Push a value
task1 = DummyOperator(task_id='test_xcom_1', dag=dag, owner='airflow')
ti1 = TI(task=task1, execution_date=exec_date)
ti1.xcom_push(key='foo', value='bar')
# Push another value with the same key (but by a different task)
task2 = DummyOperator(task_id='test_xcom_2', dag=dag, owner='airflow')
ti2 = TI(task=task2, execution_date=exec_date)
ti2.xcom_push(key='foo', value='baz')
# Pull with no arguments
result = ti1.xcom_pull()
self.assertEqual(result, None)
# Pull the value pushed most recently by any task.
result = ti1.xcom_pull(key='foo')
self.assertIn(result, 'baz')
# Pull the value pushed by the first task
result = ti1.xcom_pull(task_ids='test_xcom_1', key='foo')
self.assertEqual(result, 'bar')
# Pull the value pushed by the second task
result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo')
self.assertEqual(result, 'baz')
# Pull the values pushed by both tasks
result = ti1.xcom_pull(
task_ids=['test_xcom_1', 'test_xcom_2'], key='foo')
self.assertEqual(result, ('bar', 'baz'))
def test_xcom_pull_after_success(self):
"""
tests xcom set/clear relative to a task in a 'success' rerun scenario
"""
key = 'xcom_key'
value = 'xcom_value'
dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly')
task = DummyOperator(
task_id='test_xcom',
dag=dag,
pool='test_xcom',
owner='airflow',
start_date=timezone.datetime(2016, 6, 2, 0, 0, 0))
exec_date = timezone.utcnow()
ti = TI(
task=task, execution_date=exec_date)
ti.run(mark_success=True)
ti.xcom_push(key=key, value=value)
self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
ti.run()
# The second run and assert is to handle AIRFLOW-131 (don't clear on
# prior success)
self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
# Test AIRFLOW-703: Xcom shouldn't be cleared if the task doesn't
# execute, even if dependencies are ignored
ti.run(ignore_all_deps=True, mark_success=True)
self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
# Xcom IS finally cleared once task has executed
ti.run(ignore_all_deps=True)
self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None)
def test_xcom_pull_different_execution_date(self):
"""
tests xcom fetch behavior with different execution dates, using
both xcom_pull with "include_prior_dates" and without
"""
key = 'xcom_key'
value = 'xcom_value'
dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly')
task = DummyOperator(
task_id='test_xcom',
dag=dag,
pool='test_xcom',
owner='airflow',
start_date=timezone.datetime(2016, 6, 2, 0, 0, 0))
exec_date = timezone.utcnow()
ti = TI(
task=task, execution_date=exec_date)
ti.run(mark_success=True)
ti.xcom_push(key=key, value=value)
self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
ti.run()
exec_date += datetime.timedelta(days=1)
ti = TI(
task=task, execution_date=exec_date)
ti.run()
# We have set a new execution date (and did not pass in
# 'include_prior_dates'which means this task should now have a cleared
# xcom value
self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None)
# We *should* get a value using 'include_prior_dates'
self.assertEqual(ti.xcom_pull(task_ids='test_xcom',
key=key,
include_prior_dates=True),
value)
def test_xcom_push_flag(self):
"""
Tests the option for Operators to push XComs
"""
value = 'hello'
task_id = 'test_no_xcom_push'
dag = models.DAG(dag_id='test_xcom')
# nothing saved to XCom
task = PythonOperator(
task_id=task_id,
dag=dag,
python_callable=lambda: value,
do_xcom_push=False,
owner='airflow',
start_date=datetime.datetime(2017, 1, 1)
)
ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1))
ti.run()
self.assertEqual(
ti.xcom_pull(
task_ids=task_id, key=models.XCOM_RETURN_KEY
),
None
)
def test_post_execute_hook(self):
"""
Test that post_execute hook is called with the Operator's result.
The result ('error') will cause an error to be raised and trapped.
"""
class TestError(Exception):
pass
class TestOperator(PythonOperator):
def post_execute(self, context, result):
if result == 'error':
raise TestError('expected error.')
dag = models.DAG(dag_id='test_post_execute_dag')
task = TestOperator(
task_id='test_operator',
dag=dag,
python_callable=lambda: 'error',
owner='airflow',
start_date=timezone.datetime(2017, 2, 1))
ti = TI(task=task, execution_date=timezone.utcnow())
with self.assertRaises(TestError):
ti.run()
def test_check_and_change_state_before_execution(self):
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
ti = TI(
task=task, execution_date=timezone.utcnow())
self.assertEqual(ti._try_number, 0)
self.assertTrue(ti._check_and_change_state_before_execution())
# State should be running, and try_number column should be incremented
self.assertEqual(ti.state, State.RUNNING)
self.assertEqual(ti._try_number, 1)
def test_check_and_change_state_before_execution_dep_not_met(self):
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE)
task >> task2
ti = TI(
task=task2, execution_date=timezone.utcnow())
self.assertFalse(ti._check_and_change_state_before_execution())
def test_try_number(self):
"""
Test the try_number accessor behaves in various running states
"""
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
ti = TI(task=task, execution_date=timezone.utcnow())
self.assertEqual(1, ti.try_number)
ti.try_number = 2
ti.state = State.RUNNING
self.assertEqual(2, ti.try_number)
ti.state = State.SUCCESS
self.assertEqual(3, ti.try_number)
def test_get_num_running_task_instances(self):
session = settings.Session()
dag = models.DAG(dag_id='test_get_num_running_task_instances')
dag2 = models.DAG(dag_id='test_get_num_running_task_instances_dummy')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
task2 = DummyOperator(task_id='task', dag=dag2, start_date=DEFAULT_DATE)
ti1 = TI(task=task, execution_date=DEFAULT_DATE)
ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
ti3 = TI(task=task2, execution_date=DEFAULT_DATE)
ti1.state = State.RUNNING
ti2.state = State.QUEUED
ti3.state = State.RUNNING
session.add(ti1)
session.add(ti2)
session.add(ti3)
session.commit()
self.assertEqual(1, ti1.get_num_running_task_instances(session=session))
self.assertEqual(1, ti2.get_num_running_task_instances(session=session))
self.assertEqual(1, ti3.get_num_running_task_instances(session=session))
# def test_log_url(self):
# now = pendulum.now('Europe/Brussels')
# dag = DAG('dag', start_date=DEFAULT_DATE)
# task = DummyOperator(task_id='op', dag=dag)
# ti = TI(task=task, execution_date=now)
# d = urllib.parse.parse_qs(
# urllib.parse.urlparse(ti.log_url).query,
# keep_blank_values=True, strict_parsing=True)
# self.assertEqual(d['dag_id'][0], 'dag')
# self.assertEqual(d['task_id'][0], 'op')
# self.assertEqual(pendulum.parse(d['execution_date'][0]), now)
def test_log_url(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
ti = TI(task=task, execution_date=datetime.datetime(2018, 1, 1))
expected_url = (
'http://localhost:8080/log?'
'execution_date=2018-01-01T00%3A00%3A00%2B00%3A00'
'&task_id=op'
'&dag_id=dag'
)
self.assertEqual(ti.log_url, expected_url)
def test_mark_success_url(self):
now = pendulum.now('Europe/Brussels')
dag = DAG('dag', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
ti = TI(task=task, execution_date=now)
d = urllib.parse.parse_qs(
urllib.parse.urlparse(ti.mark_success_url).query,
keep_blank_values=True, strict_parsing=True)
self.assertEqual(d['dag_id'][0], 'dag')
self.assertEqual(d['task_id'][0], 'op')
self.assertEqual(pendulum.parse(d['execution_date'][0]), now)
def test_overwrite_params_with_dag_run_conf(self):
task = DummyOperator(task_id='op')
ti = TI(task=task, execution_date=datetime.datetime.now())
dag_run = DagRun()
dag_run.conf = {"override": True}
params = {"override": False}
ti.overwrite_params_with_dag_run_conf(params, dag_run)
self.assertEqual(True, params["override"])
def test_overwrite_params_with_dag_run_none(self):
task = DummyOperator(task_id='op')
ti = TI(task=task, execution_date=datetime.datetime.now())
params = {"override": False}
ti.overwrite_params_with_dag_run_conf(params, None)
self.assertEqual(False, params["override"])
def test_overwrite_params_with_dag_run_conf_none(self):
task = DummyOperator(task_id='op')
ti = TI(task=task, execution_date=datetime.datetime.now())
params = {"override": False}
dag_run = DagRun()
ti.overwrite_params_with_dag_run_conf(params, dag_run)
self.assertEqual(False, params["override"])
@patch('airflow.models.send_email')
def test_email_alert(self, mock_send_email):
dag = models.DAG(dag_id='test_failure_email')
task = BashOperator(
task_id='test_email_alert',
dag=dag,
bash_command='exit 1',
start_date=DEFAULT_DATE,
email='to')
ti = TI(task=task, execution_date=datetime.datetime.now())
try:
ti.run()
except AirflowException:
pass
(email, title, body), _ = mock_send_email.call_args
self.assertEqual(email, 'to')
self.assertIn('test_email_alert', title)
self.assertIn('test_email_alert', body)
@patch('airflow.models.send_email')
def test_email_alert_with_config(self, mock_send_email):
dag = models.DAG(dag_id='test_failure_email')
task = BashOperator(
task_id='test_email_alert_with_config',
dag=dag,
bash_command='exit 1',
start_date=DEFAULT_DATE,
email='to')
ti = TI(
task=task, execution_date=datetime.datetime.now())
configuration.set('email', 'SUBJECT_TEMPLATE', '/subject/path')
configuration.set('email', 'HTML_CONTENT_TEMPLATE', '/html_content/path')
opener = mock_open(read_data='template: {{ti.task_id}}')
with patch('airflow.models.open', opener, create=True):
try:
ti.run()
except AirflowException:
pass
(email, title, body), _ = mock_send_email.call_args
self.assertEqual(email, 'to')
self.assertEqual('template: test_email_alert_with_config', title)
self.assertEqual('template: test_email_alert_with_config', body)
def test_set_duration(self):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(
task=task,
execution_date=datetime.datetime.now(),
)
ti.start_date = datetime.datetime(2018, 10, 1, 1)
ti.end_date = datetime.datetime(2018, 10, 1, 2)
ti.set_duration()
self.assertEqual(ti.duration, 3600)
def test_set_duration_empty_dates(self):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.set_duration()
self.assertIsNone(ti.duration)
def test_success_callbak_no_race_condition(self):
class CallbackWrapper(object):
def wrap_task_instance(self, ti):
self.task_id = ti.task_id
self.dag_id = ti.dag_id
self.execution_date = ti.execution_date
self.task_state_in_callback = ""
self.callback_ran = False
def success_handler(self, context):
self.callback_ran = True
session = settings.Session()
temp_instance = session.query(TI).filter(
TI.task_id == self.task_id).filter(
TI.dag_id == self.dag_id).filter(
TI.execution_date == self.execution_date).one()
self.task_state_in_callback = temp_instance.state
cw = CallbackWrapper()
dag = DAG('test_success_callbak_no_race_condition', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
task = DummyOperator(task_id='op', email='test@test.test',
on_success_callback=cw.success_handler, dag=dag)
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session = settings.Session()
session.merge(ti)
session.commit()
cw.wrap_task_instance(ti)
ti._run_raw_task()
self.assertTrue(cw.callback_ran)
self.assertEqual(cw.task_state_in_callback, State.RUNNING)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
class ClearTasksTest(unittest.TestCase):
def test_clear_task_instances(self):
dag = DAG('test_clear_task_instances', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
task0 = DummyOperator(task_id='0', owner='test', dag=dag)
task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2)
ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
ti0.run()
ti1.run()
session = settings.Session()
qry = session.query(TI).filter(
TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session, dag=dag)
session.commit()
ti0.refresh_from_db()
ti1.refresh_from_db()
# Next try to run will be try 2
self.assertEqual(ti0.try_number, 2)
self.assertEqual(ti0.max_tries, 1)
self.assertEqual(ti1.try_number, 2)
self.assertEqual(ti1.max_tries, 3)
def test_clear_task_instances_without_task(self):
dag = DAG('test_clear_task_instances_without_task', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
task0 = DummyOperator(task_id='task0', owner='test', dag=dag)
task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2)
ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
ti0.run()
ti1.run()
# Remove the task from dag.
dag.task_dict = {}
self.assertFalse(dag.has_task(task0.task_id))
self.assertFalse(dag.has_task(task1.task_id))
session = settings.Session()
qry = session.query(TI).filter(
TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session)
session.commit()
# When dag is None, max_tries will be maximum of original max_tries or try_number.
ti0.refresh_from_db()
ti1.refresh_from_db()
# Next try to run will be try 2
self.assertEqual(ti0.try_number, 2)
self.assertEqual(ti0.max_tries, 1)
self.assertEqual(ti1.try_number, 2)
self.assertEqual(ti1.max_tries, 2)
def test_clear_task_instances_without_dag(self):
dag = DAG('test_clear_task_instances_without_dag', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
task0 = DummyOperator(task_id='task_0', owner='test', dag=dag)
task1 = DummyOperator(task_id='task_1', owner='test', dag=dag, retries=2)
ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
ti0.run()
ti1.run()
session = settings.Session()
qry = session.query(TI).filter(
TI.dag_id == dag.dag_id).all()
clear_task_instances(qry, session)
session.commit()
# When dag is None, max_tries will be maximum of original max_tries or try_number.
ti0.refresh_from_db()
ti1.refresh_from_db()
# Next try to run will be try 2
self.assertEqual(ti0.try_number, 2)
self.assertEqual(ti0.max_tries, 1)
self.assertEqual(ti1.try_number, 2)
self.assertEqual(ti1.max_tries, 2)
def test_dag_clear(self):
dag = DAG('test_dag_clear', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag)
ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
# Next try to run will be try 1
self.assertEqual(ti0.try_number, 1)
ti0.run()
self.assertEqual(ti0.try_number, 2)
dag.clear()
ti0.refresh_from_db()
self.assertEqual(ti0.try_number, 2)
self.assertEqual(ti0.state, State.NONE)
self.assertEqual(ti0.max_tries, 1)
task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test',
dag=dag, retries=2)
ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
self.assertEqual(ti1.max_tries, 2)
ti1.try_number = 1
# Next try will be 2
ti1.run()
self.assertEqual(ti1.try_number, 3)
self.assertEqual(ti1.max_tries, 2)
dag.clear()
ti0.refresh_from_db()
ti1.refresh_from_db()
# after clear dag, ti2 should show attempt 3 of 5
self.assertEqual(ti1.max_tries, 4)
self.assertEqual(ti1.try_number, 3)
# after clear dag, ti1 should show attempt 2 of 2
self.assertEqual(ti0.try_number, 2)
self.assertEqual(ti0.max_tries, 1)
def test_dags_clear(self):
# setup
session = settings.Session()
dags, tis = [], []
num_of_dags = 5
for i in range(num_of_dags):
dag = DAG('test_dag_clear_' + str(i), start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
ti = TI(task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test',
dag=dag),
execution_date=DEFAULT_DATE)
dags.append(dag)
tis.append(ti)
# test clear all dags
for i in range(num_of_dags):
tis[i].run()
self.assertEqual(tis[i].state, State.SUCCESS)
self.assertEqual(tis[i].try_number, 2)
self.assertEqual(tis[i].max_tries, 0)
DAG.clear_dags(dags)
for i in range(num_of_dags):
tis[i].refresh_from_db()
self.assertEqual(tis[i].state, State.NONE)
self.assertEqual(tis[i].try_number, 2)
self.assertEqual(tis[i].max_tries, 1)
# test dry_run
for i in range(num_of_dags):
tis[i].run()
self.assertEqual(tis[i].state, State.SUCCESS)
self.assertEqual(tis[i].try_number, 3)
self.assertEqual(tis[i].max_tries, 1)
DAG.clear_dags(dags, dry_run=True)
for i in range(num_of_dags):
tis[i].refresh_from_db()
self.assertEqual(tis[i].state, State.SUCCESS)
self.assertEqual(tis[i].try_number, 3)
self.assertEqual(tis[i].max_tries, 1)
# test only_failed
from random import randint
failed_dag_idx = randint(0, len(tis) - 1)
tis[failed_dag_idx].state = State.FAILED
session.merge(tis[failed_dag_idx])
session.commit()
DAG.clear_dags(dags, only_failed=True)
for i in range(num_of_dags):
tis[i].refresh_from_db()
if i != failed_dag_idx:
self.assertEqual(tis[i].state, State.SUCCESS)
self.assertEqual(tis[i].try_number, 3)
self.assertEqual(tis[i].max_tries, 1)
else:
self.assertEqual(tis[i].state, State.NONE)
self.assertEqual(tis[i].try_number, 3)
self.assertEqual(tis[i].max_tries, 2)
def test_operator_clear(self):
dag = DAG('test_operator_clear', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
t1 = DummyOperator(task_id='bash_op', owner='test', dag=dag)
t2 = DummyOperator(task_id='dummy_op', owner='test', dag=dag, retries=1)
t2.set_upstream(t1)
ti1 = TI(task=t1, execution_date=DEFAULT_DATE)
ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
ti2.run()
# Dependency not met
self.assertEqual(ti2.try_number, 1)
self.assertEqual(ti2.max_tries, 1)
t2.clear(upstream=True)
ti1.run()
ti2.run()
self.assertEqual(ti1.try_number, 2)
# max_tries is 0 because there is no task instance in db for ti1
# so clear won't change the max_tries.
self.assertEqual(ti1.max_tries, 0)
self.assertEqual(ti2.try_number, 2)
# try_number (0) + retries(1)
self.assertEqual(ti2.max_tries, 1)
def test_xcom_disable_pickle_type(self):
configuration.load_test_config()
json_obj = {"key": "value"}
execution_date = timezone.utcnow()
key = "xcom_test1"
dag_id = "test_dag1"
task_id = "test_task1"
configuration.set("core", "enable_xcom_pickling", "False")
XCom.set(key=key,
value=json_obj,
dag_id=dag_id,
task_id=task_id,
execution_date=execution_date)
ret_value = XCom.get_one(key=key,
dag_id=dag_id,
task_id=task_id,
execution_date=execution_date)
self.assertEqual(ret_value, json_obj)
session = settings.Session()
ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id,
XCom.task_id == task_id,
XCom.execution_date == execution_date
).first().value
self.assertEqual(ret_value, json_obj)
def test_xcom_enable_pickle_type(self):
json_obj = {"key": "value"}
execution_date = timezone.utcnow()
key = "xcom_test2"
dag_id = "test_dag2"
task_id = "test_task2"
configuration.set("core", "enable_xcom_pickling", "True")
XCom.set(key=key,
value=json_obj,
dag_id=dag_id,
task_id=task_id,
execution_date=execution_date)
ret_value = XCom.get_one(key=key,
dag_id=dag_id,
task_id=task_id,
execution_date=execution_date)
self.assertEqual(ret_value, json_obj)
session = settings.Session()
ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id,
XCom.task_id == task_id,
XCom.execution_date == execution_date
).first().value
self.assertEqual(ret_value, json_obj)
def test_xcom_disable_pickle_type_fail_on_non_json(self):
class PickleRce(object):
def __reduce__(self):
return os.system, ("ls -alt",)
configuration.set("core", "xcom_enable_pickling", "False")
self.assertRaises(TypeError, XCom.set,
key="xcom_test3",
value=PickleRce(),
dag_id="test_dag3",
task_id="test_task3",
execution_date=timezone.utcnow())
def test_xcom_get_many(self):
json_obj = {"key": "value"}
execution_date = timezone.utcnow()
key = "xcom_test4"
dag_id1 = "test_dag4"
task_id1 = "test_task4"
dag_id2 = "test_dag5"
task_id2 = "test_task5"
configuration.set("core", "xcom_enable_pickling", "True")
XCom.set(key=key,
value=json_obj,
dag_id=dag_id1,
task_id=task_id1,
execution_date=execution_date)
XCom.set(key=key,
value=json_obj,
dag_id=dag_id2,
task_id=task_id2,
execution_date=execution_date)
results = XCom.get_many(key=key,
execution_date=execution_date)
for result in results:
self.assertEqual(result.value, json_obj)
class VariableTest(unittest.TestCase):
def setUp(self):
models._fernet = None
def tearDown(self):
models._fernet = None
@patch('airflow.models.configuration.conf.get')
def test_variable_no_encryption(self, mock_get):
"""
Test variables without encryption
"""
mock_get.return_value = ''
Variable.set('key', 'value')
session = settings.Session()
test_var = session.query(Variable).filter(Variable.key == 'key').one()
self.assertFalse(test_var.is_encrypted)
self.assertEqual(test_var.val, 'value')
@patch('airflow.models.configuration.conf.get')
def test_variable_with_encryption(self, mock_get):
"""
Test variables with encryption
"""
mock_get.return_value = Fernet.generate_key().decode()
Variable.set('key', 'value')
session = settings.Session()
test_var = session.query(Variable).filter(Variable.key == 'key').one()
self.assertTrue(test_var.is_encrypted)
self.assertEqual(test_var.val, 'value')
@patch('airflow.models.configuration.conf.get')
def test_var_with_encryption_rotate_fernet_key(self, mock_get):
"""
Tests rotating encrypted variables.
"""
key1 = Fernet.generate_key()
key2 = Fernet.generate_key()
mock_get.return_value = key1.decode()
Variable.set('key', 'value')
session = settings.Session()
test_var = session.query(Variable).filter(Variable.key == 'key').one()
self.assertTrue(test_var.is_encrypted)
self.assertEqual(test_var.val, 'value')
self.assertEqual(Fernet(key1).decrypt(test_var._val.encode()), b'value')
# Test decrypt of old value with new key
mock_get.return_value = ','.join([key2.decode(), key1.decode()])
models._fernet = None
self.assertEqual(test_var.val, 'value')
# Test decrypt of new value with new key
test_var.rotate_fernet_key()
self.assertTrue(test_var.is_encrypted)
self.assertEqual(test_var.val, 'value')
self.assertEqual(Fernet(key2).decrypt(test_var._val.encode()), b'value')
class ConnectionTest(unittest.TestCase):
def setUp(self):
models._fernet = None
def tearDown(self):
models._fernet = None
@patch('airflow.models.configuration.conf.get')
def test_connection_extra_no_encryption(self, mock_get):
"""
Tests extras on a new connection without encryption. The fernet key
is set to a non-base64-encoded string and the extra is stored without
encryption.
"""
mock_get.return_value = ''
test_connection = Connection(extra='testextra')
self.assertFalse(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')
@patch('airflow.models.configuration.conf.get')
def test_connection_extra_with_encryption(self, mock_get):
"""
Tests extras on a new connection with encryption.
"""
mock_get.return_value = Fernet.generate_key().decode()
test_connection = Connection(extra='testextra')
self.assertTrue(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')
@patch('airflow.models.configuration.conf.get')
def test_connection_extra_with_encryption_rotate_fernet_key(self, mock_get):
"""
Tests rotating encrypted extras.
"""
key1 = Fernet.generate_key()
key2 = Fernet.generate_key()
mock_get.return_value = key1.decode()
test_connection = Connection(extra='testextra')
self.assertTrue(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')
self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra')
# Test decrypt of old value with new key
mock_get.return_value = ','.join([key2.decode(), key1.decode()])
models._fernet = None
self.assertEqual(test_connection.extra, 'testextra')
# Test decrypt of new value with new key
test_connection.rotate_fernet_key()
self.assertTrue(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')
self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
def test_connection_from_uri_without_extras(self):
uri = 'scheme://user:password@host%2flocation:1234/schema'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password')
self.assertEqual(connection.port, 1234)
self.assertIsNone(connection.extra)
def test_connection_from_uri_with_extras(self):
uri = 'scheme://user:password@host%2flocation:1234/schema?' \
'extra1=a%20value&extra2=%2fpath%2f'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password')
self.assertEqual(connection.port, 1234)
self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
'extra2': '/path/'})
def test_connection_from_uri_with_colon_in_hostname(self):
uri = 'scheme://user:password@host%2flocation%3ax%3ay:1234/schema?' \
'extra1=a%20value&extra2=%2fpath%2f'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location:x:y')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password')
self.assertEqual(connection.port, 1234)
self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
'extra2': '/path/'})
def test_connection_from_uri_with_encoded_password(self):
uri = 'scheme://user:password%20with%20space@host%2flocation%3ax%3ay:1234/schema'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location:x:y')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password with space')
self.assertEqual(connection.port, 1234)
def test_connection_from_uri_with_encoded_user(self):
uri = 'scheme://domain%2fuser:password@host%2flocation%3ax%3ay:1234/schema'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location:x:y')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'domain/user')
self.assertEqual(connection.password, 'password')
self.assertEqual(connection.port, 1234)
def test_connection_from_uri_with_encoded_schema(self):
uri = 'scheme://user:password%20with%20space@host:1234/schema%2ftest'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host')
self.assertEqual(connection.schema, 'schema/test')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password with space')
self.assertEqual(connection.port, 1234)
def test_connection_from_uri_no_schema(self):
uri = 'scheme://user:password%20with%20space@host:1234'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host')
self.assertEqual(connection.schema, '')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password with space')
self.assertEqual(connection.port, 1234)
class TestSkipMixin(unittest.TestCase):
@patch('airflow.models.timezone.utcnow')
def test_skip(self, mock_now):
session = settings.Session()
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
tasks = [DummyOperator(task_id='task')]
dag_run = dag.create_dagrun(
run_id='manual__' + now.isoformat(),
state=State.FAILED,
)
SkipMixin().skip(
dag_run=dag_run,
execution_date=now,
tasks=tasks,
session=session)
session.query(TI).filter(
TI.dag_id == 'dag',
TI.task_id == 'task',
TI.state == State.SKIPPED,
TI.start_date == now,
TI.end_date == now,
).one()
@patch('airflow.models.timezone.utcnow')
def test_skip_none_dagrun(self, mock_now):
session = settings.Session()
now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
mock_now.return_value = now
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
)
with dag:
tasks = [DummyOperator(task_id='task')]
SkipMixin().skip(
dag_run=None,
execution_date=now,
tasks=tasks,
session=session)
session.query(TI).filter(
TI.dag_id == 'dag',
TI.task_id == 'task',
TI.state == State.SKIPPED,
TI.start_date == now,
TI.end_date == now,
).one()
def test_skip_none_tasks(self):
session = Mock()
SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session)
self.assertFalse(session.query.called)
self.assertFalse(session.commit.called)
class TestKubeResourceVersion(unittest.TestCase):
def test_checkpoint_resource_version(self):
session = settings.Session()
KubeResourceVersion.checkpoint_resource_version('7', session)
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '7')
def test_reset_resource_version(self):
session = settings.Session()
version = KubeResourceVersion.reset_resource_version(session)
self.assertEqual(version, '0')
self.assertEqual(KubeResourceVersion.get_current_resource_version(session), '0')
class TestKubeWorkerIdentifier(unittest.TestCase):
@patch('airflow.models.uuid.uuid4')
def test_get_or_create_not_exist(self, mock_uuid):
session = settings.Session()
session.query(KubeWorkerIdentifier).update({
KubeWorkerIdentifier.worker_uuid: ''
})
mock_uuid.return_value = 'abcde'
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
self.assertEqual(worker_uuid, 'abcde')
def test_get_or_create_exist(self):
session = settings.Session()
KubeWorkerIdentifier.checkpoint_kube_worker_uuid('fghij', session)
worker_uuid = KubeWorkerIdentifier.get_or_create_current_kube_worker_uuid(session)
self.assertEqual(worker_uuid, 'fghij')