Replace methods on state with frozenset properties (#11576)

Although these lists are short, there's no need to re-create them each
time, and also no need for them to be a method.

I have made them lowercase (`finished`, `running`) instead of uppercase
(`FINISHED`, `RUNNING`) to distinguish them from the actual states.
This commit is contained in:
Ash Berlin-Taylor 2020-10-16 21:09:36 +01:00 коммит произвёл GitHub
Родитель 44031bf72b
Коммит 0c5bbe83c6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 47 добавлений и 53 удалений

Просмотреть файл

@ -121,7 +121,7 @@ def set_state(
tis_altered += qry_sub_dag.with_for_update().all()
for task_instance in tis_altered:
task_instance.state = state
if state in State.finished():
if state in State.finished:
task_instance.end_date = timezone.utcnow()
task_instance.set_duration()
else:

Просмотреть файл

@ -630,7 +630,7 @@ class BackfillJob(BaseJob):
_dag_runs = ti_status.active_runs[:]
for run in _dag_runs:
run.update_state(session=session)
if run.state in State.finished():
if run.state in State.finished:
ti_status.finished_runs += 1
ti_status.active_runs.remove(run)
executed_run_dates.append(run.execution_date)
@ -749,7 +749,7 @@ class BackfillJob(BaseJob):
"""
for dag_run in dag_runs:
dag_run.update_state()
if dag_run.state not in State.finished():
if dag_run.state not in State.finished:
dag_run.set_state(State.FAILED)
session.merge(dag_run)

Просмотреть файл

@ -870,7 +870,7 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
}
# Only add end_date and duration if the new_state is 'success', 'failed' or 'skipped'
if new_state in State.finished():
if new_state in State.finished:
ti_prop_update.update({
models.TaskInstance.end_date: current_time,
models.TaskInstance.duration: 0,
@ -1484,7 +1484,7 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
func.count(TI.execution_date.distinct()),
).filter(
TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})),
TI.state.notin_(State.finished())
TI.state.notin_(list(State.finished))
).group_by(TI.dag_id).all())
for dag_run in dag_runs:

Просмотреть файл

@ -126,7 +126,7 @@ class DagRun(Base, LoggingMixin):
def set_state(self, state):
if self._state != state:
self._state = state
self.end_date = timezone.utcnow() if self._state in State.finished() else None
self.end_date = timezone.utcnow() if self._state in State.finished else None
@declared_attr
def state(self):
@ -385,8 +385,8 @@ class DagRun(Base, LoggingMixin):
for ti in tis:
ti.task = dag.get_task(ti.task_id)
unfinished_tasks = [t for t in tis if t.state in State.unfinished()]
finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]]
unfinished_tasks = [t for t in tis if t.state in State.unfinished]
finished_tasks = [t for t in tis if t.state in State.finished | {State.UPSTREAM_FAILED}]
none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks)
if unfinished_tasks:

Просмотреть файл

@ -152,7 +152,7 @@ class SensorInstance(Base):
database, in all other cases this will be incremented.
"""
# This is designed so that task logs end up in the right file.
if self.state in State.running():
if self.state in State.running:
return self._try_number
return self._try_number + 1

Просмотреть файл

@ -295,7 +295,7 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
"""
# This is designed so that task logs end up in the right file.
# TODO: whether we need sensing here or not (in sensor and task_instance state machine)
if self.state in State.running():
if self.state in State.running:
return self._try_number
return self._try_number + 1
@ -623,7 +623,7 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
self.log.debug("Setting task state for %s to %s", self, state)
self.state = state
self.start_date = current_time
if self.state in State.finished():
if self.state in State.finished:
self.end_date = current_time
self.duration = 0
session.merge(self)

Просмотреть файл

@ -436,7 +436,7 @@ class SmartSensorOperator(BaseOperator, SkipMixin):
def mark_state(ti, sensor_instance):
ti.state = state
sensor_instance.state = state
if state in State.finished():
if state in State.finished:
ti.end_date = end_date
ti.set_duration()

Просмотреть файл

@ -100,7 +100,7 @@ class DepContext:
self.finished_tasks = dag.get_task_instances(
start_date=execution_date,
end_date=execution_date,
state=State.finished() + [State.UPSTREAM_FAILED],
state=State.finished | {State.UPSTREAM_FAILED},
session=session,
)
return self.finished_tasks

Просмотреть файл

@ -99,45 +99,39 @@ class State:
return 'white'
return 'black'
@classmethod
def running(cls):
"""
A list of states indicating that a task is being executed.
"""
return [
cls.RUNNING,
cls.SENSING
]
running = frozenset([
RUNNING,
SENSING
])
"""
A list of states indicating that a task is being executed.
"""
@classmethod
def finished(cls):
"""
A list of states indicating that a task started and completed a
run attempt. Note that the attempt could have resulted in failure or
have been interrupted; in any case, it is no longer running.
"""
return [
cls.SUCCESS,
cls.FAILED,
cls.SKIPPED,
]
finished = frozenset([
SUCCESS,
FAILED,
SKIPPED,
])
"""
A list of states indicating that a task started and completed a
run attempt. Note that the attempt could have resulted in failure or
have been interrupted; in any case, it is no longer running.
"""
@classmethod
def unfinished(cls):
"""
A list of states indicating that a task either has not completed
a run or has not even started.
"""
return [
cls.NONE,
cls.SCHEDULED,
cls.QUEUED,
cls.RUNNING,
cls.SENSING,
cls.SHUTDOWN,
cls.UP_FOR_RETRY,
cls.UP_FOR_RESCHEDULE,
]
unfinished = frozenset([
NONE,
SCHEDULED,
QUEUED,
RUNNING,
SENSING,
SHUTDOWN,
UP_FOR_RETRY,
UP_FOR_RESCHEDULE,
])
"""
A list of states indicating that a task either has not completed
a run or has not even started.
"""
class PokeState:

Просмотреть файл

@ -107,7 +107,7 @@ class TestMarkTasks(unittest.TestCase):
self.assertEqual(ti.operator, dag.get_task(ti.task_id).task_type)
if ti.task_id in task_ids and ti.execution_date in execution_dates:
self.assertEqual(ti.state, state)
if state in State.finished():
if state in State.finished:
self.assertIsNotNone(ti.end_date)
else:
for old_ti in old_tis:

Просмотреть файл

@ -1952,7 +1952,7 @@ class TestSchedulerJob(unittest.TestCase):
ti = dr.get_task_instance(task_id=op1.task_id, session=session)
self.assertEqual(ti.state, expected_task_state)
self.assertIsNotNone(ti.start_date)
if expected_task_state in State.finished():
if expected_task_state in State.finished:
self.assertIsNotNone(ti.end_date)
self.assertEqual(ti.start_date, ti.end_date)
self.assertIsNotNone(ti.duration)

Просмотреть файл

@ -547,7 +547,7 @@ class TestTriggerRuleDep(unittest.TestCase):
finished_tasks = DepContext().ensure_finished_tasks(ti_op2.task.dag, ti_op2.execution_date, session)
self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op2),
(1, 0, 0, 0, 1))
finished_tasks = dr.get_task_instances(state=State.finished() + [State.UPSTREAM_FAILED],
finished_tasks = dr.get_task_instances(state=State.finished | {State.UPSTREAM_FAILED},
session=session)
self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4),
(1, 0, 1, 0, 2))