30 строки
1.0 KiB
Python
30 строки
1.0 KiB
Python
from typing import TYPE_CHECKING
|
|
|
|
from airflow.models.taskinstance import clear_task_instances
|
|
from airflow.utils.context import Context
|
|
from airflow.utils.db import provide_session
|
|
from sqlalchemy.orm.session import Session
|
|
|
|
if TYPE_CHECKING:
|
|
from airflow.models.dagrun import DagRun
|
|
|
|
|
|
@provide_session
|
|
def retry_tasks_callback(context: Context, session: Session | None = None) -> None:
|
|
"""
|
|
Clear tasks specified by the `retry_tasks` task param.
|
|
|
|
Intended to be used to as an `on_retry_callback` to also retry other tasks when a task fails.
|
|
"""
|
|
retry_task_ids: list[str] = context["params"].get("retry_tasks", [])
|
|
if isinstance(retry_task_ids, str):
|
|
retry_task_ids = [retry_task_ids]
|
|
dag_run: DagRun = context["dag_run"]
|
|
retry_task_instances = [
|
|
task_instance
|
|
for task_instance in dag_run.get_task_instances(session=session)
|
|
if task_instance.task_id in retry_task_ids
|
|
]
|
|
if retry_task_instances:
|
|
clear_task_instances(retry_task_instances, session=session)
|