[AIRFLOW-4549] Allow skipped tasks to satisfy wait_for_downstream (#7735)

Previously, tasks that were in SUCCESS or SKIPPED state satisfy the
depends_on_past check, but only tasks that were in the SUCCESS state
satisfy the wait_for_downstream check. The inconsistency in behavior
made the API less intuitive to users.
This commit is contained in:
Teddy Hartanto 2020-05-11 22:22:31 +08:00 коммит произвёл GitHub
Родитель 5ae76d8cc0
Коммит 2ec0130099
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 149 добавлений и 12 удалений

Просмотреть файл

@ -62,6 +62,16 @@ https://developers.google.com/style/inclusive-documentation
-->
### Skipped tasks can satisfy wait_for_downstream
Previously, a task instance with `wait_for_downstream=True` will only run if the downstream task of
the previous task instance is successful. Meanwhile, a task instance with `depends_on_past=True`
will run if the previous task instance is either successful or skipped. These two flags are close siblings
yet they have different behavior. This inconsistency in behavior made the API less intuitive to users.
To maintain consistent behavior, both successful or skipped downstream task can now satisfy the
`wait_for_downstream=True` flag.
### Ability to patch Pool.DEFAULT_POOL_NAME in BaseOperator
It was not possible to patch pool in BaseOperator as the signature sets the default value of pool
as Pool.DEFAULT_POOL_NAME.

Просмотреть файл

@ -127,12 +127,12 @@ class BaseOperator(Operator, LoggingMixin):
:param end_date: if specified, the scheduler won't go beyond this date
:type end_date: datetime.datetime
:param depends_on_past: when set to true, task instances will run
sequentially while relying on the previous task's schedule to
succeed. The task instance for the start_date is allowed to run.
sequentially and only if the previous instance has succeeded or has been skipped.
The task instance for the start_date is allowed to run.
:type depends_on_past: bool
:param wait_for_downstream: when set to true, an instance of task
X will wait for tasks immediately downstream of the previous instance
of task X to finish successfully before it runs. This is useful if the
of task X to finish successfully or be skipped before it runs. This is useful if the
different instances of a task X alter the same asset, and this asset
is used by tasks downstream of task X. Note that depends_on_past
is forced to True wherever wait_for_downstream is used. Also note that

Просмотреть файл

@ -540,7 +540,7 @@ class TaskInstance(Base, LoggingMixin):
@provide_session
def are_dependents_done(self, session=None):
"""
Checks whether the dependents of this task instance have all succeeded.
Checks whether the immediate dependents of this task instance have succeeded or have been skipped.
This is meant to be used by wait_for_downstream.
This is useful when you do not want to start processing the next
@ -556,7 +556,7 @@ class TaskInstance(Base, LoggingMixin):
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(task.downstream_task_ids),
TaskInstance.execution_date == self.execution_date,
TaskInstance.state == State.SUCCESS,
TaskInstance.state.in_([State.SKIPPED, State.SUCCESS]),
)
count = ti[0][0]
return count == len(task.downstream_task_ids)

Просмотреть файл

@ -47,12 +47,7 @@ dag1_task1 = DummyOperator(
dag=dag1,
pool='test_backfill_pooled_task_pool',)
# DAG tests depends_on_past dependencies
dag2 = DAG(dag_id='test_depends_on_past', default_args=default_args)
dag2_task1 = DummyOperator(
task_id='test_dop_task',
dag=dag2,
depends_on_past=True,)
# dag2 has been moved to test_prev_dagrun_dep.py
# DAG tests that a Dag run that doesn't complete is marked failed
dag3 = DAG(dag_id='test_dagrun_states_fail', default_args=default_args)

Просмотреть файл

@ -0,0 +1,37 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from airflow.models import DAG
from airflow.operators.dummy_operator import DummyOperator
DEFAULT_DATE = datetime(2016, 1, 1)
default_args = dict(start_date=DEFAULT_DATE, owner="airflow")
# DAG tests depends_on_past dependencies
dag_dop = DAG(dag_id="test_depends_on_past", default_args=default_args)
with dag_dop:
dag_dop_task = DummyOperator(task_id="test_dop_task", depends_on_past=True,)
# DAG tests wait_for_downstream dependencies
dag_wfd = DAG(dag_id="test_wait_for_downstream", default_args=default_args)
with dag_wfd:
dag_wfd_upstream = DummyOperator(task_id="upstream_task", wait_for_downstream=True,)
dag_wfd_downstream = DummyOperator(task_id="downstream_task",)
dag_wfd_upstream >> dag_wfd_downstream

Просмотреть файл

@ -19,8 +19,10 @@
import datetime
import unittest
from parameterized import parameterized
from airflow import models, settings
from airflow.models import DAG, TaskInstance as TI, clear_task_instances
from airflow.models import DAG, DagBag, TaskInstance as TI, clear_task_instances
from airflow.models.dagrun import DagRun
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python import ShortCircuitOperator
@ -29,10 +31,19 @@ from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import DagRunType
from tests.models import DEFAULT_DATE
from tests.test_utils.db import clear_db_pools, clear_db_runs
class TestDagRun(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dagbag = DagBag(include_examples=True)
def setUp(self):
clear_db_runs()
clear_db_pools()
def create_dag_run(self, dag,
state=State.RUNNING,
task_states=None,
@ -552,3 +563,55 @@ class TestDagRun(unittest.TestCase):
dagrun.verify_integrity()
flaky_ti.refresh_from_db()
self.assertEqual(State.NONE, flaky_ti.state)
@parameterized.expand([
(State.SUCCESS, True),
(State.SKIPPED, True),
(State.RUNNING, False),
(State.FAILED, False),
(State.NONE, False),
])
def test_depends_on_past(self, prev_ti_state, is_ti_success):
dag_id = 'test_depends_on_past'
dag = self.dagbag.get_dag(dag_id)
task = dag.tasks[0]
self.create_dag_run(dag, execution_date=timezone.datetime(2016, 1, 1, 0, 0, 0))
self.create_dag_run(dag, execution_date=timezone.datetime(2016, 1, 2, 0, 0, 0))
prev_ti = TI(task, timezone.datetime(2016, 1, 1, 0, 0, 0))
ti = TI(task, timezone.datetime(2016, 1, 2, 0, 0, 0))
prev_ti.set_state(prev_ti_state)
ti.set_state(State.QUEUED)
ti.run()
self.assertEqual(ti.state == State.SUCCESS, is_ti_success)
@parameterized.expand([
(State.SUCCESS, True),
(State.SKIPPED, True),
(State.RUNNING, False),
(State.FAILED, False),
(State.NONE, False),
])
def test_wait_for_downstream(self, prev_ti_state, is_ti_success):
dag_id = 'test_wait_for_downstream'
dag = self.dagbag.get_dag(dag_id)
upstream, downstream = dag.tasks
# For ti.set_state() to work, the DagRun has to exist,
# Otherwise ti.previous_ti returns an unpersisted TI
self.create_dag_run(dag, execution_date=timezone.datetime(2016, 1, 1, 0, 0, 0))
self.create_dag_run(dag, execution_date=timezone.datetime(2016, 1, 2, 0, 0, 0))
prev_ti_downstream = TI(task=downstream, execution_date=timezone.datetime(2016, 1, 1, 0, 0, 0))
ti = TI(task=upstream, execution_date=timezone.datetime(2016, 1, 2, 0, 0, 0))
prev_ti = ti.get_previous_ti()
prev_ti.set_state(State.SUCCESS)
self.assertEqual(prev_ti.state, State.SUCCESS)
prev_ti_downstream.set_state(prev_ti_state)
ti.set_state(State.QUEUED)
ti.run()
self.assertEqual(ti.state == State.SUCCESS, is_ti_success)

Просмотреть файл

@ -855,6 +855,38 @@ class TestTaskInstance(unittest.TestCase):
self.assertEqual(completed, expect_completed)
self.assertEqual(ti.state, expect_state)
def test_respects_prev_dagrun_dep(self):
with DAG(dag_id='test_dag'):
task = DummyOperator(task_id='task', start_date=DEFAULT_DATE)
ti = TI(task, DEFAULT_DATE)
failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')]
passing_status = [TIDepStatus('test pass status name', True, 'test passing reason')]
with patch('airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses',
return_value=failing_status):
self.assertFalse(ti.are_dependencies_met())
with patch('airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses',
return_value=passing_status):
self.assertTrue(ti.are_dependencies_met())
@parameterized.expand([
(State.SUCCESS, True),
(State.SKIPPED, True),
(State.RUNNING, False),
(State.FAILED, False),
(State.NONE, False),
])
def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_done):
with DAG(dag_id='test_dag'):
task = DummyOperator(task_id='task', start_date=DEFAULT_DATE)
downstream_task = DummyOperator(task_id='downstream_task', start_date=DEFAULT_DATE)
task >> downstream_task
ti = TI(task, DEFAULT_DATE)
downstream_ti = TI(downstream_task, DEFAULT_DATE)
downstream_ti.set_state(downstream_ti_state)
self.assertEqual(ti.are_dependents_done(), expected_are_dependents_done)
def test_xcom_pull(self):
"""
Test xcom_pull, using different filtering methods.