telemetry-airflow/utils/callbacks.py

30 строки
1.0 KiB
Python

from typing import TYPE_CHECKING
from airflow.models.taskinstance import clear_task_instances
from airflow.utils.context import Context
from airflow.utils.db import provide_session
from sqlalchemy.orm.session import Session
if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
@provide_session
def retry_tasks_callback(context: Context, session: Session | None = None) -> None:
"""
Clear tasks specified by the `retry_tasks` task param.
Intended to be used to as an `on_retry_callback` to also retry other tasks when a task fails.
"""
retry_task_ids: list[str] = context["params"].get("retry_tasks", [])
if isinstance(retry_task_ids, str):
retry_task_ids = [retry_task_ids]
dag_run: DagRun = context["dag_run"]
retry_task_instances = [
task_instance
for task_instance in dag_run.get_task_instances(session=session)
if task_instance.task_id in retry_task_ids
]
if retry_task_instances:
clear_task_instances(retry_task_instances, session=session)