Resolve upstream tasks when template field is XComArg (#8805)
* Resolve upstream tasks when template field is XComArg closes: #8054 * fixup! Resolve upstream tasks when template field is XComArg * Resolve task relations in DagRun and DagBag * Add tests for serialized DAG * Set dependencies only in bag_dag, refactor tests * Traverse template_fields attribute * Use provide_test_dag_bag in all tests * fixup! Use provide_test_dag_bag in all tests * Use metaclass + setattr * Add prepare_for_execution method * Check signature of __init__ not class * Apply suggestions from code review Co-authored-by: Ash Berlin-Taylor <ash_github@firemirror.com> * Update airflow/models/baseoperator.py Co-authored-by: Ash Berlin-Taylor <ash_github@firemirror.com>
This commit is contained in:
Родитель
aee6ab94eb
Коммит
431ea3291c
|
@ -0,0 +1,51 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Example DAG demonstrating the usage of the XComArgs."""
|
||||
|
||||
from airflow import DAG
|
||||
from airflow.operators.python import PythonOperator
|
||||
from airflow.utils.dates import days_ago
|
||||
|
||||
args = {
|
||||
'owner': 'airflow',
|
||||
'start_date': days_ago(2),
|
||||
}
|
||||
|
||||
|
||||
def dummy(*args, **kwargs):
|
||||
"""Dummy function"""
|
||||
return "pass"
|
||||
|
||||
|
||||
with DAG(
|
||||
dag_id='example_xcom_args',
|
||||
default_args=args,
|
||||
schedule_interval=None,
|
||||
tags=['example']
|
||||
) as dag:
|
||||
task1 = PythonOperator(
|
||||
task_id='task1',
|
||||
python_callable=dummy,
|
||||
)
|
||||
|
||||
task2 = PythonOperator(
|
||||
task_id='task2',
|
||||
python_callable=dummy,
|
||||
op_kwargs={"dummy": task1.output},
|
||||
)
|
|
@ -18,6 +18,7 @@
|
|||
"""
|
||||
Base operator for all operators.
|
||||
"""
|
||||
import abc
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
|
@ -60,9 +61,29 @@ from airflow.utils.weight_rule import WeightRule
|
|||
ScheduleInterval = Union[str, timedelta, relativedelta]
|
||||
|
||||
|
||||
class BaseOperatorMeta(abc.ABCMeta):
|
||||
"""
|
||||
Base metaclass of BaseOperator.
|
||||
"""
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
"""
|
||||
Called when you call BaseOperator(). In this way we are able to perform an action
|
||||
after initializing an operator no matter where the ``super().__init__`` is called
|
||||
(before or after assign of new attributes in a custom operator).
|
||||
"""
|
||||
obj: BaseOperator = type.__call__(cls, *args, **kwargs)
|
||||
# Here we set upstream task defined by XComArgs passed to template fields of the operator
|
||||
obj.set_xcomargs_dependencies()
|
||||
|
||||
# Mark instance as instantiated https://docs.python.org/3/tutorial/classes.html#private-variables
|
||||
obj._BaseOperator__instantiated = True
|
||||
return obj
|
||||
|
||||
|
||||
# pylint: disable=too-many-instance-attributes,too-many-public-methods
|
||||
@functools.total_ordering
|
||||
class BaseOperator(Operator, LoggingMixin):
|
||||
class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta):
|
||||
"""
|
||||
Abstract base class for all operators. Since operators create objects that
|
||||
become nodes in the dag, BaseOperator contains many recursive methods for
|
||||
|
@ -292,6 +313,12 @@ class BaseOperator(Operator, LoggingMixin):
|
|||
# Defines if the operator supports lineage without manual definitions
|
||||
supports_lineage = False
|
||||
|
||||
# If True then the class constructor was called
|
||||
__instantiated = False
|
||||
|
||||
# Set to True before calling execute method
|
||||
_lock_for_execution = False
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
# pylint: disable=too-many-arguments,too-many-locals, too-many-statements
|
||||
@apply_defaults
|
||||
|
@ -547,6 +574,18 @@ class BaseOperator(Operator, LoggingMixin):
|
|||
|
||||
return self
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
super().__setattr__(key, value)
|
||||
if self._lock_for_execution:
|
||||
# Skip any custom behaviour during execute
|
||||
return
|
||||
if self.__instantiated and key in self.template_fields:
|
||||
# Resolve upstreams set by assigning an XComArg after initializing
|
||||
# an operator, example:
|
||||
# op = BashOperator()
|
||||
# op.bash_command = "sleep 1"
|
||||
self.set_xcomargs_dependencies()
|
||||
|
||||
def add_inlets(self, inlets: Iterable[Any]):
|
||||
"""
|
||||
Sets inlets to this operator
|
||||
|
@ -633,6 +672,56 @@ class BaseOperator(Operator, LoggingMixin):
|
|||
NotPreviouslySkippedDep(),
|
||||
}
|
||||
|
||||
def prepare_for_execution(self) -> "BaseOperator":
|
||||
"""
|
||||
Lock task for execution to disable custom action in __setattr__ and
|
||||
returns a copy of the task
|
||||
"""
|
||||
other = copy.copy(self)
|
||||
other._lock_for_execution = True # pylint: disable=protected-access
|
||||
return other
|
||||
|
||||
def set_xcomargs_dependencies(self) -> None:
|
||||
"""
|
||||
Resolves upstream dependencies of a task. In this way passing an ``XComArg``
|
||||
as value for a template field will result in creating upstream relation between
|
||||
two tasks.
|
||||
|
||||
**Example**: ::
|
||||
|
||||
with DAG(...):
|
||||
generate_content = GenerateContentOperator(task_id="generate_content")
|
||||
send_email = EmailOperator(..., html_content=generate_content.output)
|
||||
|
||||
# This is equivalent to
|
||||
with DAG(...):
|
||||
generate_content = GenerateContentOperator(task_id="generate_content")
|
||||
send_email = EmailOperator(
|
||||
..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
|
||||
)
|
||||
generate_content >> send_email
|
||||
|
||||
"""
|
||||
from airflow.models.xcom_arg import XComArg
|
||||
|
||||
def apply_set_upstream(arg: Any):
|
||||
if isinstance(arg, XComArg):
|
||||
self.set_upstream(arg.operator)
|
||||
elif isinstance(arg, (tuple, set, list)):
|
||||
for elem in arg:
|
||||
apply_set_upstream(elem)
|
||||
elif isinstance(arg, dict):
|
||||
for elem in arg.values():
|
||||
apply_set_upstream(elem)
|
||||
elif hasattr(arg, "template_fields"):
|
||||
for elem in arg.template_fields:
|
||||
apply_set_upstream(elem)
|
||||
|
||||
for field in self.template_fields:
|
||||
if hasattr(self, field):
|
||||
arg = getattr(self, field)
|
||||
apply_set_upstream(arg)
|
||||
|
||||
@property
|
||||
def priority_weight_total(self) -> int:
|
||||
"""
|
||||
|
@ -1140,7 +1229,7 @@ class BaseOperator(Operator, LoggingMixin):
|
|||
|
||||
@property
|
||||
def output(self):
|
||||
"""Returns default XComArg for the operator"""
|
||||
"""Returns reference to XCom pushed by current operator"""
|
||||
from airflow.models.xcom_arg import XComArg
|
||||
return XComArg(operator=self)
|
||||
|
||||
|
@ -1205,7 +1294,8 @@ class BaseOperator(Operator, LoggingMixin):
|
|||
if not cls.__serialized_fields:
|
||||
cls.__serialized_fields = frozenset(
|
||||
vars(BaseOperator(task_id='test')).keys() - {
|
||||
'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag'
|
||||
'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag',
|
||||
'_BaseOperator__instantiated',
|
||||
} | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'})
|
||||
|
||||
return cls.__serialized_fields
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import copy
|
||||
import getpass
|
||||
import hashlib
|
||||
import logging
|
||||
|
@ -970,7 +969,7 @@ class TaskInstance(Base, LoggingMixin):
|
|||
if not mark_success:
|
||||
context = self.get_template_context()
|
||||
|
||||
task_copy = copy.copy(task)
|
||||
task_copy = task.prepare_for_execution()
|
||||
|
||||
# Sensors in `poke` mode can block execution of DAGs when running
|
||||
# with single process executor, thus we change the mode to`reschedule`
|
||||
|
@ -1154,7 +1153,7 @@ class TaskInstance(Base, LoggingMixin):
|
|||
|
||||
def dry_run(self):
|
||||
task = self.task
|
||||
task_copy = copy.copy(task)
|
||||
task_copy = task.prepare_for_execution()
|
||||
self.task = task_copy
|
||||
|
||||
self.render_templates()
|
||||
|
|
|
@ -31,7 +31,7 @@ from airflow.providers.google.cloud.hooks.gcs import GCSHook
|
|||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class BaseSQLToGCSOperator(BaseOperator, metaclass=abc.ABCMeta):
|
||||
class BaseSQLToGCSOperator(BaseOperator):
|
||||
"""
|
||||
:param sql: The SQL to execute.
|
||||
:type sql: str
|
||||
|
|
|
@ -299,7 +299,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
|
|||
_decorated_fields = {'executor_config'}
|
||||
|
||||
_CONSTRUCTOR_PARAMS = {
|
||||
k: v.default for k, v in signature(BaseOperator).parameters.items()
|
||||
k: v.default for k, v in signature(BaseOperator.__init__).parameters.items()
|
||||
if v.default is not v.empty
|
||||
}
|
||||
|
||||
|
@ -537,7 +537,7 @@ class SerializedDAG(DAG, BaseSerialization):
|
|||
'access_control': '_access_control',
|
||||
}
|
||||
return {
|
||||
param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items()
|
||||
param_to_attr.get(k, k): v.default for k, v in signature(DAG.__init__).parameters.items()
|
||||
if v.default is not v.empty
|
||||
}
|
||||
|
||||
|
|
|
@ -15,13 +15,13 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
import uuid
|
||||
from datetime import date, datetime
|
||||
from unittest import mock
|
||||
|
||||
import jinja2
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
|
@ -29,6 +29,7 @@ from airflow.lineage.entities import File
|
|||
from airflow.models import DAG
|
||||
from airflow.models.baseoperator import chain, cross_downstream
|
||||
from airflow.operators.dummy_operator import DummyOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
from tests.models import DEFAULT_DATE
|
||||
from tests.test_utils.mock_operators import MockNamedTuple, MockOperator
|
||||
|
||||
|
@ -347,3 +348,61 @@ class TestBaseOperatorMethods(unittest.TestCase):
|
|||
task4 = DummyOperator(task_id="op4", dag=dag)
|
||||
task4 > [inlet, outlet, extra]
|
||||
self.assertEqual(task4.get_outlet_defs(), [inlet, outlet, extra])
|
||||
|
||||
|
||||
class CustomOp(DummyOperator):
|
||||
template_fields = ("field", "field2")
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self, field=None, field2=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.field = field
|
||||
self.field2 = field2
|
||||
|
||||
def execute(self, context):
|
||||
self.field = None
|
||||
|
||||
|
||||
class TestXComArgsRelationsAreResolved:
|
||||
def test_setattr_performs_no_custom_action_at_execute_time(self):
|
||||
op = CustomOp(task_id="test_task")
|
||||
op_copy = op.prepare_for_execution()
|
||||
|
||||
with mock.patch(
|
||||
"airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies"
|
||||
) as method_mock:
|
||||
op_copy.execute({})
|
||||
assert method_mock.call_count == 0
|
||||
|
||||
def test_upstream_is_set_when_template_field_is_xcomarg(self):
|
||||
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
|
||||
op1 = DummyOperator(task_id="op1")
|
||||
op2 = CustomOp(task_id="op2", field=op1.output)
|
||||
|
||||
assert op1 in op2.upstream_list
|
||||
assert op2 in op1.downstream_list
|
||||
|
||||
def test_set_xcomargs_dependencies_works_recursively(self):
|
||||
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
|
||||
op1 = DummyOperator(task_id="op1")
|
||||
op2 = DummyOperator(task_id="op2")
|
||||
op3 = CustomOp(task_id="op3", field=[op1.output, op2.output])
|
||||
op4 = CustomOp(task_id="op4", field={"op1": op1.output, "op2": op2.output})
|
||||
|
||||
assert op1 in op3.upstream_list
|
||||
assert op2 in op3.upstream_list
|
||||
assert op1 in op4.upstream_list
|
||||
assert op2 in op4.upstream_list
|
||||
|
||||
def test_set_xcomargs_dependencies_works_when_set_after_init(self):
|
||||
with DAG(dag_id='xcomargs_test', default_args={"start_date": datetime.today()}):
|
||||
op1 = DummyOperator(task_id="op1")
|
||||
op2 = CustomOp(task_id="op2")
|
||||
op2.field = op1.output # value is set after init
|
||||
|
||||
assert op1 in op2.upstream_list
|
||||
|
||||
def test_set_xcomargs_dependencies_error_when_outside_dag(self):
|
||||
with pytest.raises(AirflowException):
|
||||
op1 = DummyOperator(task_id="op1")
|
||||
CustomOp(task_id="op2", field=op1.output)
|
||||
|
|
|
@ -726,7 +726,8 @@ class TestStringifiedDAGs(unittest.TestCase):
|
|||
"""
|
||||
base_operator = BaseOperator(task_id="10")
|
||||
fields = base_operator.__dict__
|
||||
self.assertEqual({'_dag': None,
|
||||
self.assertEqual({'_BaseOperator__instantiated': True,
|
||||
'_dag': None,
|
||||
'_downstream_task_ids': set(),
|
||||
'_inlets': [],
|
||||
'_log': base_operator.log,
|
||||
|
|
Загрузка…
Ссылка в новой задаче