AIRFLOW-[3823] Exclude branch's downstream tasks from the tasks to skip (#4666)
This commit is contained in:
Родитель
8d6dcd1840
Коммит
59d2615459
|
@ -23,10 +23,11 @@ import pickle
|
|||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
from builtins import str
|
||||
from textwrap import dedent
|
||||
|
||||
import dill
|
||||
from builtins import str
|
||||
import six
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
|
@ -138,7 +139,7 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
|
|||
"""
|
||||
def execute(self, context):
|
||||
branch = super(BranchPythonOperator, self).execute(context)
|
||||
if isinstance(branch, str):
|
||||
if isinstance(branch, six.string_types):
|
||||
branch = [branch]
|
||||
self.log.info("Following branch %s", branch)
|
||||
self.log.info("Marking other directly downstream tasks as skipped")
|
||||
|
@ -146,8 +147,18 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
|
|||
downstream_tasks = context['task'].downstream_list
|
||||
self.log.debug("Downstream task_ids %s", downstream_tasks)
|
||||
|
||||
skip_tasks = [t for t in downstream_tasks if t.task_id not in branch]
|
||||
if downstream_tasks:
|
||||
# Also check downstream tasks of the branch task. In case the task to skip
|
||||
# is a downstream task of the branch task, we exclude it from skipping.
|
||||
branch_downstream_task_ids = set()
|
||||
for b in branch:
|
||||
branch_downstream_task_ids.update(context["dag"].
|
||||
get_task(b).
|
||||
get_flat_relative_ids(upstream=False))
|
||||
skip_tasks = [t
|
||||
for t in downstream_tasks
|
||||
if t.task_id not in branch and
|
||||
t.task_id not in branch_downstream_task_ids]
|
||||
self.skip(context['dag_run'], context['ti'].execution_date, skip_tasks)
|
||||
|
||||
self.log.info("Done.")
|
||||
|
|
|
@ -289,6 +289,64 @@ class BranchOperatorTest(unittest.TestCase):
|
|||
else:
|
||||
raise Exception
|
||||
|
||||
def test_with_skip_in_branch_downstream_dependencies(self):
|
||||
self.branch_op = BranchPythonOperator(task_id='make_choice',
|
||||
dag=self.dag,
|
||||
python_callable=lambda: 'branch_1')
|
||||
|
||||
self.branch_op >> self.branch_1 >> self.branch_2
|
||||
self.branch_op >> self.branch_2
|
||||
self.dag.clear()
|
||||
|
||||
dr = self.dag.create_dagrun(
|
||||
run_id="manual__",
|
||||
start_date=timezone.utcnow(),
|
||||
execution_date=DEFAULT_DATE,
|
||||
state=State.RUNNING
|
||||
)
|
||||
|
||||
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
||||
|
||||
tis = dr.get_task_instances()
|
||||
for ti in tis:
|
||||
if ti.task_id == 'make_choice':
|
||||
self.assertEqual(ti.state, State.SUCCESS)
|
||||
elif ti.task_id == 'branch_1':
|
||||
self.assertEqual(ti.state, State.NONE)
|
||||
elif ti.task_id == 'branch_2':
|
||||
self.assertEqual(ti.state, State.NONE)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
def test_with_skip_in_branch_downstream_dependencies2(self):
|
||||
self.branch_op = BranchPythonOperator(task_id='make_choice',
|
||||
dag=self.dag,
|
||||
python_callable=lambda: 'branch_2')
|
||||
|
||||
self.branch_op >> self.branch_1 >> self.branch_2
|
||||
self.branch_op >> self.branch_2
|
||||
self.dag.clear()
|
||||
|
||||
dr = self.dag.create_dagrun(
|
||||
run_id="manual__",
|
||||
start_date=timezone.utcnow(),
|
||||
execution_date=DEFAULT_DATE,
|
||||
state=State.RUNNING
|
||||
)
|
||||
|
||||
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
||||
|
||||
tis = dr.get_task_instances()
|
||||
for ti in tis:
|
||||
if ti.task_id == 'make_choice':
|
||||
self.assertEqual(ti.state, State.SUCCESS)
|
||||
elif ti.task_id == 'branch_1':
|
||||
self.assertEqual(ti.state, State.SKIPPED)
|
||||
elif ti.task_id == 'branch_2':
|
||||
self.assertEqual(ti.state, State.NONE)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
||||
class ShortCircuitOperatorTest(unittest.TestCase):
|
||||
@classmethod
|
||||
|
|
Загрузка…
Ссылка в новой задаче