[AIRFLOW-4549] Allow skipped tasks to satisfy wait_for_downstream (#7735)
Previously, tasks that were in SUCCESS or SKIPPED state satisfy the depends_on_past check, but only tasks that were in the SUCCESS state satisfy the wait_for_downstream check. The inconsistency in behavior made the API less intuitive to users.
This commit is contained in:
Родитель
5ae76d8cc0
Коммит
2ec0130099
10
UPDATING.md
10
UPDATING.md
|
@ -62,6 +62,16 @@ https://developers.google.com/style/inclusive-documentation
|
|||
|
||||
-->
|
||||
|
||||
### Skipped tasks can satisfy wait_for_downstream
|
||||
|
||||
Previously, a task instance with `wait_for_downstream=True` will only run if the downstream task of
|
||||
the previous task instance is successful. Meanwhile, a task instance with `depends_on_past=True`
|
||||
will run if the previous task instance is either successful or skipped. These two flags are close siblings
|
||||
yet they have different behavior. This inconsistency in behavior made the API less intuitive to users.
|
||||
To maintain consistent behavior, both successful or skipped downstream task can now satisfy the
|
||||
`wait_for_downstream=True` flag.
|
||||
|
||||
|
||||
### Ability to patch Pool.DEFAULT_POOL_NAME in BaseOperator
|
||||
It was not possible to patch pool in BaseOperator as the signature sets the default value of pool
|
||||
as Pool.DEFAULT_POOL_NAME.
|
||||
|
|
|
@ -127,12 +127,12 @@ class BaseOperator(Operator, LoggingMixin):
|
|||
:param end_date: if specified, the scheduler won't go beyond this date
|
||||
:type end_date: datetime.datetime
|
||||
:param depends_on_past: when set to true, task instances will run
|
||||
sequentially while relying on the previous task's schedule to
|
||||
succeed. The task instance for the start_date is allowed to run.
|
||||
sequentially and only if the previous instance has succeeded or has been skipped.
|
||||
The task instance for the start_date is allowed to run.
|
||||
:type depends_on_past: bool
|
||||
:param wait_for_downstream: when set to true, an instance of task
|
||||
X will wait for tasks immediately downstream of the previous instance
|
||||
of task X to finish successfully before it runs. This is useful if the
|
||||
of task X to finish successfully or be skipped before it runs. This is useful if the
|
||||
different instances of a task X alter the same asset, and this asset
|
||||
is used by tasks downstream of task X. Note that depends_on_past
|
||||
is forced to True wherever wait_for_downstream is used. Also note that
|
||||
|
|
|
@ -540,7 +540,7 @@ class TaskInstance(Base, LoggingMixin):
|
|||
@provide_session
|
||||
def are_dependents_done(self, session=None):
|
||||
"""
|
||||
Checks whether the dependents of this task instance have all succeeded.
|
||||
Checks whether the immediate dependents of this task instance have succeeded or have been skipped.
|
||||
This is meant to be used by wait_for_downstream.
|
||||
|
||||
This is useful when you do not want to start processing the next
|
||||
|
@ -556,7 +556,7 @@ class TaskInstance(Base, LoggingMixin):
|
|||
TaskInstance.dag_id == self.dag_id,
|
||||
TaskInstance.task_id.in_(task.downstream_task_ids),
|
||||
TaskInstance.execution_date == self.execution_date,
|
||||
TaskInstance.state == State.SUCCESS,
|
||||
TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]),
|
||||
)
|
||||
count = ti[0][0]
|
||||
return count == len(task.downstream_task_ids)
|
||||
|
|
|
@ -47,12 +47,7 @@ dag1_task1 = DummyOperator(
|
|||
dag=dag1,
|
||||
pool='test_backfill_pooled_task_pool',)
|
||||
|
||||
# DAG tests depends_on_past dependencies
|
||||
dag2 = DAG(dag_id='test_depends_on_past', default_args=default_args)
|
||||
dag2_task1 = DummyOperator(
|
||||
task_id='test_dop_task',
|
||||
dag=dag2,
|
||||
depends_on_past=True,)
|
||||
# dag2 has been moved to test_prev_dagrun_dep.py
|
||||
|
||||
# DAG tests that a Dag run that doesn't complete is marked failed
|
||||
dag3 = DAG(dag_id='test_dagrun_states_fail', default_args=default_args)
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from airflow.models import DAG
|
||||
from airflow.operators.dummy_operator import DummyOperator
|
||||
|
||||
DEFAULT_DATE = datetime(2016, 1, 1)
|
||||
default_args = dict(start_date=DEFAULT_DATE, owner="airflow")
|
||||
|
||||
# DAG tests depends_on_past dependencies
|
||||
dag_dop = DAG(dag_id="test_depends_on_past", default_args=default_args)
|
||||
with dag_dop:
|
||||
dag_dop_task = DummyOperator(task_id="test_dop_task", depends_on_past=True,)
|
||||
|
||||
# DAG tests wait_for_downstream dependencies
|
||||
dag_wfd = DAG(dag_id="test_wait_for_downstream", default_args=default_args)
|
||||
with dag_wfd:
|
||||
dag_wfd_upstream = DummyOperator(task_id="upstream_task", wait_for_downstream=True,)
|
||||
dag_wfd_downstream = DummyOperator(task_id="downstream_task",)
|
||||
dag_wfd_upstream >> dag_wfd_downstream
|
|
@ -19,8 +19,10 @@
|
|||
import datetime
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from airflow import models, settings
|
||||
from airflow.models import DAG, TaskInstance as TI, clear_task_instances
|
||||
from airflow.models import DAG, DagBag, TaskInstance as TI, clear_task_instances
|
||||
from airflow.models.dagrun import DagRun
|
||||
from airflow.operators.dummy_operator import DummyOperator
|
||||
from airflow.operators.python import ShortCircuitOperator
|
||||
|
@ -29,10 +31,19 @@ from airflow.utils.state import State
|
|||
from airflow.utils.trigger_rule import TriggerRule
|
||||
from airflow.utils.types import DagRunType
|
||||
from tests.models import DEFAULT_DATE
|
||||
from tests.test_utils.db import clear_db_pools, clear_db_runs
|
||||
|
||||
|
||||
class TestDagRun(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.dagbag = DagBag(include_examples=True)
|
||||
|
||||
def setUp(self):
|
||||
clear_db_runs()
|
||||
clear_db_pools()
|
||||
|
||||
def create_dag_run(self, dag,
|
||||
state=State.RUNNING,
|
||||
task_states=None,
|
||||
|
@ -552,3 +563,55 @@ class TestDagRun(unittest.TestCase):
|
|||
dagrun.verify_integrity()
|
||||
flaky_ti.refresh_from_db()
|
||||
self.assertEqual(State.NONE, flaky_ti.state)
|
||||
|
||||
@parameterized.expand([
|
||||
(State.SUCCESS, True),
|
||||
(State.SKIPPED, True),
|
||||
(State.RUNNING, False),
|
||||
(State.FAILED, False),
|
||||
(State.NONE, False),
|
||||
])
|
||||
def test_depends_on_past(self, prev_ti_state, is_ti_success):
|
||||
dag_id = 'test_depends_on_past'
|
||||
|
||||
dag = self.dagbag.get_dag(dag_id)
|
||||
task = dag.tasks[0]
|
||||
|
||||
self.create_dag_run(dag, execution_date=timezone.datetime(2016, 1, 1, 0, 0, 0))
|
||||
self.create_dag_run(dag, execution_date=timezone.datetime(2016, 1, 2, 0, 0, 0))
|
||||
|
||||
prev_ti = TI(task, timezone.datetime(2016, 1, 1, 0, 0, 0))
|
||||
ti = TI(task, timezone.datetime(2016, 1, 2, 0, 0, 0))
|
||||
|
||||
prev_ti.set_state(prev_ti_state)
|
||||
ti.set_state(State.QUEUED)
|
||||
ti.run()
|
||||
self.assertEqual(ti.state == State.SUCCESS, is_ti_success)
|
||||
|
||||
@parameterized.expand([
|
||||
(State.SUCCESS, True),
|
||||
(State.SKIPPED, True),
|
||||
(State.RUNNING, False),
|
||||
(State.FAILED, False),
|
||||
(State.NONE, False),
|
||||
])
|
||||
def test_wait_for_downstream(self, prev_ti_state, is_ti_success):
|
||||
dag_id = 'test_wait_for_downstream'
|
||||
dag = self.dagbag.get_dag(dag_id)
|
||||
upstream, downstream = dag.tasks
|
||||
|
||||
# For ti.set_state() to work, the DagRun has to exist,
|
||||
# Otherwise ti.previous_ti returns an unpersisted TI
|
||||
self.create_dag_run(dag, execution_date=timezone.datetime(2016, 1, 1, 0, 0, 0))
|
||||
self.create_dag_run(dag, execution_date=timezone.datetime(2016, 1, 2, 0, 0, 0))
|
||||
|
||||
prev_ti_downstream = TI(task=downstream, execution_date=timezone.datetime(2016, 1, 1, 0, 0, 0))
|
||||
ti = TI(task=upstream, execution_date=timezone.datetime(2016, 1, 2, 0, 0, 0))
|
||||
prev_ti = ti.get_previous_ti()
|
||||
prev_ti.set_state(State.SUCCESS)
|
||||
self.assertEqual(prev_ti.state, State.SUCCESS)
|
||||
|
||||
prev_ti_downstream.set_state(prev_ti_state)
|
||||
ti.set_state(State.QUEUED)
|
||||
ti.run()
|
||||
self.assertEqual(ti.state == State.SUCCESS, is_ti_success)
|
||||
|
|
|
@ -855,6 +855,38 @@ class TestTaskInstance(unittest.TestCase):
|
|||
self.assertEqual(completed, expect_completed)
|
||||
self.assertEqual(ti.state, expect_state)
|
||||
|
||||
def test_respects_prev_dagrun_dep(self):
|
||||
with DAG(dag_id='test_dag'):
|
||||
task = DummyOperator(task_id='task', start_date=DEFAULT_DATE)
|
||||
ti = TI(task, DEFAULT_DATE)
|
||||
failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')]
|
||||
passing_status = [TIDepStatus('test pass status name', True, 'test passing reason')]
|
||||
with patch('airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses',
|
||||
return_value=failing_status):
|
||||
self.assertFalse(ti.are_dependencies_met())
|
||||
with patch('airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses',
|
||||
return_value=passing_status):
|
||||
self.assertTrue(ti.are_dependencies_met())
|
||||
|
||||
@parameterized.expand([
|
||||
(State.SUCCESS, True),
|
||||
(State.SKIPPED, True),
|
||||
(State.RUNNING, False),
|
||||
(State.FAILED, False),
|
||||
(State.NONE, False),
|
||||
])
|
||||
def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_done):
|
||||
with DAG(dag_id='test_dag'):
|
||||
task = DummyOperator(task_id='task', start_date=DEFAULT_DATE)
|
||||
downstream_task = DummyOperator(task_id='downstream_task', start_date=DEFAULT_DATE)
|
||||
task >> downstream_task
|
||||
|
||||
ti = TI(task, DEFAULT_DATE)
|
||||
downstream_ti = TI(downstream_task, DEFAULT_DATE)
|
||||
|
||||
downstream_ti.set_state(downstream_ti_state)
|
||||
self.assertEqual(ti.are_dependents_done(), expected_are_dependents_done)
|
||||
|
||||
def test_xcom_pull(self):
|
||||
"""
|
||||
Test xcom_pull, using different filtering methods.
|
||||
|
|
Загрузка…
Ссылка в новой задаче