AIRFLOW-[3823] Exclude branch's downstream tasks from the tasks to skip (#4666)

This commit is contained in:
BasPH 2019-02-10 19:59:46 +01:00 коммит произвёл Fokko Driesprong
Родитель 8d6dcd1840
Коммит 59d2615459
2 изменённых файлов: 72 добавлений и 3 удалений

Просмотреть файл

@ -23,10 +23,11 @@ import pickle
import subprocess
import sys
import types
from builtins import str
from textwrap import dedent
import dill
from builtins import str
import six
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@ -138,7 +139,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
"""
def execute(self, context):
branch = super(BranchPythonOperator, self).execute(context)
if isinstance(branch, str):
if isinstance(branch, six.string_types):
branch = [branch]
self.log.info("Following branch %s", branch)
self.log.info("Marking other directly downstream tasks as skipped")
@ -146,8 +147,18 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
downstream_tasks = context['task'].downstream_list
self.log.debug("Downstream task_ids %s", downstream_tasks)
skip_tasks = [t for t in downstream_tasks if t.task_id not in branch]
if downstream_tasks:
# Also check downstream tasks of the branch task. In case the task to skip
# is a downstream task of the branch task, we exclude it from skipping.
branch_downstream_task_ids = set()
for b in branch:
branch_downstream_task_ids.update(context["dag"].
get_task(b).
get_flat_relative_ids(upstream=False))
skip_tasks = [t
for t in downstream_tasks
if t.task_id not in branch and
t.task_id not in branch_downstream_task_ids]
self.skip(context['dag_run'], context['ti'].execution_date, skip_tasks)
self.log.info("Done.")

Просмотреть файл

@ -289,6 +289,64 @@ class BranchOperatorTest(unittest.TestCase):
else:
raise Exception
def test_with_skip_in_branch_downstream_dependencies(self):
self.branch_op = BranchPythonOperator(task_id='make_choice',
dag=self.dag,
python_callable=lambda: 'branch_1')
self.branch_op >> self.branch_1 >> self.branch_2
self.branch_op >> self.branch_2
self.dag.clear()
dr = self.dag.create_dagrun(
run_id="manual__",
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING
)
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'branch_1':
self.assertEqual(ti.state, State.NONE)
elif ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.NONE)
else:
raise Exception
def test_with_skip_in_branch_downstream_dependencies2(self):
self.branch_op = BranchPythonOperator(task_id='make_choice',
dag=self.dag,
python_callable=lambda: 'branch_2')
self.branch_op >> self.branch_1 >> self.branch_2
self.branch_op >> self.branch_2
self.dag.clear()
dr = self.dag.create_dagrun(
run_id="manual__",
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING
)
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'branch_1':
self.assertEqual(ti.state, State.SKIPPED)
elif ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.NONE)
else:
raise Exception
class ShortCircuitOperatorTest(unittest.TestCase):
@classmethod