295 строки
10 KiB
Python
295 строки
10 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import json
|
|
import logging
|
|
import pickle
|
|
from json import JSONDecodeError
|
|
from typing import Any, Iterable, Optional, Union
|
|
|
|
import pendulum
|
|
from sqlalchemy import Column, LargeBinary, String, and_
|
|
from sqlalchemy.orm import Query, Session, reconstructor
|
|
|
|
from airflow.configuration import conf
|
|
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
|
|
from airflow.utils import timezone
|
|
from airflow.utils.helpers import is_container
|
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
|
from airflow.utils.session import provide_session
|
|
from airflow.utils.sqlalchemy import UtcDateTime
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
# MAX XCOM Size is 48KB
|
|
# https://github.com/apache/airflow/pull/1618#discussion_r68249677
|
|
MAX_XCOM_SIZE = 49344
|
|
XCOM_RETURN_KEY = 'return_value'
|
|
|
|
|
|
class BaseXCom(Base, LoggingMixin):
|
|
"""
|
|
Base class for XCom objects.
|
|
"""
|
|
|
|
__tablename__ = "xcom"
|
|
|
|
key = Column(String(512, **COLLATION_ARGS), primary_key=True)
|
|
value = Column(LargeBinary)
|
|
timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
|
|
execution_date = Column(UtcDateTime, primary_key=True)
|
|
|
|
# source information
|
|
task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
|
|
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
|
|
|
|
"""
|
|
TODO: "pickling" has been deprecated and JSON is preferred.
|
|
"pickling" will be removed in Airflow 2.0.
|
|
"""
|
|
@reconstructor
|
|
def init_on_load(self):
|
|
"""
|
|
Called by the ORM after the instance has been loaded from the DB or otherwise reconstituted
|
|
i.e automatically deserialize Xcom value when loading from DB.
|
|
"""
|
|
try:
|
|
self.value = XCom.deserialize_value(self)
|
|
except (UnicodeEncodeError, ValueError):
|
|
# For backward-compatibility.
|
|
# Preventing errors in webserver
|
|
# due to XComs mixed with pickled and unpickled.
|
|
self.value = pickle.loads(self.value)
|
|
|
|
def __repr__(self):
|
|
return '<XCom "{key}" ({task_id} @ {execution_date})>'.format(
|
|
key=self.key,
|
|
task_id=self.task_id,
|
|
execution_date=self.execution_date)
|
|
|
|
@classmethod
|
|
@provide_session
|
|
def set(
|
|
cls,
|
|
key,
|
|
value,
|
|
execution_date,
|
|
task_id,
|
|
dag_id,
|
|
session=None):
|
|
"""
|
|
Store an XCom value.
|
|
TODO: "pickling" has been deprecated and JSON is preferred.
|
|
"pickling" will be removed in Airflow 2.0.
|
|
|
|
:return: None
|
|
"""
|
|
session.expunge_all()
|
|
|
|
value = XCom.serialize_value(value)
|
|
|
|
# remove any duplicate XComs
|
|
session.query(cls).filter(
|
|
cls.key == key,
|
|
cls.execution_date == execution_date,
|
|
cls.task_id == task_id,
|
|
cls.dag_id == dag_id).delete()
|
|
|
|
session.commit()
|
|
|
|
# insert new XCom
|
|
session.add(XCom(
|
|
key=key,
|
|
value=value,
|
|
execution_date=execution_date,
|
|
task_id=task_id,
|
|
dag_id=dag_id))
|
|
|
|
session.commit()
|
|
|
|
@classmethod
|
|
@provide_session
|
|
def get_one(cls,
|
|
execution_date: pendulum.DateTime,
|
|
key: Optional[str] = None,
|
|
task_id: Optional[Union[str, Iterable[str]]] = None,
|
|
dag_id: Optional[Union[str, Iterable[str]]] = None,
|
|
include_prior_dates: bool = False,
|
|
session: Session = None) -> Optional[Any]:
|
|
"""
|
|
Retrieve an XCom value, optionally meeting certain criteria. Returns None
|
|
of there are no results.
|
|
|
|
:param execution_date: Execution date for the task
|
|
:type execution_date: pendulum.datetime
|
|
:param key: A key for the XCom. If provided, only XComs with matching
|
|
keys will be returned. To remove the filter, pass key=None.
|
|
:type key: str
|
|
:param task_id: Only XComs from task with matching id will be
|
|
pulled. Can pass None to remove the filter.
|
|
:type task_id: str
|
|
:param dag_id: If provided, only pulls XCom from this DAG.
|
|
If None (default), the DAG of the calling task is used.
|
|
:type dag_id: str
|
|
:param include_prior_dates: If False, only XCom from the current
|
|
execution_date are returned. If True, XCom from previous dates
|
|
are returned as well.
|
|
:type include_prior_dates: bool
|
|
:param session: database session
|
|
:type session: sqlalchemy.orm.session.Session
|
|
"""
|
|
result = cls.get_many(execution_date=execution_date,
|
|
key=key,
|
|
task_ids=task_id,
|
|
dag_ids=dag_id,
|
|
include_prior_dates=include_prior_dates,
|
|
session=session).first()
|
|
if result:
|
|
return result.value
|
|
return None
|
|
|
|
@classmethod
|
|
@provide_session
|
|
def get_many(cls,
|
|
execution_date: pendulum.DateTime,
|
|
key: Optional[str] = None,
|
|
task_ids: Optional[Union[str, Iterable[str]]] = None,
|
|
dag_ids: Optional[Union[str, Iterable[str]]] = None,
|
|
include_prior_dates: bool = False,
|
|
limit: Optional[int] = None,
|
|
session: Session = None) -> Query:
|
|
"""
|
|
Composes a query to get one or more values from the xcom table.
|
|
|
|
:param execution_date: Execution date for the task
|
|
:type execution_date: pendulum.datetime
|
|
:param key: A key for the XCom. If provided, only XComs with matching
|
|
keys will be returned. To remove the filter, pass key=None.
|
|
:type key: str
|
|
:param task_ids: Only XComs from tasks with matching ids will be
|
|
pulled. Can pass None to remove the filter.
|
|
:type task_ids: str or iterable of strings (representing task_ids)
|
|
:param dag_ids: If provided, only pulls XComs from this DAG.
|
|
If None (default), the DAG of the calling task is used.
|
|
:type dag_ids: str
|
|
:param include_prior_dates: If False, only XComs from the current
|
|
execution_date are returned. If True, XComs from previous dates
|
|
are returned as well.
|
|
:type include_prior_dates: bool
|
|
:param limit: If required, limit the number of returned objects.
|
|
XCom objects can be quite big and you might want to limit the
|
|
number of rows.
|
|
:type limit: int
|
|
:param session: database session
|
|
:type session: sqlalchemy.orm.session.Session
|
|
"""
|
|
filters = []
|
|
|
|
if key:
|
|
filters.append(cls.key == key)
|
|
|
|
if task_ids:
|
|
if is_container(task_ids):
|
|
filters.append(cls.task_id.in_(task_ids))
|
|
else:
|
|
filters.append(cls.task_id == task_ids)
|
|
|
|
if dag_ids:
|
|
if is_container(dag_ids):
|
|
filters.append(cls.dag_id.in_(dag_ids))
|
|
else:
|
|
filters.append(cls.dag_id == dag_ids)
|
|
|
|
if include_prior_dates:
|
|
filters.append(cls.execution_date <= execution_date)
|
|
else:
|
|
filters.append(cls.execution_date == execution_date)
|
|
|
|
query = (session
|
|
.query(cls)
|
|
.filter(and_(*filters))
|
|
.order_by(cls.execution_date.desc(), cls.timestamp.desc()))
|
|
|
|
if limit:
|
|
return query.limit(limit)
|
|
else:
|
|
return query
|
|
|
|
@classmethod
|
|
@provide_session
|
|
def delete(cls, xcoms, session=None):
|
|
"""Delete Xcom"""
|
|
if isinstance(xcoms, XCom):
|
|
xcoms = [xcoms]
|
|
for xcom in xcoms:
|
|
if not isinstance(xcom, XCom):
|
|
raise TypeError(
|
|
'Expected XCom; received {}'.format(xcom.__class__.__name__)
|
|
)
|
|
session.delete(xcom)
|
|
session.commit()
|
|
|
|
@staticmethod
|
|
def serialize_value(value: Any):
|
|
"""Serialize Xcom value to str or pickled object"""
|
|
# TODO: "pickling" has been deprecated and JSON is preferred.
|
|
# "pickling" will be removed in Airflow 2.0.
|
|
if conf.getboolean('core', 'enable_xcom_pickling'):
|
|
return pickle.dumps(value)
|
|
try:
|
|
return json.dumps(value).encode('UTF-8')
|
|
except (ValueError, TypeError):
|
|
log.error("Could not serialize the XCOM value into JSON. "
|
|
"If you are using pickles instead of JSON "
|
|
"for XCOM, then you need to enable pickle "
|
|
"support for XCOM in your airflow config.")
|
|
raise
|
|
|
|
@staticmethod
|
|
def deserialize_value(result) -> Any:
|
|
"""Deserialize Xcom value from str or pickle object"""
|
|
# TODO: "pickling" has been deprecated and JSON is preferred.
|
|
# "pickling" will be removed in Airflow 2.0.
|
|
enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
|
|
if enable_pickling:
|
|
return pickle.loads(result.value)
|
|
|
|
try:
|
|
return json.loads(result.value.decode('UTF-8'))
|
|
except JSONDecodeError:
|
|
log.error("Could not deserialize the XCOM value from JSON. "
|
|
"If you are using pickles instead of JSON "
|
|
"for XCOM, then you need to enable pickle "
|
|
"support for XCOM in your airflow config.")
|
|
raise
|
|
|
|
|
|
def resolve_xcom_backend():
|
|
"""Resolves custom XCom class"""
|
|
clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
|
|
if clazz:
|
|
if not issubclass(clazz, BaseXCom):
|
|
raise TypeError(
|
|
f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`."
|
|
)
|
|
return clazz
|
|
return BaseXCom
|
|
|
|
|
|
XCom = resolve_xcom_backend()
|