Resolve upstream tasks when template field is XComArg (#8805)

* Resolve upstream tasks when template field is XComArg

closes: #8054

* fixup! Resolve upstream tasks when template field is XComArg

* Resolve task relations in DagRun and DagBag

* Add tests for serialized DAG

* Set dependencies only in bag_dag, refactor tests

* Traverse template_fields attribute

* Use provide_test_dag_bag in all tests

* fixup! Use provide_test_dag_bag in all tests

* Use metaclass + setattr

* Add prepare_for_execution method

* Check signature of __init__ not class

* Apply suggestions from code review

Co-authored-by: Ash Berlin-Taylor <ash_github@firemirror.com>

* Update airflow/models/baseoperator.py

Co-authored-by: Ash Berlin-Taylor <ash_github@firemirror.com>
This commit is contained in:
Tomek Urbaszek 2020-06-15 12:29:16 +02:00 коммит произвёл GitHub
Родитель aee6ab94eb
Коммит 431ea3291c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 211 добавлений и 11 удалений

Просмотреть файл

@ -0,0 +1,51 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example DAG demonstrating the usage of the XComArgs."""
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
args = {
'owner': 'airflow',
'start_date': days_ago(2),
}
def dummy(*args, **kwargs):
"""Dummy function"""
return "pass"
with DAG(
dag_id='example_xcom_args',
default_args=args,
schedule_interval=None,
tags=['example']
) as dag:
task1 = PythonOperator(
task_id='task1',
python_callable=dummy,
)
task2 = PythonOperator(
task_id='task2',
python_callable=dummy,
op_kwargs={"dummy": task1.output},
)

Просмотреть файл

@ -18,6 +18,7 @@
"""
Base operator for all operators.
"""
import abc
import copy
import functools
import logging
@ -60,9 +61,29 @@ from airflow.utils.weight_rule import WeightRule
ScheduleInterval = Union[str, timedelta, relativedelta]
class BaseOperatorMeta(abc.ABCMeta):
"""
Base metaclass of BaseOperator.
"""
def __call__(cls, *args, **kwargs):
"""
Called when you call BaseOperator(). In this way we are able to perform an action
after initializing an operator no matter where the ``super().__init__`` is called
(before or after assign of new attributes in a custom operator).
"""
obj: BaseOperator = type.__call__(cls, *args, **kwargs)
# Here we set upstream task defined by XComArgs passed to template fields of the operator
obj.set_xcomargs_dependencies()
# Mark instance as instantiated https://docs.python.org/3/tutorial/classes.html#private-variables
obj._BaseOperator__instantiated = True
return obj
# pylint: disable=too-many-instance-attributes,too-many-public-methods
@functools.total_ordering
class BaseOperator(Operator, LoggingMixin):
class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta):
"""
Abstract base class for all operators. Since operators create objects that
become nodes in the dag, BaseOperator contains many recursive methods for
@ -292,6 +313,12 @@ class BaseOperator(Operator, LoggingMixin):
# Defines if the operator supports lineage without manual definitions
supports_lineage = False
# If True then the class constructor was called
__instantiated = False
# Set to True before calling execute method
_lock_for_execution = False
# noinspection PyUnusedLocal
# pylint: disable=too-many-arguments,too-many-locals, too-many-statements
@apply_defaults
@ -547,6 +574,18 @@ class BaseOperator(Operator, LoggingMixin):
return self
def __setattr__(self, key, value):
super().__setattr__(key, value)
if self._lock_for_execution:
# Skip any custom behaviour during execute
return
if self.__instantiated and key in self.template_fields:
# Resolve upstreams set by assigning an XComArg after initializing
# an operator, example:
# op = BashOperator()
# op.bash_command = "sleep 1"
self.set_xcomargs_dependencies()
def add_inlets(self, inlets: Iterable[Any]):
"""
Sets inlets to this operator
@ -633,6 +672,56 @@ class BaseOperator(Operator, LoggingMixin):
NotPreviouslySkippedDep(),
}
def prepare_for_execution(self) -> "BaseOperator":
"""
Lock task for execution to disable custom action in __setattr__ and
returns a copy of the task
"""
other = copy.copy(self)
other._lock_for_execution = True # pylint: disable=protected-access
return other
def set_xcomargs_dependencies(self) -> None:
"""
Resolves upstream dependencies of a task. In this way passing an ``XComArg``
as value for a template field will result in creating upstream relation between
two tasks.
**Example**: ::
with DAG(...):
generate_content = GenerateContentOperator(task_id="generate_content")
send_email = EmailOperator(..., html_content=generate_content.output)
# This is equivalent to
with DAG(...):
generate_content = GenerateContentOperator(task_id="generate_content")
send_email = EmailOperator(
..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
)
generate_content >> send_email
"""
from airflow.models.xcom_arg import XComArg
def apply_set_upstream(arg: Any):
if isinstance(arg, XComArg):
self.set_upstream(arg.operator)
elif isinstance(arg, (tuple, set, list)):
for elem in arg:
apply_set_upstream(elem)
elif isinstance(arg, dict):
for elem in arg.values():
apply_set_upstream(elem)
elif hasattr(arg, "template_fields"):
for elem in arg.template_fields:
apply_set_upstream(elem)
for field in self.template_fields:
if hasattr(self, field):
arg = getattr(self, field)
apply_set_upstream(arg)
@property
def priority_weight_total(self) -> int:
"""
@ -1140,7 +1229,7 @@ class BaseOperator(Operator, LoggingMixin):
@property
def output(self):
"""Returns default XComArg for the operator"""
"""Returns reference to XCom pushed by current operator"""
from airflow.models.xcom_arg import XComArg
return XComArg(operator=self)
@ -1205,7 +1294,8 @@ class BaseOperator(Operator, LoggingMixin):
if not cls.__serialized_fields:
cls.__serialized_fields = frozenset(
vars(BaseOperator(task_id='test')).keys() - {
'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag'
'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag',
'_BaseOperator__instantiated',
} | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'})
return cls.__serialized_fields

Просмотреть файл

@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
import copy
import getpass
import hashlib
import logging
@ -970,7 +969,7 @@ class TaskInstance(Base, LoggingMixin):
if not mark_success:
context = self.get_template_context()
task_copy = copy.copy(task)
task_copy = task.prepare_for_execution()
# Sensors in `poke` mode can block execution of DAGs when running
# with single process executor, thus we change the mode to`reschedule`
@ -1154,7 +1153,7 @@ class TaskInstance(Base, LoggingMixin):
def dry_run(self):
task = self.task
task_copy = copy.copy(task)
task_copy = task.prepare_for_execution()
self.task = task_copy
self.render_templates()

Просмотреть файл

@ -31,7 +31,7 @@ from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.utils.decorators import apply_defaults
class BaseSQLToGCSOperator(BaseOperator, metaclass=abc.ABCMeta):
class BaseSQLToGCSOperator(BaseOperator):
"""
:param sql: The SQL to execute.
:type sql: str

Просмотреть файл

@ -299,7 +299,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
_decorated_fields = {'executor_config'}
_CONSTRUCTOR_PARAMS = {
k: v.default for k, v in signature(BaseOperator).parameters.items()
k: v.default for k, v in signature(BaseOperator.__init__).parameters.items()
if v.default is not v.empty
}
@ -537,7 +537,7 @@ class SerializedDAG(DAG, BaseSerialization):
'access_control': '_access_control',
}
return {
param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items()
param_to_attr.get(k, k): v.default for k, v in signature(DAG.__init__).parameters.items()
if v.default is not v.empty
}

Просмотреть файл

@ -15,13 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
import uuid
from datetime import date, datetime
from unittest import mock
import jinja2
import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
@ -29,6 +29,7 @@ from airflow.lineage.entities import File
from airflow.models import DAG
from airflow.models.baseoperator import chain, cross_downstream
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.decorators import apply_defaults
from tests.models import DEFAULT_DATE
from tests.test_utils.mock_operators import MockNamedTuple, MockOperator
@ -347,3 +348,61 @@ class TestBaseOperatorMethods(unittest.TestCase):
task4 = DummyOperator(task_id="op4", dag=dag)
task4 > [inlet, outlet, extra]
self.assertEqual(task4.get_outlet_defs(), [inlet, outlet, extra])
class CustomOp(DummyOperator):
template_fields = ("field", "field2")
@apply_defaults
def __init__(self, field=None, field2=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.field = field
self.field2 = field2
def execute(self, context):
self.field = None
class TestXComArgsRelationsAreResolved:
def test_setattr_performs_no_custom_action_at_execute_time(self):
op = CustomOp(task_id="test_task")
op_copy = op.prepare_for_execution()
with mock.patch(
"airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies"
) as method_mock:
op_copy.execute({})
assert method_mock.call_count == 0
def test_upstream_is_set_when_template_field_is_xcomarg(self):
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = CustomOp(task_id="op2", field=op1.output)
assert op1 in op2.upstream_list
assert op2 in op1.downstream_list
def test_set_xcomargs_dependencies_works_recursively(self):
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = DummyOperator(task_id="op2")
op3 = CustomOp(task_id="op3", field=[op1.output, op2.output])
op4 = CustomOp(task_id="op4", field={"op1": op1.output, "op2": op2.output})
assert op1 in op3.upstream_list
assert op2 in op3.upstream_list
assert op1 in op4.upstream_list
assert op2 in op4.upstream_list
def test_set_xcomargs_dependencies_works_when_set_after_init(self):
with DAG(dag_id='xcomargs_test', default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = CustomOp(task_id="op2")
op2.field = op1.output # value is set after init
assert op1 in op2.upstream_list
def test_set_xcomargs_dependencies_error_when_outside_dag(self):
with pytest.raises(AirflowException):
op1 = DummyOperator(task_id="op1")
CustomOp(task_id="op2", field=op1.output)

Просмотреть файл

@ -726,7 +726,8 @@ class TestStringifiedDAGs(unittest.TestCase):
"""
base_operator = BaseOperator(task_id="10")
fields = base_operator.__dict__
self.assertEqual({'_dag': None,
self.assertEqual({'_BaseOperator__instantiated': True,
'_dag': None,
'_downstream_task_ids': set(),
'_inlets': [],
'_log': base_operator.log,