incubator-airflow/airflow/models/xcom.py

295 строки
10 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
import pickle
from json import JSONDecodeError
from typing import Any, Iterable, Optional, Union
import pendulum
from sqlalchemy import Column, LargeBinary, String, and_
from sqlalchemy.orm import Query, Session, reconstructor
from airflow.configuration import conf
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from airflow.utils import timezone
from airflow.utils.helpers import is_container
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import UtcDateTime
log = logging.getLogger(__name__)
# MAX XCOM Size is 48KB
# https://github.com/apache/airflow/pull/1618#discussion_r68249677
MAX_XCOM_SIZE = 49344
XCOM_RETURN_KEY = 'return_value'
class BaseXCom(Base, LoggingMixin):
"""
Base class for XCom objects.
"""
__tablename__ = "xcom"
key = Column(String(512, **COLLATION_ARGS), primary_key=True)
value = Column(LargeBinary)
timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
execution_date = Column(UtcDateTime, primary_key=True)
# source information
task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
"""
TODO: "pickling" has been deprecated and JSON is preferred.
"pickling" will be removed in Airflow 2.0.
"""
@reconstructor
def init_on_load(self):
"""
Called by the ORM after the instance has been loaded from the DB or otherwise reconstituted
i.e automatically deserialize Xcom value when loading from DB.
"""
try:
self.value = XCom.deserialize_value(self)
except (UnicodeEncodeError, ValueError):
# For backward-compatibility.
# Preventing errors in webserver
# due to XComs mixed with pickled and unpickled.
self.value = pickle.loads(self.value)
def __repr__(self):
return '<XCom "{key}" ({task_id} @ {execution_date})>'.format(
key=self.key,
task_id=self.task_id,
execution_date=self.execution_date)
@classmethod
@provide_session
def set(
cls,
key,
value,
execution_date,
task_id,
dag_id,
session=None):
"""
Store an XCom value.
TODO: "pickling" has been deprecated and JSON is preferred.
"pickling" will be removed in Airflow 2.0.
:return: None
"""
session.expunge_all()
value = XCom.serialize_value(value)
# remove any duplicate XComs
session.query(cls).filter(
cls.key == key,
cls.execution_date == execution_date,
cls.task_id == task_id,
cls.dag_id == dag_id).delete()
session.commit()
# insert new XCom
session.add(XCom(
key=key,
value=value,
execution_date=execution_date,
task_id=task_id,
dag_id=dag_id))
session.commit()
@classmethod
@provide_session
def get_one(cls,
execution_date: pendulum.DateTime,
key: Optional[str] = None,
task_id: Optional[Union[str, Iterable[str]]] = None,
dag_id: Optional[Union[str, Iterable[str]]] = None,
include_prior_dates: bool = False,
session: Session = None) -> Optional[Any]:
"""
Retrieve an XCom value, optionally meeting certain criteria. Returns None
of there are no results.
:param execution_date: Execution date for the task
:type execution_date: pendulum.datetime
:param key: A key for the XCom. If provided, only XComs with matching
keys will be returned. To remove the filter, pass key=None.
:type key: str
:param task_id: Only XComs from task with matching id will be
pulled. Can pass None to remove the filter.
:type task_id: str
:param dag_id: If provided, only pulls XCom from this DAG.
If None (default), the DAG of the calling task is used.
:type dag_id: str
:param include_prior_dates: If False, only XCom from the current
execution_date are returned. If True, XCom from previous dates
are returned as well.
:type include_prior_dates: bool
:param session: database session
:type session: sqlalchemy.orm.session.Session
"""
result = cls.get_many(execution_date=execution_date,
key=key,
task_ids=task_id,
dag_ids=dag_id,
include_prior_dates=include_prior_dates,
session=session).first()
if result:
return result.value
return None
@classmethod
@provide_session
def get_many(cls,
execution_date: pendulum.DateTime,
key: Optional[str] = None,
task_ids: Optional[Union[str, Iterable[str]]] = None,
dag_ids: Optional[Union[str, Iterable[str]]] = None,
include_prior_dates: bool = False,
limit: Optional[int] = None,
session: Session = None) -> Query:
"""
Composes a query to get one or more values from the xcom table.
:param execution_date: Execution date for the task
:type execution_date: pendulum.datetime
:param key: A key for the XCom. If provided, only XComs with matching
keys will be returned. To remove the filter, pass key=None.
:type key: str
:param task_ids: Only XComs from tasks with matching ids will be
pulled. Can pass None to remove the filter.
:type task_ids: str or iterable of strings (representing task_ids)
:param dag_ids: If provided, only pulls XComs from this DAG.
If None (default), the DAG of the calling task is used.
:type dag_ids: str
:param include_prior_dates: If False, only XComs from the current
execution_date are returned. If True, XComs from previous dates
are returned as well.
:type include_prior_dates: bool
:param limit: If required, limit the number of returned objects.
XCom objects can be quite big and you might want to limit the
number of rows.
:type limit: int
:param session: database session
:type session: sqlalchemy.orm.session.Session
"""
filters = []
if key:
filters.append(cls.key == key)
if task_ids:
if is_container(task_ids):
filters.append(cls.task_id.in_(task_ids))
else:
filters.append(cls.task_id == task_ids)
if dag_ids:
if is_container(dag_ids):
filters.append(cls.dag_id.in_(dag_ids))
else:
filters.append(cls.dag_id == dag_ids)
if include_prior_dates:
filters.append(cls.execution_date <= execution_date)
else:
filters.append(cls.execution_date == execution_date)
query = (session
.query(cls)
.filter(and_(*filters))
.order_by(cls.execution_date.desc(), cls.timestamp.desc()))
if limit:
return query.limit(limit)
else:
return query
@classmethod
@provide_session
def delete(cls, xcoms, session=None):
"""Delete Xcom"""
if isinstance(xcoms, XCom):
xcoms = [xcoms]
for xcom in xcoms:
if not isinstance(xcom, XCom):
raise TypeError(
'Expected XCom; received {}'.format(xcom.__class__.__name__)
)
session.delete(xcom)
session.commit()
@staticmethod
def serialize_value(value: Any):
"""Serialize Xcom value to str or pickled object"""
# TODO: "pickling" has been deprecated and JSON is preferred.
# "pickling" will be removed in Airflow 2.0.
if conf.getboolean('core', 'enable_xcom_pickling'):
return pickle.dumps(value)
try:
return json.dumps(value).encode('UTF-8')
except (ValueError, TypeError):
log.error("Could not serialize the XCOM value into JSON. "
"If you are using pickles instead of JSON "
"for XCOM, then you need to enable pickle "
"support for XCOM in your airflow config.")
raise
@staticmethod
def deserialize_value(result) -> Any:
"""Deserialize Xcom value from str or pickle object"""
# TODO: "pickling" has been deprecated and JSON is preferred.
# "pickling" will be removed in Airflow 2.0.
enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
if enable_pickling:
return pickle.loads(result.value)
try:
return json.loads(result.value.decode('UTF-8'))
except JSONDecodeError:
log.error("Could not deserialize the XCOM value from JSON. "
"If you are using pickles instead of JSON "
"for XCOM, then you need to enable pickle "
"support for XCOM in your airflow config.")
raise
def resolve_xcom_backend():
"""Resolves custom XCom class"""
clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
if clazz:
if not issubclass(clazz, BaseXCom):
raise TypeError(
f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`."
)
return clazz
return BaseXCom
XCom = resolve_xcom_backend()