Replace methods on state with frozenset properties (#11576)
Although these lists are short, there's no need to re-create them each time, and also no need for them to be a method. I have made them lowercase (`finished`, `running`) instead of uppercase (`FINISHED`, `RUNNING`) to distinguish them from the actual states.
This commit is contained in:
Родитель
44031bf72b
Коммит
0c5bbe83c6
|
@ -121,7 +121,7 @@ def set_state(
|
|||
tis_altered += qry_sub_dag.with_for_update().all()
|
||||
for task_instance in tis_altered:
|
||||
task_instance.state = state
|
||||
if state in State.finished():
|
||||
if state in State.finished:
|
||||
task_instance.end_date = timezone.utcnow()
|
||||
task_instance.set_duration()
|
||||
else:
|
||||
|
|
|
@ -630,7 +630,7 @@ class BackfillJob(BaseJob):
|
|||
_dag_runs = ti_status.active_runs[:]
|
||||
for run in _dag_runs:
|
||||
run.update_state(session=session)
|
||||
if run.state in State.finished():
|
||||
if run.state in State.finished:
|
||||
ti_status.finished_runs += 1
|
||||
ti_status.active_runs.remove(run)
|
||||
executed_run_dates.append(run.execution_date)
|
||||
|
@ -749,7 +749,7 @@ class BackfillJob(BaseJob):
|
|||
"""
|
||||
for dag_run in dag_runs:
|
||||
dag_run.update_state()
|
||||
if dag_run.state not in State.finished():
|
||||
if dag_run.state not in State.finished:
|
||||
dag_run.set_state(State.FAILED)
|
||||
session.merge(dag_run)
|
||||
|
||||
|
|
|
@ -870,7 +870,7 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
|
|||
}
|
||||
|
||||
# Only add end_date and duration if the new_state is 'success', 'failed' or 'skipped'
|
||||
if new_state in State.finished():
|
||||
if new_state in State.finished:
|
||||
ti_prop_update.update({
|
||||
models.TaskInstance.end_date: current_time,
|
||||
models.TaskInstance.duration: 0,
|
||||
|
@ -1484,7 +1484,7 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
|
|||
func.count(TI.execution_date.distinct()),
|
||||
).filter(
|
||||
TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})),
|
||||
TI.state.notin_(State.finished())
|
||||
TI.state.notin_(list(State.finished))
|
||||
).group_by(TI.dag_id).all())
|
||||
|
||||
for dag_run in dag_runs:
|
||||
|
|
|
@ -126,7 +126,7 @@ class DagRun(Base, LoggingMixin):
|
|||
def set_state(self, state):
|
||||
if self._state != state:
|
||||
self._state = state
|
||||
self.end_date = timezone.utcnow() if self._state in State.finished() else None
|
||||
self.end_date = timezone.utcnow() if self._state in State.finished else None
|
||||
|
||||
@declared_attr
|
||||
def state(self):
|
||||
|
@ -385,8 +385,8 @@ class DagRun(Base, LoggingMixin):
|
|||
for ti in tis:
|
||||
ti.task = dag.get_task(ti.task_id)
|
||||
|
||||
unfinished_tasks = [t for t in tis if t.state in State.unfinished()]
|
||||
finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]]
|
||||
unfinished_tasks = [t for t in tis if t.state in State.unfinished]
|
||||
finished_tasks = [t for t in tis if t.state in State.finished | {State.UPSTREAM_FAILED}]
|
||||
none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
|
||||
none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks)
|
||||
if unfinished_tasks:
|
||||
|
|
|
@ -152,7 +152,7 @@ class SensorInstance(Base):
|
|||
database, in all other cases this will be incremented.
|
||||
"""
|
||||
# This is designed so that task logs end up in the right file.
|
||||
if self.state in State.running():
|
||||
if self.state in State.running:
|
||||
return self._try_number
|
||||
return self._try_number + 1
|
||||
|
||||
|
|
|
@ -295,7 +295,7 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
|
|||
"""
|
||||
# This is designed so that task logs end up in the right file.
|
||||
# TODO: whether we need sensing here or not (in sensor and task_instance state machine)
|
||||
if self.state in State.running():
|
||||
if self.state in State.running:
|
||||
return self._try_number
|
||||
return self._try_number + 1
|
||||
|
||||
|
@ -623,7 +623,7 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
|
|||
self.log.debug("Setting task state for %s to %s", self, state)
|
||||
self.state = state
|
||||
self.start_date = current_time
|
||||
if self.state in State.finished():
|
||||
if self.state in State.finished:
|
||||
self.end_date = current_time
|
||||
self.duration = 0
|
||||
session.merge(self)
|
||||
|
|
|
@ -436,7 +436,7 @@ class SmartSensorOperator(BaseOperator, SkipMixin):
|
|||
def mark_state(ti, sensor_instance):
|
||||
ti.state = state
|
||||
sensor_instance.state = state
|
||||
if state in State.finished():
|
||||
if state in State.finished:
|
||||
ti.end_date = end_date
|
||||
ti.set_duration()
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class DepContext:
|
|||
self.finished_tasks = dag.get_task_instances(
|
||||
start_date=execution_date,
|
||||
end_date=execution_date,
|
||||
state=State.finished() + [State.UPSTREAM_FAILED],
|
||||
state=State.finished | {State.UPSTREAM_FAILED},
|
||||
session=session,
|
||||
)
|
||||
return self.finished_tasks
|
||||
|
|
|
@ -99,45 +99,39 @@ class State:
|
|||
return 'white'
|
||||
return 'black'
|
||||
|
||||
@classmethod
|
||||
def running(cls):
|
||||
"""
|
||||
A list of states indicating that a task is being executed.
|
||||
"""
|
||||
return [
|
||||
cls.RUNNING,
|
||||
cls.SENSING
|
||||
]
|
||||
running = frozenset([
|
||||
RUNNING,
|
||||
SENSING
|
||||
])
|
||||
"""
|
||||
A list of states indicating that a task is being executed.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def finished(cls):
|
||||
"""
|
||||
A list of states indicating that a task started and completed a
|
||||
run attempt. Note that the attempt could have resulted in failure or
|
||||
have been interrupted; in any case, it is no longer running.
|
||||
"""
|
||||
return [
|
||||
cls.SUCCESS,
|
||||
cls.FAILED,
|
||||
cls.SKIPPED,
|
||||
]
|
||||
finished = frozenset([
|
||||
SUCCESS,
|
||||
FAILED,
|
||||
SKIPPED,
|
||||
])
|
||||
"""
|
||||
A list of states indicating that a task started and completed a
|
||||
run attempt. Note that the attempt could have resulted in failure or
|
||||
have been interrupted; in any case, it is no longer running.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def unfinished(cls):
|
||||
"""
|
||||
A list of states indicating that a task either has not completed
|
||||
a run or has not even started.
|
||||
"""
|
||||
return [
|
||||
cls.NONE,
|
||||
cls.SCHEDULED,
|
||||
cls.QUEUED,
|
||||
cls.RUNNING,
|
||||
cls.SENSING,
|
||||
cls.SHUTDOWN,
|
||||
cls.UP_FOR_RETRY,
|
||||
cls.UP_FOR_RESCHEDULE,
|
||||
]
|
||||
unfinished = frozenset([
|
||||
NONE,
|
||||
SCHEDULED,
|
||||
QUEUED,
|
||||
RUNNING,
|
||||
SENSING,
|
||||
SHUTDOWN,
|
||||
UP_FOR_RETRY,
|
||||
UP_FOR_RESCHEDULE,
|
||||
])
|
||||
"""
|
||||
A list of states indicating that a task either has not completed
|
||||
a run or has not even started.
|
||||
"""
|
||||
|
||||
|
||||
class PokeState:
|
||||
|
|
|
@ -107,7 +107,7 @@ class TestMarkTasks(unittest.TestCase):
|
|||
self.assertEqual(ti.operator, dag.get_task(ti.task_id).task_type)
|
||||
if ti.task_id in task_ids and ti.execution_date in execution_dates:
|
||||
self.assertEqual(ti.state, state)
|
||||
if state in State.finished():
|
||||
if state in State.finished:
|
||||
self.assertIsNotNone(ti.end_date)
|
||||
else:
|
||||
for old_ti in old_tis:
|
||||
|
|
|
@ -1952,7 +1952,7 @@ class TestSchedulerJob(unittest.TestCase):
|
|||
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
|
||||
self.assertEqual(ti.state, expected_task_state)
|
||||
self.assertIsNotNone(ti.start_date)
|
||||
if expected_task_state in State.finished():
|
||||
if expected_task_state in State.finished:
|
||||
self.assertIsNotNone(ti.end_date)
|
||||
self.assertEqual(ti.start_date, ti.end_date)
|
||||
self.assertIsNotNone(ti.duration)
|
||||
|
|
|
@ -547,7 +547,7 @@ class TestTriggerRuleDep(unittest.TestCase):
|
|||
finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session)
|
||||
self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2),
|
||||
(1, 0, 0, 0, 1))
|
||||
finished_tasks = dr.get_task_instances(state=State.finished() + [State.UPSTREAM_FAILED],
|
||||
finished_tasks = dr.get_task_instances(state=State.finished | {State.UPSTREAM_FAILED},
|
||||
session=session)
|
||||
self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4),
|
||||
(1, 0, 1, 0, 2))
|
||||
|
|
Загрузка…
Ссылка в новой задаче