Adding wait_for_downstream as BaseOperator attribute

This commit is contained in:
Maxime 2014-12-12 05:47:35 +00:00
Родитель 4315179061
Коммит c79458b1a6
1 изменённых файлов: 54 добавлений и 9 удалений

Просмотреть файл

@ -331,11 +331,41 @@ class TaskInstance(Base):
else:
return False
def are_dependents_done(self, main_session=None):
"""
Checks whether the dependents of this task instance have all succeeded.
This is meant to be used by wait_for_downstream.
This is useful when you do not want to start processing the next
schedule of a task until the dependents are done. For instance,
if the task DROPs and recreates a table.
"""
session = main_session or settings.Session()
task = self.task
if not task._downstream_list:
return True
downstream_task_ids = [t.task_id for t in task._downstream_list]
ti = session.query(func.count(TaskInstance.task_id)).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(downstream_task_ids),
TaskInstance.execution_date == self.execution_date,
TaskInstance.state == State.SUCCESS,
)
count = ti[0][0]
if not main_session:
session.commit()
session.close()
return count == len(task._downstream_list)
def are_dependencies_met(self, main_session=None):
"""
Returns a boolean on whether the upstream tasks are in a SUCCESS state
and considers depends_on_past and the previous' run state.
"""
TI = TaskInstance
# Using the session if passed as param
session = main_session or settings.Session()
@ -344,18 +374,31 @@ class TaskInstance(Base):
# Checking that the depends_on_past is fulfilled
if (task.depends_on_past and
not self.execution_date == task.start_date):
current_state = self.current_state()
if current_state == State.SUCCESS:
return False
previous_ti = session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == task.task_id,
TI.execution_date == \
self.execution_date-task.schedule_interval,
TI.state == State.SUCCESS,
).first()
if previous_ti:
previous_ti.task = task
if previous_ti.state != State.SUCCESS:
return False
# Applying wait_for_downstream
if task.wait_for_downstream and not \
previous_ti.are_dependents_done(session):
return False
# Checking that all upstream dependencies have succeeded
if task._upstream_list:
upstream_task_ids = [t.task_id for t in task._upstream_list]
ti = session.query(func.count(TaskInstance.task_id)).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(upstream_task_ids),
TaskInstance.execution_date == self.execution_date,
TaskInstance.state == State.SUCCESS,
ti = session.query(func.count(TI.task_id)).filter(
TI.dag_id == self.dag_id,
TI.task_id.in_(upstream_task_ids),
TI.execution_date == self.execution_date,
TI.state == State.SUCCESS,
)
count = ti[0][0]
if count < len(task._upstream_list):
@ -608,6 +651,7 @@ class BaseOperator(Base):
end_date=None,
schedule_interval=timedelta(days=1),
depends_on_past=False,
wait_for_downstream=False,
dag=None,
params=None,
default_args=None,
@ -622,6 +666,7 @@ class BaseOperator(Base):
self.start_date = start_date
self.end_date = end_date
self.depends_on_past = depends_on_past
self.wait_for_downstream = wait_for_downstream
self._schedule_interval = schedule_interval
self.retries = retries
self.retry_delay = retry_delay
@ -1030,7 +1075,7 @@ class DAG(Base):
def run(self, start_date=None, end_date=None, mark_success=False):
from airflow import jobs
job = jobs.BackfillJob(
self, start_date, end_date)
self, start_date, end_date, mark_success)
job.run()