incubator-airflow/airflow/models/xcom_arg.py

141 строка
4.9 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Sequence, Union
from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator # pylint: disable=R0401
from airflow.models.taskmixin import TaskMixin
from airflow.models.xcom import XCOM_RETURN_KEY
class XComArg(TaskMixin):
"""
Class that represents a XCom push from a previous operator.
Defaults to "return_value" as only key.
Current implementation supports
xcomarg >> op
xcomarg << op
op >> xcomarg (by BaseOperator code)
op << xcomarg (by BaseOperator code)
**Example**: The moment you get a result from any operator (decorated or regular) you can ::
any_op = AnyOperator()
xcomarg = XComArg(any_op)
# or equivalently
xcomarg = any_op.output
my_op = MyOperator()
my_op >> xcomarg
This object can be used in legacy Operators via Jinja.
**Example**: You can make this result to be part of any generated string ::
any_op = AnyOperator()
xcomarg = any_op.output
op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
:param operator: operator to which the XComArg belongs to
:type operator: airflow.models.baseoperator.BaseOperator
:param key: key value which is used for xcom_pull (key in the XCom table)
:type key: str
"""
def __init__(self, operator: BaseOperator, key: str = XCOM_RETURN_KEY):
self._operator = operator
self._key = key
def __eq__(self, other):
return self.operator == other.operator and self.key == other.key
def __getitem__(self, item):
"""Implements xcomresult['some_result_key']"""
return XComArg(operator=self.operator, key=item)
def __str__(self):
"""
Backward compatibility for old-style jinja used in Airflow Operators
**Example**: to use XComArg at BashOperator::
BashOperator(cmd=f"... { xcomarg } ...")
:return:
"""
xcom_pull_kwargs = [
f"task_ids='{self.operator.task_id}'",
f"dag_id='{self.operator.dag.dag_id}'",
]
if self.key is not None:
xcom_pull_kwargs.append(f"key='{self.key}'")
xcom_pull_kwargs = ", ".join(xcom_pull_kwargs)
# {{{{ are required for escape {{ in f-string
xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}"
return xcom_pull
@property
def operator(self) -> BaseOperator:
"""Returns operator of this XComArg."""
return self._operator
@property
def roots(self) -> List[BaseOperator]:
"""Required by TaskMixin"""
return [self._operator]
@property
def leaves(self) -> List[BaseOperator]:
"""Required by TaskMixin"""
return [self._operator]
@property
def key(self) -> str:
"""Returns keys of this XComArg"""
return self._key
def set_upstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]):
"""Proxy to underlying operator set_upstream method. Required by TaskMixin."""
self.operator.set_upstream(task_or_task_list)
def set_downstream(self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]]):
"""Proxy to underlying operator set_downstream method. Required by TaskMixin."""
self.operator.set_downstream(task_or_task_list)
def resolve(self, context: Dict) -> Any:
"""
Pull XCom value for the existing arg. This method is run during ``op.execute()``
in respectable context.
"""
resolved_value = self.operator.xcom_pull(
context=context,
task_ids=[self.operator.task_id],
key=str(self.key), # xcom_pull supports only key as str
dag_id=self.operator.dag.dag_id,
)
if not resolved_value:
raise AirflowException(
f'XComArg result from {self.operator.task_id} at {self.operator.dag.dag_id} '
f'with key="{self.key}"" is not found!'
)
resolved_value = resolved_value[0]
return resolved_value