[AIRFLOW-2761] Parallelize enqueue in celery executor (#4234)

This commit is contained in:
Kevin Yang 2018-11-28 14:23:44 -08:00 коммит произвёл Kaxil Naik
Родитель 8fdf5ce5f5
Коммит 1d53f93966
9 изменённых файлов: 389 добавлений и 125 удалений

Просмотреть файл

@ -19,11 +19,12 @@
from builtins import range
# To avoid circular imports
import airflow.utils.dag_processing
from airflow import configuration
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
PARALLELISM = configuration.conf.getint('core', 'PARALLELISM')
@ -50,11 +51,11 @@ class BaseExecutor(LoggingMixin):
"""
pass
def queue_command(self, task_instance, command, priority=1, queue=None):
key = task_instance.key
def queue_command(self, simple_task_instance, command, priority=1, queue=None):
key = simple_task_instance.key
if key not in self.queued_tasks and key not in self.running:
self.log.info("Adding to queue: %s", command)
self.queued_tasks[key] = (command, priority, queue, task_instance)
self.queued_tasks[key] = (command, priority, queue, simple_task_instance)
else:
self.log.info("could not queue task {}".format(key))
@ -86,7 +87,7 @@ class BaseExecutor(LoggingMixin):
pickle_id=pickle_id,
cfg_path=cfg_path)
self.queue_command(
task_instance,
airflow.utils.dag_processing.SimpleTaskInstance(task_instance),
command,
priority=task_instance.task.priority_weight_total,
queue=task_instance.task.queue)
@ -124,26 +125,13 @@ class BaseExecutor(LoggingMixin):
key=lambda x: x[1][1],
reverse=True)
for i in range(min((open_slots, len(self.queued_tasks)))):
key, (command, _, queue, ti) = sorted_queue.pop(0)
# TODO(jlowin) without a way to know what Job ran which tasks,
# there is a danger that another Job started running a task
# that was also queued to this executor. This is the last chance
# to check if that happened. The most probable way is that a
# Scheduler tried to run a task that was originally queued by a
# Backfill. This fix reduces the probability of a collision but
# does NOT eliminate it.
key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
self.queued_tasks.pop(key)
ti.refresh_from_db()
if ti.state != State.RUNNING:
self.running[key] = command
self.execute_async(key=key,
command=command,
queue=queue,
executor_config=ti.executor_config)
else:
self.log.info(
'Task is already running, not sending to '
'executor: {}'.format(key))
self.running[key] = command
self.execute_async(key=key,
command=command,
queue=queue,
executor_config=simple_ti.executor_config)
# Calling child class sync method
self.log.debug("Calling the %s sync method", self.__class__)
@ -151,7 +139,7 @@ class BaseExecutor(LoggingMixin):
def change_state(self, key, state):
self.log.debug("Changing state: {}".format(key))
self.running.pop(key)
self.running.pop(key, None)
self.event_buffer[key] = state
def fail(self, key):

Просмотреть файл

@ -33,10 +33,13 @@ from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
from airflow.utils.timeout import timeout
# Make it constant for unit test.
CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'
CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task'
'''
To start the celery worker, run the command:
airflow worker
@ -55,12 +58,12 @@ app = Celery(
@app.task
def execute_command(command):
def execute_command(command_to_exec):
log = LoggingMixin().log
log.info("Executing command in Celery: %s", command)
log.info("Executing command in Celery: %s", command_to_exec)
env = os.environ.copy()
try:
subprocess.check_call(command, stderr=subprocess.STDOUT,
subprocess.check_call(command_to_exec, stderr=subprocess.STDOUT,
close_fds=True, env=env)
except subprocess.CalledProcessError as e:
log.exception('execute_command encountered a CalledProcessError')
@ -95,9 +98,10 @@ def fetch_celery_task_state(celery_task):
"""
try:
# Accessing state property of celery task will make actual network request
# to get the current state of the task.
res = (celery_task[0], celery_task[1].state)
with timeout(seconds=2):
# Accessing state property of celery task will make actual network request
# to get the current state of the task.
res = (celery_task[0], celery_task[1].state)
except Exception as e:
exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
traceback.format_exc())
@ -105,6 +109,19 @@ def fetch_celery_task_state(celery_task):
return res
def send_task_to_executor(task_tuple):
key, simple_ti, command, queue, task = task_tuple
try:
with timeout(seconds=2):
result = task.apply_async(args=[command], queue=queue)
except Exception as e:
exception_traceback = "Celery Task ID: {}\n{}".format(key,
traceback.format_exc())
result = ExceptionWithTraceback(e, exception_traceback)
return key, command, result
class CeleryExecutor(BaseExecutor):
"""
CeleryExecutor is recommended for production use of Airflow. It allows
@ -135,24 +152,91 @@ class CeleryExecutor(BaseExecutor):
'Starting Celery Executor using {} processes for syncing'.format(
self._sync_parallelism))
def execute_async(self, key, command,
queue=DEFAULT_CELERY_CONFIG['task_default_queue'],
executor_config=None):
self.log.info("[celery] queuing {key} through celery, "
"queue={queue}".format(**locals()))
self.tasks[key] = execute_command.apply_async(
args=[command], queue=queue)
self.last_state[key] = celery_states.PENDING
def _num_tasks_per_send_process(self, to_send_count):
"""
How many Celery tasks should each worker process send.
def _num_tasks_per_process(self):
:return: Number of tasks that should be sent per process
:rtype: int
"""
return max(1,
int(math.ceil(1.0 * to_send_count / self._sync_parallelism)))
def _num_tasks_per_fetch_process(self):
"""
How many Celery tasks should be sent to each worker process.
:return: Number of tasks that should be used per process
:rtype: int
"""
return max(1,
int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism)))
def heartbeat(self):
# Triggering new jobs
if not self.parallelism:
open_slots = len(self.queued_tasks)
else:
open_slots = self.parallelism - len(self.running)
self.log.debug("{} running task instances".format(len(self.running)))
self.log.debug("{} in queue".format(len(self.queued_tasks)))
self.log.debug("{} open slots".format(open_slots))
sorted_queue = sorted(
[(k, v) for k, v in self.queued_tasks.items()],
key=lambda x: x[1][1],
reverse=True)
task_tuples_to_send = []
for i in range(min((open_slots, len(self.queued_tasks)))):
key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
task_tuples_to_send.append((key, simple_ti, command, queue,
execute_command))
cached_celery_backend = None
if task_tuples_to_send:
tasks = [t[4] for t in task_tuples_to_send]
# Celery state queries will stuck if we do not use one same backend
# for all tasks.
cached_celery_backend = tasks[0].backend
if task_tuples_to_send:
# Use chunking instead of a work queue to reduce context switching
# since tasks are roughly uniform in size
chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
send_pool = Pool(processes=num_processes)
key_and_async_results = send_pool.map(
send_task_to_executor,
task_tuples_to_send,
chunksize=chunksize)
send_pool.close()
send_pool.join()
self.log.debug('Sent all tasks.')
for key, command, result in key_and_async_results:
if isinstance(result, ExceptionWithTraceback):
self.log.error(
CELERY_SEND_ERR_MSG_HEADER + ":{}\n{}\n".format(
result.exception, result.traceback))
elif result is not None:
# Only pops when enqueued successfully, otherwise keep it
# and expect scheduler loop to deal with it.
self.queued_tasks.pop(key)
result.backend = cached_celery_backend
self.running[key] = command
self.tasks[key] = result
self.last_state[key] = celery_states.PENDING
# Calling child class sync method
self.log.debug("Calling the {} sync method".format(self.__class__))
self.sync()
def sync(self):
num_processes = min(len(self.tasks), self._sync_parallelism)
if num_processes == 0:
@ -167,7 +251,7 @@ class CeleryExecutor(BaseExecutor):
# Use chunking instead of a work queue to reduce context switching since tasks are
# roughly uniform in size
chunksize = self._num_tasks_per_process()
chunksize = self._num_tasks_per_fetch_process()
self.log.debug("Waiting for inquiries to complete...")
task_keys_to_states = self._sync_pool.map(

Просмотреть файл

@ -52,6 +52,7 @@ from airflow.utils.dag_processing import (AbstractDagFileProcessor,
DagFileProcessorAgent,
SimpleDag,
SimpleDagBag,
SimpleTaskInstance,
list_py_file_paths)
from airflow.utils.db import create_session, provide_session
from airflow.utils.email import get_email_address_list, send_email
@ -598,6 +599,7 @@ class SchedulerJob(BaseJob):
'run_duration')
self.processor_agent = None
self._last_loop = False
signal.signal(signal.SIGINT, self._exit_gracefully)
signal.signal(signal.SIGTERM, self._exit_gracefully)
@ -1228,13 +1230,13 @@ class SchedulerJob(BaseJob):
acceptable_states, session=None):
"""
Changes the state of task instances in the list with one of the given states
to QUEUED atomically, and returns the TIs changed.
to QUEUED atomically, and returns the TIs changed in SimpleTaskInstance format.
:param task_instances: TaskInstances to change the state of
:type task_instances: List[TaskInstance]
:param acceptable_states: Filters the TaskInstances updated to be in these states
:type acceptable_states: Iterable[State]
:return: List[TaskInstance]
:return: List[SimpleTaskInstance]
"""
if len(task_instances) == 0:
session.commit()
@ -1276,81 +1278,57 @@ class SchedulerJob(BaseJob):
else task_instance.queued_dttm)
session.merge(task_instance)
# save which TIs we set before session expires them
filter_for_ti_enqueue = ([and_(TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.execution_date == ti.execution_date)
for ti in tis_to_set_to_queued])
# Generate a list of SimpleTaskInstance for the use of queuing
# them in the executor.
simple_task_instances = [SimpleTaskInstance(ti) for ti in
tis_to_set_to_queued]
task_instance_str = "\n\t".join(
["{}".format(x) for x in tis_to_set_to_queued])
session.commit()
self.log.info("Setting the following {} tasks to queued state:\n\t{}"
.format(len(tis_to_set_to_queued), task_instance_str))
return simple_task_instances
# requery in batches since above was expired by commit
def query(result, items):
tis_to_be_queued = (
session
.query(TI)
.filter(or_(*items))
.all())
task_instance_str = "\n\t".join(
["{}".format(x) for x in tis_to_be_queued])
self.log.info("Setting the following {} tasks to queued state:\n\t{}"
.format(len(tis_to_be_queued),
task_instance_str))
return result + tis_to_be_queued
tis_to_be_queued = helpers.reduce_in_chunks(query,
filter_for_ti_enqueue,
[],
self.max_tis_per_query)
return tis_to_be_queued
def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instances):
def _enqueue_task_instances_with_queued_state(self, simple_dag_bag,
simple_task_instances):
"""
Takes task_instances, which should have been set to queued, and enqueues them
with the executor.
:param task_instances: TaskInstances to enqueue
:type task_instances: List[TaskInstance]
:param simple_task_instances: TaskInstances to enqueue
:type simple_task_instances: List[SimpleTaskInstance]
:param simple_dag_bag: Should contains all of the task_instances' dags
:type simple_dag_bag: SimpleDagBag
"""
TI = models.TaskInstance
# actually enqueue them
for task_instance in task_instances:
simple_dag = simple_dag_bag.get_dag(task_instance.dag_id)
for simple_task_instance in simple_task_instances:
simple_dag = simple_dag_bag.get_dag(simple_task_instance.dag_id)
command = TI.generate_command(
task_instance.dag_id,
task_instance.task_id,
task_instance.execution_date,
simple_task_instance.dag_id,
simple_task_instance.task_id,
simple_task_instance.execution_date,
local=True,
mark_success=False,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pool=task_instance.pool,
pool=simple_task_instance.pool,
file_path=simple_dag.full_filepath,
pickle_id=simple_dag.pickle_id)
priority = task_instance.priority_weight
queue = task_instance.queue
priority = simple_task_instance.priority_weight
queue = simple_task_instance.queue
self.log.info(
"Sending %s to executor with priority %s and queue %s",
task_instance.key, priority, queue
simple_task_instance.key, priority, queue
)
# save attributes so sqlalchemy doesnt expire them
copy_dag_id = task_instance.dag_id
copy_task_id = task_instance.task_id
copy_execution_date = task_instance.execution_date
make_transient(task_instance)
task_instance.dag_id = copy_dag_id
task_instance.task_id = copy_task_id
task_instance.execution_date = copy_execution_date
self.executor.queue_command(
task_instance,
simple_task_instance,
command,
priority=priority,
queue=queue)
@ -1374,24 +1352,65 @@ class SchedulerJob(BaseJob):
:type simple_dag_bag: SimpleDagBag
:param states: Execute TaskInstances in these states
:type states: Tuple[State]
:return: None
:return: Number of task instance with state changed.
"""
executable_tis = self._find_executable_task_instances(simple_dag_bag, states,
session=session)
def query(result, items):
tis_with_state_changed = self._change_state_for_executable_task_instances(
items,
states,
session=session)
simple_tis_with_state_changed = \
self._change_state_for_executable_task_instances(items,
states,
session=session)
self._enqueue_task_instances_with_queued_state(
simple_dag_bag,
tis_with_state_changed)
simple_tis_with_state_changed)
session.commit()
return result + len(tis_with_state_changed)
return result + len(simple_tis_with_state_changed)
return helpers.reduce_in_chunks(query, executable_tis, 0, self.max_tis_per_query)
@provide_session
def _change_state_for_tasks_failed_to_execute(self, session):
"""
If there are tasks left over in the executor,
we set them back to SCHEDULED to avoid creating hanging tasks.
:param session: session for ORM operations
"""
if self.executor.queued_tasks:
TI = models.TaskInstance
filter_for_ti_state_change = (
[and_(
TI.dag_id == dag_id,
TI.task_id == task_id,
TI.execution_date == execution_date,
# The TI.try_number will return raw try_number+1 since the
# ti is not running. And we need to -1 to match the DB record.
TI._try_number == try_number - 1,
TI.state == State.QUEUED)
for dag_id, task_id, execution_date, try_number
in self.executor.queued_tasks.keys()])
ti_query = (session.query(TI)
.filter(or_(*filter_for_ti_state_change)))
tis_to_set_to_scheduled = (ti_query
.with_for_update()
.all())
if len(tis_to_set_to_scheduled) == 0:
session.commit()
return
# set TIs to queued state
for task_instance in tis_to_set_to_scheduled:
task_instance.state = State.SCHEDULED
task_instance_str = "\n\t".join(
["{}".format(x) for x in tis_to_set_to_scheduled])
session.commit()
self.log.info("Set the follow tasks to scheduled state:\n\t{}"
.format(task_instance_str))
def _process_dags(self, dagbag, dags, tis_out):
"""
Iterates over the dags and processes them. Processing includes:
@ -1507,6 +1526,8 @@ class SchedulerJob(BaseJob):
try:
self._execute_helper()
except Exception:
self.log.exception("Exception when executing execute_helper")
finally:
self.processor_agent.end()
self.log.info("Exited execute loop")
@ -1557,6 +1578,7 @@ class SchedulerJob(BaseJob):
self.log.info("Harvesting DAG parsing results")
simple_dags = self.processor_agent.harvest_simple_dags()
self.log.debug("Harvested {} SimpleDAGs".format(len(simple_dags)))
# Send tasks for execution if available
simple_dag_bag = SimpleDagBag(simple_dags)
@ -1593,6 +1615,8 @@ class SchedulerJob(BaseJob):
self.log.debug("Heartbeating the executor")
self.executor.heartbeat()
self._change_state_for_tasks_failed_to_execute()
# Process events from the executor
self._process_executor_events(simple_dag_bag)
@ -1612,8 +1636,13 @@ class SchedulerJob(BaseJob):
self.log.debug("Sleeping for %.2f seconds", self._processor_poll_interval)
time.sleep(self._processor_poll_interval)
# Exit early for a test mode
# Exit early for a test mode, run one additional scheduler loop
# to reduce the possibility that parsed DAG was put into the queue
# by the DAG manager but not yet received by DAG agent.
if self.processor_agent.done:
self._last_loop = True
if self._last_loop:
self.log.info("Exiting scheduler loop as all files"
" have been processed {} times".format(self.num_runs))
break

Просмотреть файл

@ -22,12 +22,12 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
from collections import defaultdict, namedtuple
from builtins import ImportError as BuiltinImportError, bytes, object, str
from future.standard_library import install_aliases
from builtins import str, object, bytes, ImportError as BuiltinImportError
import copy
from collections import namedtuple, defaultdict
try:
# Fix Python > 3.7 deprecation
from collections.abc import Hashable

Просмотреть файл

@ -146,6 +146,21 @@ class SimpleTaskInstance(object):
self._end_date = ti.end_date
self._try_number = ti.try_number
self._state = ti.state
self._executor_config = ti.executor_config
if hasattr(ti, 'run_as_user'):
self._run_as_user = ti.run_as_user
else:
self._run_as_user = None
if hasattr(ti, 'pool'):
self._pool = ti.pool
else:
self._pool = None
if hasattr(ti, 'priority_weight'):
self._priority_weight = ti.priority_weight
else:
self._priority_weight = None
self._queue = ti.queue
self._key = ti.key
@property
def dag_id(self):
@ -175,6 +190,49 @@ class SimpleTaskInstance(object):
def state(self):
return self._state
@property
def pool(self):
return self._pool
@property
def priority_weight(self):
return self._priority_weight
@property
def queue(self):
return self._queue
@property
def key(self):
return self._key
@property
def executor_config(self):
return self._executor_config
@provide_session
def construct_task_instance(self, session=None, lock_for_update=False):
"""
Construct a TaskInstance from the database based on the primary key
:param session: DB session.
:param lock_for_update: if True, indicates that the database should
lock the TaskInstance (issuing a FOR UPDATE clause) until the
session is committed.
"""
TI = airflow.models.TaskInstance
qry = session.query(TI).filter(
TI.dag_id == self._dag_id,
TI.task_id == self._task_id,
TI.execution_date == self._execution_date)
if lock_for_update:
ti = qry.with_for_update().first()
else:
ti = qry.first()
return ti
class SimpleDagBag(BaseDagBag):
"""
@ -571,11 +629,16 @@ class DagFileProcessorAgent(LoggingMixin):
Terminate (and then kill) the manager process launched.
:return:
"""
if not self._process or not self._process.is_alive():
if not self._process:
self.log.warn('Ending without manager process.')
return
this_process = psutil.Process(os.getpid())
manager_process = psutil.Process(self._process.pid)
try:
manager_process = psutil.Process(self._process.pid)
except psutil.NoSuchProcess:
self.log.info("Manager process not running.")
return
# First try SIGTERM
if manager_process.is_running() \
and manager_process.pid in [x.pid for x in this_process.children()]:

Просмотреть файл

@ -23,6 +23,7 @@ from __future__ import print_function
from __future__ import unicode_literals
import signal
import os
from airflow.exceptions import AirflowTaskTimeout
from airflow.utils.log.logging_mixin import LoggingMixin
@ -35,10 +36,10 @@ class timeout(LoggingMixin):
def __init__(self, seconds=1, error_message='Timeout'):
self.seconds = seconds
self.error_message = error_message
self.error_message = error_message + ', PID: ' + str(os.getpid())
def handle_timeout(self, signum, frame):
self.log.error("Process timed out")
self.log.error("Process timed out, PID: " + str(os.getpid()))
raise AirflowTaskTimeout(self.error_message)
def __enter__(self):

Просмотреть файл

@ -18,12 +18,17 @@
# under the License.
import sys
import unittest
from multiprocessing import Pool
import mock
from celery.contrib.testing.worker import start_worker
from airflow.executors.celery_executor import CeleryExecutor
from airflow.executors.celery_executor import app
from airflow.executors import celery_executor
from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER
from airflow.executors.celery_executor import (CeleryExecutor, celery_configuration,
send_task_to_executor, execute_command)
from airflow.executors.celery_executor import app
from celery import states as celery_states
from airflow.utils.state import State
from airflow import configuration
@ -40,16 +45,37 @@ class CeleryExecutorTest(unittest.TestCase):
executor = CeleryExecutor()
executor.start()
with start_worker(app=app, logfile=sys.stdout, loglevel='debug'):
success_command = ['true', 'some_parameter']
fail_command = ['false', 'some_parameter']
executor.execute_async(key='success', command=success_command)
# errors are propagated for some reason
try:
executor.execute_async(key='fail', command=fail_command)
except Exception:
pass
cached_celery_backend = execute_command.backend
task_tuples_to_send = [('success', 'fake_simple_ti', success_command,
celery_configuration['task_default_queue'],
execute_command),
('fail', 'fake_simple_ti', fail_command,
celery_configuration['task_default_queue'],
execute_command)]
chunksize = executor._num_tasks_per_send_process(len(task_tuples_to_send))
num_processes = min(len(task_tuples_to_send), executor._sync_parallelism)
send_pool = Pool(processes=num_processes)
key_and_async_results = send_pool.map(
send_task_to_executor,
task_tuples_to_send,
chunksize=chunksize)
send_pool.close()
send_pool.join()
for key, command, result in key_and_async_results:
# Only pops when enqueued successfully, otherwise keep it
# and expect scheduler loop to deal with it.
result.backend = cached_celery_backend
executor.running[key] = command
executor.tasks[key] = result
executor.last_state[key] = celery_states.PENDING
executor.running['success'] = True
executor.running['fail'] = True
@ -64,6 +90,23 @@ class CeleryExecutorTest(unittest.TestCase):
self.assertNotIn('success', executor.last_state)
self.assertNotIn('fail', executor.last_state)
@unittest.skipIf('sqlite' in configuration.conf.get('core', 'sql_alchemy_conn'),
"sqlite is configured with SequentialExecutor")
def test_error_sending_task(self):
@app.task
def fake_execute_command():
pass
# fake_execute_command takes no arguments while execute_command takes 1,
# which will cause TypeError when calling task.apply_async()
celery_executor.execute_command = fake_execute_command
executor = CeleryExecutor()
value_tuple = 'command', '_', 'queue', 'should_be_a_simple_ti'
executor.queued_tasks['key'] = value_tuple
executor.heartbeat()
self.assertEquals(1, len(executor.queued_tasks))
self.assertEquals(executor.queued_tasks['key'], value_tuple)
def test_exception_propagation(self):
@app.task
def fake_celery_task():

Просмотреть файл

@ -46,7 +46,8 @@ class TestExecutor(BaseExecutor):
ti = self._running.pop()
ti.set_state(State.SUCCESS, session)
for key, val in list(self.queued_tasks.items()):
(command, priority, queue, ti) = val
(command, priority, queue, simple_ti) = val
ti = simple_ti.construct_task_instance()
ti.set_state(State.RUNNING, session)
self._running.append(ti)
self.queued_tasks.pop(key)

Просмотреть файл

@ -2033,6 +2033,54 @@ class SchedulerJobTest(unittest.TestCase):
ti2.refresh_from_db(session=session)
self.assertEqual(ti2.state, State.SCHEDULED)
def test_change_state_for_tasks_failed_to_execute(self):
dag = DAG(
dag_id='dag_id',
start_date=DEFAULT_DATE)
task = DummyOperator(
task_id='task_id',
dag=dag,
owner='airflow')
# If there's no left over task in executor.queued_tasks, nothing happens
session = settings.Session()
scheduler_job = SchedulerJob()
mock_logger = mock.MagicMock()
test_executor = TestExecutor()
scheduler_job.executor = test_executor
scheduler_job._logger = mock_logger
scheduler_job._change_state_for_tasks_failed_to_execute()
mock_logger.info.assert_not_called()
# Tasks failed to execute with QUEUED state will be set to SCHEDULED state.
session.query(TI).delete()
session.commit()
key = 'dag_id', 'task_id', DEFAULT_DATE, 1
test_executor.queued_tasks[key] = 'value'
ti = TI(task, DEFAULT_DATE)
ti.state = State.QUEUED
session.merge(ti)
session.commit()
scheduler_job._change_state_for_tasks_failed_to_execute()
ti.refresh_from_db()
self.assertEquals(State.SCHEDULED, ti.state)
# Tasks failed to execute with RUNNING state will not be set to SCHEDULED state.
session.query(TI).delete()
session.commit()
ti.state = State.RUNNING
session.merge(ti)
session.commit()
scheduler_job._change_state_for_tasks_failed_to_execute()
ti.refresh_from_db()
self.assertEquals(State.RUNNING, ti.state)
def test_execute_helper_reset_orphaned_tasks(self):
session = settings.Session()
dag = DAG(
@ -2950,7 +2998,8 @@ class SchedulerJobTest(unittest.TestCase):
pass
ti_tuple = six.next(six.itervalues(executor.queued_tasks))
(command, priority, queue, ti) = ti_tuple
(command, priority, queue, simple_ti) = ti_tuple
ti = simple_ti.construct_task_instance()
ti.task = dag_task1
self.assertEqual(ti.try_number, 1)
@ -2971,15 +3020,21 @@ class SchedulerJobTest(unittest.TestCase):
# removing self.assertEqual(ti.state, State.SCHEDULED)
# as scheduler will move state from SCHEDULED to QUEUED
# now the executor has cleared and it should be allowed the re-queue
# now the executor has cleared and it should be allowed the re-queue,
# but tasks stay in the executor.queued_tasks after executor.heartbeat()
# will be set back to SCHEDULED state
executor.queued_tasks.clear()
do_schedule()
ti.refresh_from_db()
self.assertEqual(ti.state, State.QUEUED)
# calling below again in order to ensure with try_number 2,
# scheduler doesn't put task in queue
self.assertEqual(ti.state, State.SCHEDULED)
# To verify that task does get re-queued.
executor.queued_tasks.clear()
executor.do_update = True
do_schedule()
self.assertEquals(1, len(executor.queued_tasks))
ti.refresh_from_db()
self.assertEqual(ti.state, State.RUNNING)
@unittest.skipUnless("INTEGRATION" in os.environ, "Can only run end to end")
def test_retry_handling_job(self):
@ -3024,8 +3079,8 @@ class SchedulerJobTest(unittest.TestCase):
logging.info("Test ran in %.2fs, expected %.2fs",
run_duration,
expected_run_duration)
# 5s to wait for child process to exit and 1s dummy sleep
# in scheduler loop to prevent excessive logs.
# 5s to wait for child process to exit, 1s dummy sleep
# in scheduler loop to prevent excessive logs and 1s for last loop to finish.
self.assertLess(run_duration - expected_run_duration, 6.0)
def test_dag_with_system_exit(self):