[AIRFLOW-2761] Parallelize enqueue in celery executor (#4234)
This commit is contained in:
Родитель
8fdf5ce5f5
Коммит
1d53f93966
|
@ -19,11 +19,12 @@
|
|||
|
||||
from builtins import range
|
||||
|
||||
# To avoid circular imports
|
||||
import airflow.utils.dag_processing
|
||||
from airflow import configuration
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from airflow.utils.state import State
|
||||
|
||||
|
||||
PARALLELISM = configuration.conf.getint('core', 'PARALLELISM')
|
||||
|
||||
|
||||
|
@ -50,11 +51,11 @@ class BaseExecutor(LoggingMixin):
|
|||
"""
|
||||
pass
|
||||
|
||||
def queue_command(self, task_instance, command, priority=1, queue=None):
|
||||
key = task_instance.key
|
||||
def queue_command(self, simple_task_instance, command, priority=1, queue=None):
|
||||
key = simple_task_instance.key
|
||||
if key not in self.queued_tasks and key not in self.running:
|
||||
self.log.info("Adding to queue: %s", command)
|
||||
self.queued_tasks[key] = (command, priority, queue, task_instance)
|
||||
self.queued_tasks[key] = (command, priority, queue, simple_task_instance)
|
||||
else:
|
||||
self.log.info("could not queue task {}".format(key))
|
||||
|
||||
|
@ -86,7 +87,7 @@ class BaseExecutor(LoggingMixin):
|
|||
pickle_id=pickle_id,
|
||||
cfg_path=cfg_path)
|
||||
self.queue_command(
|
||||
task_instance,
|
||||
airflow.utils.dag_processing.SimpleTaskInstance(task_instance),
|
||||
command,
|
||||
priority=task_instance.task.priority_weight_total,
|
||||
queue=task_instance.task.queue)
|
||||
|
@ -124,26 +125,13 @@ class BaseExecutor(LoggingMixin):
|
|||
key=lambda x: x[1][1],
|
||||
reverse=True)
|
||||
for i in range(min((open_slots, len(self.queued_tasks)))):
|
||||
key, (command, _, queue, ti) = sorted_queue.pop(0)
|
||||
# TODO(jlowin) without a way to know what Job ran which tasks,
|
||||
# there is a danger that another Job started running a task
|
||||
# that was also queued to this executor. This is the last chance
|
||||
# to check if that happened. The most probable way is that a
|
||||
# Scheduler tried to run a task that was originally queued by a
|
||||
# Backfill. This fix reduces the probability of a collision but
|
||||
# does NOT eliminate it.
|
||||
key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
|
||||
self.queued_tasks.pop(key)
|
||||
ti.refresh_from_db()
|
||||
if ti.state != State.RUNNING:
|
||||
self.running[key] = command
|
||||
self.execute_async(key=key,
|
||||
command=command,
|
||||
queue=queue,
|
||||
executor_config=ti.executor_config)
|
||||
else:
|
||||
self.log.info(
|
||||
'Task is already running, not sending to '
|
||||
'executor: {}'.format(key))
|
||||
self.running[key] = command
|
||||
self.execute_async(key=key,
|
||||
command=command,
|
||||
queue=queue,
|
||||
executor_config=simple_ti.executor_config)
|
||||
|
||||
# Calling child class sync method
|
||||
self.log.debug("Calling the %s sync method", self.__class__)
|
||||
|
@ -151,7 +139,7 @@ class BaseExecutor(LoggingMixin):
|
|||
|
||||
def change_state(self, key, state):
|
||||
self.log.debug("Changing state: {}".format(key))
|
||||
self.running.pop(key)
|
||||
self.running.pop(key, None)
|
||||
self.event_buffer[key] = state
|
||||
|
||||
def fail(self, key):
|
||||
|
|
|
@ -33,10 +33,13 @@ from airflow.exceptions import AirflowException
|
|||
from airflow.executors.base_executor import BaseExecutor
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from airflow.utils.module_loading import import_string
|
||||
from airflow.utils.timeout import timeout
|
||||
|
||||
# Make it constant for unit test.
|
||||
CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'
|
||||
|
||||
CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task'
|
||||
|
||||
'''
|
||||
To start the celery worker, run the command:
|
||||
airflow worker
|
||||
|
@ -55,12 +58,12 @@ app = Celery(
|
|||
|
||||
|
||||
@app.task
|
||||
def execute_command(command):
|
||||
def execute_command(command_to_exec):
|
||||
log = LoggingMixin().log
|
||||
log.info("Executing command in Celery: %s", command)
|
||||
log.info("Executing command in Celery: %s", command_to_exec)
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
subprocess.check_call(command, stderr=subprocess.STDOUT,
|
||||
subprocess.check_call(command_to_exec, stderr=subprocess.STDOUT,
|
||||
close_fds=True, env=env)
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.exception('execute_command encountered a CalledProcessError')
|
||||
|
@ -95,9 +98,10 @@ def fetch_celery_task_state(celery_task):
|
|||
"""
|
||||
|
||||
try:
|
||||
# Accessing state property of celery task will make actual network request
|
||||
# to get the current state of the task.
|
||||
res = (celery_task[0], celery_task[1].state)
|
||||
with timeout(seconds=2):
|
||||
# Accessing state property of celery task will make actual network request
|
||||
# to get the current state of the task.
|
||||
res = (celery_task[0], celery_task[1].state)
|
||||
except Exception as e:
|
||||
exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
|
||||
traceback.format_exc())
|
||||
|
@ -105,6 +109,19 @@ def fetch_celery_task_state(celery_task):
|
|||
return res
|
||||
|
||||
|
||||
def send_task_to_executor(task_tuple):
|
||||
key, simple_ti, command, queue, task = task_tuple
|
||||
try:
|
||||
with timeout(seconds=2):
|
||||
result = task.apply_async(args=[command], queue=queue)
|
||||
except Exception as e:
|
||||
exception_traceback = "Celery Task ID: {}\n{}".format(key,
|
||||
traceback.format_exc())
|
||||
result = ExceptionWithTraceback(e, exception_traceback)
|
||||
|
||||
return key, command, result
|
||||
|
||||
|
||||
class CeleryExecutor(BaseExecutor):
|
||||
"""
|
||||
CeleryExecutor is recommended for production use of Airflow. It allows
|
||||
|
@ -135,24 +152,91 @@ class CeleryExecutor(BaseExecutor):
|
|||
'Starting Celery Executor using {} processes for syncing'.format(
|
||||
self._sync_parallelism))
|
||||
|
||||
def execute_async(self, key, command,
|
||||
queue=DEFAULT_CELERY_CONFIG['task_default_queue'],
|
||||
executor_config=None):
|
||||
self.log.info("[celery] queuing {key} through celery, "
|
||||
"queue={queue}".format(**locals()))
|
||||
self.tasks[key] = execute_command.apply_async(
|
||||
args=[command], queue=queue)
|
||||
self.last_state[key] = celery_states.PENDING
|
||||
def _num_tasks_per_send_process(self, to_send_count):
|
||||
"""
|
||||
How many Celery tasks should each worker process send.
|
||||
|
||||
def _num_tasks_per_process(self):
|
||||
:return: Number of tasks that should be sent per process
|
||||
:rtype: int
|
||||
"""
|
||||
return max(1,
|
||||
int(math.ceil(1.0 * to_send_count / self._sync_parallelism)))
|
||||
|
||||
def _num_tasks_per_fetch_process(self):
|
||||
"""
|
||||
How many Celery tasks should be sent to each worker process.
|
||||
|
||||
:return: Number of tasks that should be used per process
|
||||
:rtype: int
|
||||
"""
|
||||
return max(1,
|
||||
int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism)))
|
||||
|
||||
def heartbeat(self):
|
||||
# Triggering new jobs
|
||||
if not self.parallelism:
|
||||
open_slots = len(self.queued_tasks)
|
||||
else:
|
||||
open_slots = self.parallelism - len(self.running)
|
||||
|
||||
self.log.debug("{} running task instances".format(len(self.running)))
|
||||
self.log.debug("{} in queue".format(len(self.queued_tasks)))
|
||||
self.log.debug("{} open slots".format(open_slots))
|
||||
|
||||
sorted_queue = sorted(
|
||||
[(k, v) for k, v in self.queued_tasks.items()],
|
||||
key=lambda x: x[1][1],
|
||||
reverse=True)
|
||||
|
||||
task_tuples_to_send = []
|
||||
|
||||
for i in range(min((open_slots, len(self.queued_tasks)))):
|
||||
key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
|
||||
task_tuples_to_send.append((key, simple_ti, command, queue,
|
||||
execute_command))
|
||||
|
||||
cached_celery_backend = None
|
||||
if task_tuples_to_send:
|
||||
tasks = [t[4] for t in task_tuples_to_send]
|
||||
|
||||
# Celery state queries will stuck if we do not use one same backend
|
||||
# for all tasks.
|
||||
cached_celery_backend = tasks[0].backend
|
||||
|
||||
if task_tuples_to_send:
|
||||
# Use chunking instead of a work queue to reduce context switching
|
||||
# since tasks are roughly uniform in size
|
||||
chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
|
||||
num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
|
||||
|
||||
send_pool = Pool(processes=num_processes)
|
||||
key_and_async_results = send_pool.map(
|
||||
send_task_to_executor,
|
||||
task_tuples_to_send,
|
||||
chunksize=chunksize)
|
||||
|
||||
send_pool.close()
|
||||
send_pool.join()
|
||||
self.log.debug('Sent all tasks.')
|
||||
|
||||
for key, command, result in key_and_async_results:
|
||||
if isinstance(result, ExceptionWithTraceback):
|
||||
self.log.error(
|
||||
CELERY_SEND_ERR_MSG_HEADER + ":{}\n{}\n".format(
|
||||
result.exception, result.traceback))
|
||||
elif result is not None:
|
||||
# Only pops when enqueued successfully, otherwise keep it
|
||||
# and expect scheduler loop to deal with it.
|
||||
self.queued_tasks.pop(key)
|
||||
result.backend = cached_celery_backend
|
||||
self.running[key] = command
|
||||
self.tasks[key] = result
|
||||
self.last_state[key] = celery_states.PENDING
|
||||
|
||||
# Calling child class sync method
|
||||
self.log.debug("Calling the {} sync method".format(self.__class__))
|
||||
self.sync()
|
||||
|
||||
def sync(self):
|
||||
num_processes = min(len(self.tasks), self._sync_parallelism)
|
||||
if num_processes == 0:
|
||||
|
@ -167,7 +251,7 @@ class CeleryExecutor(BaseExecutor):
|
|||
|
||||
# Use chunking instead of a work queue to reduce context switching since tasks are
|
||||
# roughly uniform in size
|
||||
chunksize = self._num_tasks_per_process()
|
||||
chunksize = self._num_tasks_per_fetch_process()
|
||||
|
||||
self.log.debug("Waiting for inquiries to complete...")
|
||||
task_keys_to_states = self._sync_pool.map(
|
||||
|
|
147
airflow/jobs.py
147
airflow/jobs.py
|
@ -52,6 +52,7 @@ from airflow.utils.dag_processing import (AbstractDagFileProcessor,
|
|||
DagFileProcessorAgent,
|
||||
SimpleDag,
|
||||
SimpleDagBag,
|
||||
SimpleTaskInstance,
|
||||
list_py_file_paths)
|
||||
from airflow.utils.db import create_session, provide_session
|
||||
from airflow.utils.email import get_email_address_list, send_email
|
||||
|
@ -598,6 +599,7 @@ class SchedulerJob(BaseJob):
|
|||
'run_duration')
|
||||
|
||||
self.processor_agent = None
|
||||
self._last_loop = False
|
||||
|
||||
signal.signal(signal.SIGINT, self._exit_gracefully)
|
||||
signal.signal(signal.SIGTERM, self._exit_gracefully)
|
||||
|
@ -1228,13 +1230,13 @@ class SchedulerJob(BaseJob):
|
|||
acceptable_states, session=None):
|
||||
"""
|
||||
Changes the state of task instances in the list with one of the given states
|
||||
to QUEUED atomically, and returns the TIs changed.
|
||||
to QUEUED atomically, and returns the TIs changed in SimpleTaskInstance format.
|
||||
|
||||
:param task_instances: TaskInstances to change the state of
|
||||
:type task_instances: List[TaskInstance]
|
||||
:param acceptable_states: Filters the TaskInstances updated to be in these states
|
||||
:type acceptable_states: Iterable[State]
|
||||
:return: List[TaskInstance]
|
||||
:return: List[SimpleTaskInstance]
|
||||
"""
|
||||
if len(task_instances) == 0:
|
||||
session.commit()
|
||||
|
@ -1276,81 +1278,57 @@ class SchedulerJob(BaseJob):
|
|||
else task_instance.queued_dttm)
|
||||
session.merge(task_instance)
|
||||
|
||||
# save which TIs we set before session expires them
|
||||
filter_for_ti_enqueue = ([and_(TI.dag_id == ti.dag_id,
|
||||
TI.task_id == ti.task_id,
|
||||
TI.execution_date == ti.execution_date)
|
||||
for ti in tis_to_set_to_queued])
|
||||
# Generate a list of SimpleTaskInstance for the use of queuing
|
||||
# them in the executor.
|
||||
simple_task_instances = [SimpleTaskInstance(ti) for ti in
|
||||
tis_to_set_to_queued]
|
||||
|
||||
task_instance_str = "\n\t".join(
|
||||
["{}".format(x) for x in tis_to_set_to_queued])
|
||||
|
||||
session.commit()
|
||||
self.log.info("Setting the following {} tasks to queued state:\n\t{}"
|
||||
.format(len(tis_to_set_to_queued), task_instance_str))
|
||||
return simple_task_instances
|
||||
|
||||
# requery in batches since above was expired by commit
|
||||
|
||||
def query(result, items):
|
||||
tis_to_be_queued = (
|
||||
session
|
||||
.query(TI)
|
||||
.filter(or_(*items))
|
||||
.all())
|
||||
task_instance_str = "\n\t".join(
|
||||
["{}".format(x) for x in tis_to_be_queued])
|
||||
self.log.info("Setting the following {} tasks to queued state:\n\t{}"
|
||||
.format(len(tis_to_be_queued),
|
||||
task_instance_str))
|
||||
return result + tis_to_be_queued
|
||||
|
||||
tis_to_be_queued = helpers.reduce_in_chunks(query,
|
||||
filter_for_ti_enqueue,
|
||||
[],
|
||||
self.max_tis_per_query)
|
||||
|
||||
return tis_to_be_queued
|
||||
|
||||
def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instances):
|
||||
def _enqueue_task_instances_with_queued_state(self, simple_dag_bag,
|
||||
simple_task_instances):
|
||||
"""
|
||||
Takes task_instances, which should have been set to queued, and enqueues them
|
||||
with the executor.
|
||||
|
||||
:param task_instances: TaskInstances to enqueue
|
||||
:type task_instances: List[TaskInstance]
|
||||
:param simple_task_instances: TaskInstances to enqueue
|
||||
:type simple_task_instances: List[SimpleTaskInstance]
|
||||
:param simple_dag_bag: Should contains all of the task_instances' dags
|
||||
:type simple_dag_bag: SimpleDagBag
|
||||
"""
|
||||
TI = models.TaskInstance
|
||||
# actually enqueue them
|
||||
for task_instance in task_instances:
|
||||
simple_dag = simple_dag_bag.get_dag(task_instance.dag_id)
|
||||
for simple_task_instance in simple_task_instances:
|
||||
simple_dag = simple_dag_bag.get_dag(simple_task_instance.dag_id)
|
||||
command = TI.generate_command(
|
||||
task_instance.dag_id,
|
||||
task_instance.task_id,
|
||||
task_instance.execution_date,
|
||||
simple_task_instance.dag_id,
|
||||
simple_task_instance.task_id,
|
||||
simple_task_instance.execution_date,
|
||||
local=True,
|
||||
mark_success=False,
|
||||
ignore_all_deps=False,
|
||||
ignore_depends_on_past=False,
|
||||
ignore_task_deps=False,
|
||||
ignore_ti_state=False,
|
||||
pool=task_instance.pool,
|
||||
pool=simple_task_instance.pool,
|
||||
file_path=simple_dag.full_filepath,
|
||||
pickle_id=simple_dag.pickle_id)
|
||||
|
||||
priority = task_instance.priority_weight
|
||||
queue = task_instance.queue
|
||||
priority = simple_task_instance.priority_weight
|
||||
queue = simple_task_instance.queue
|
||||
self.log.info(
|
||||
"Sending %s to executor with priority %s and queue %s",
|
||||
task_instance.key, priority, queue
|
||||
simple_task_instance.key, priority, queue
|
||||
)
|
||||
|
||||
# save attributes so sqlalchemy doesnt expire them
|
||||
copy_dag_id = task_instance.dag_id
|
||||
copy_task_id = task_instance.task_id
|
||||
copy_execution_date = task_instance.execution_date
|
||||
make_transient(task_instance)
|
||||
task_instance.dag_id = copy_dag_id
|
||||
task_instance.task_id = copy_task_id
|
||||
task_instance.execution_date = copy_execution_date
|
||||
|
||||
self.executor.queue_command(
|
||||
task_instance,
|
||||
simple_task_instance,
|
||||
command,
|
||||
priority=priority,
|
||||
queue=queue)
|
||||
|
@ -1374,24 +1352,65 @@ class SchedulerJob(BaseJob):
|
|||
:type simple_dag_bag: SimpleDagBag
|
||||
:param states: Execute TaskInstances in these states
|
||||
:type states: Tuple[State]
|
||||
:return: None
|
||||
:return: Number of task instance with state changed.
|
||||
"""
|
||||
executable_tis = self._find_executable_task_instances(simple_dag_bag, states,
|
||||
session=session)
|
||||
|
||||
def query(result, items):
|
||||
tis_with_state_changed = self._change_state_for_executable_task_instances(
|
||||
items,
|
||||
states,
|
||||
session=session)
|
||||
simple_tis_with_state_changed = \
|
||||
self._change_state_for_executable_task_instances(items,
|
||||
states,
|
||||
session=session)
|
||||
self._enqueue_task_instances_with_queued_state(
|
||||
simple_dag_bag,
|
||||
tis_with_state_changed)
|
||||
simple_tis_with_state_changed)
|
||||
session.commit()
|
||||
return result + len(tis_with_state_changed)
|
||||
return result + len(simple_tis_with_state_changed)
|
||||
|
||||
return helpers.reduce_in_chunks(query, executable_tis, 0, self.max_tis_per_query)
|
||||
|
||||
@provide_session
|
||||
def _change_state_for_tasks_failed_to_execute(self, session):
|
||||
"""
|
||||
If there are tasks left over in the executor,
|
||||
we set them back to SCHEDULED to avoid creating hanging tasks.
|
||||
|
||||
:param session: session for ORM operations
|
||||
"""
|
||||
if self.executor.queued_tasks:
|
||||
TI = models.TaskInstance
|
||||
filter_for_ti_state_change = (
|
||||
[and_(
|
||||
TI.dag_id == dag_id,
|
||||
TI.task_id == task_id,
|
||||
TI.execution_date == execution_date,
|
||||
# The TI.try_number will return raw try_number+1 since the
|
||||
# ti is not running. And we need to -1 to match the DB record.
|
||||
TI._try_number == try_number - 1,
|
||||
TI.state == State.QUEUED)
|
||||
for dag_id, task_id, execution_date, try_number
|
||||
in self.executor.queued_tasks.keys()])
|
||||
ti_query = (session.query(TI)
|
||||
.filter(or_(*filter_for_ti_state_change)))
|
||||
tis_to_set_to_scheduled = (ti_query
|
||||
.with_for_update()
|
||||
.all())
|
||||
if len(tis_to_set_to_scheduled) == 0:
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# set TIs to queued state
|
||||
for task_instance in tis_to_set_to_scheduled:
|
||||
task_instance.state = State.SCHEDULED
|
||||
|
||||
task_instance_str = "\n\t".join(
|
||||
["{}".format(x) for x in tis_to_set_to_scheduled])
|
||||
|
||||
session.commit()
|
||||
self.log.info("Set the follow tasks to scheduled state:\n\t{}"
|
||||
.format(task_instance_str))
|
||||
|
||||
def _process_dags(self, dagbag, dags, tis_out):
|
||||
"""
|
||||
Iterates over the dags and processes them. Processing includes:
|
||||
|
@ -1507,6 +1526,8 @@ class SchedulerJob(BaseJob):
|
|||
|
||||
try:
|
||||
self._execute_helper()
|
||||
except Exception:
|
||||
self.log.exception("Exception when executing execute_helper")
|
||||
finally:
|
||||
self.processor_agent.end()
|
||||
self.log.info("Exited execute loop")
|
||||
|
@ -1557,6 +1578,7 @@ class SchedulerJob(BaseJob):
|
|||
|
||||
self.log.info("Harvesting DAG parsing results")
|
||||
simple_dags = self.processor_agent.harvest_simple_dags()
|
||||
self.log.debug("Harvested {} SimpleDAGs".format(len(simple_dags)))
|
||||
|
||||
# Send tasks for execution if available
|
||||
simple_dag_bag = SimpleDagBag(simple_dags)
|
||||
|
@ -1593,6 +1615,8 @@ class SchedulerJob(BaseJob):
|
|||
self.log.debug("Heartbeating the executor")
|
||||
self.executor.heartbeat()
|
||||
|
||||
self._change_state_for_tasks_failed_to_execute()
|
||||
|
||||
# Process events from the executor
|
||||
self._process_executor_events(simple_dag_bag)
|
||||
|
||||
|
@ -1612,8 +1636,13 @@ class SchedulerJob(BaseJob):
|
|||
self.log.debug("Sleeping for %.2f seconds", self._processor_poll_interval)
|
||||
time.sleep(self._processor_poll_interval)
|
||||
|
||||
# Exit early for a test mode
|
||||
# Exit early for a test mode, run one additional scheduler loop
|
||||
# to reduce the possibility that parsed DAG was put into the queue
|
||||
# by the DAG manager but not yet received by DAG agent.
|
||||
if self.processor_agent.done:
|
||||
self._last_loop = True
|
||||
|
||||
if self._last_loop:
|
||||
self.log.info("Exiting scheduler loop as all files"
|
||||
" have been processed {} times".format(self.num_runs))
|
||||
break
|
||||
|
|
|
@ -22,12 +22,12 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import copy
|
||||
from collections import defaultdict, namedtuple
|
||||
|
||||
from builtins import ImportError as BuiltinImportError, bytes, object, str
|
||||
from future.standard_library import install_aliases
|
||||
|
||||
from builtins import str, object, bytes, ImportError as BuiltinImportError
|
||||
import copy
|
||||
from collections import namedtuple, defaultdict
|
||||
try:
|
||||
# Fix Python > 3.7 deprecation
|
||||
from collections.abc import Hashable
|
||||
|
|
|
@ -146,6 +146,21 @@ class SimpleTaskInstance(object):
|
|||
self._end_date = ti.end_date
|
||||
self._try_number = ti.try_number
|
||||
self._state = ti.state
|
||||
self._executor_config = ti.executor_config
|
||||
if hasattr(ti, 'run_as_user'):
|
||||
self._run_as_user = ti.run_as_user
|
||||
else:
|
||||
self._run_as_user = None
|
||||
if hasattr(ti, 'pool'):
|
||||
self._pool = ti.pool
|
||||
else:
|
||||
self._pool = None
|
||||
if hasattr(ti, 'priority_weight'):
|
||||
self._priority_weight = ti.priority_weight
|
||||
else:
|
||||
self._priority_weight = None
|
||||
self._queue = ti.queue
|
||||
self._key = ti.key
|
||||
|
||||
@property
|
||||
def dag_id(self):
|
||||
|
@ -175,6 +190,49 @@ class SimpleTaskInstance(object):
|
|||
def state(self):
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def pool(self):
|
||||
return self._pool
|
||||
|
||||
@property
|
||||
def priority_weight(self):
|
||||
return self._priority_weight
|
||||
|
||||
@property
|
||||
def queue(self):
|
||||
return self._queue
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key
|
||||
|
||||
@property
|
||||
def executor_config(self):
|
||||
return self._executor_config
|
||||
|
||||
@provide_session
|
||||
def construct_task_instance(self, session=None, lock_for_update=False):
|
||||
"""
|
||||
Construct a TaskInstance from the database based on the primary key
|
||||
|
||||
:param session: DB session.
|
||||
:param lock_for_update: if True, indicates that the database should
|
||||
lock the TaskInstance (issuing a FOR UPDATE clause) until the
|
||||
session is committed.
|
||||
"""
|
||||
TI = airflow.models.TaskInstance
|
||||
|
||||
qry = session.query(TI).filter(
|
||||
TI.dag_id == self._dag_id,
|
||||
TI.task_id == self._task_id,
|
||||
TI.execution_date == self._execution_date)
|
||||
|
||||
if lock_for_update:
|
||||
ti = qry.with_for_update().first()
|
||||
else:
|
||||
ti = qry.first()
|
||||
return ti
|
||||
|
||||
|
||||
class SimpleDagBag(BaseDagBag):
|
||||
"""
|
||||
|
@ -571,11 +629,16 @@ class DagFileProcessorAgent(LoggingMixin):
|
|||
Terminate (and then kill) the manager process launched.
|
||||
:return:
|
||||
"""
|
||||
if not self._process or not self._process.is_alive():
|
||||
if not self._process:
|
||||
self.log.warn('Ending without manager process.')
|
||||
return
|
||||
this_process = psutil.Process(os.getpid())
|
||||
manager_process = psutil.Process(self._process.pid)
|
||||
try:
|
||||
manager_process = psutil.Process(self._process.pid)
|
||||
except psutil.NoSuchProcess:
|
||||
self.log.info("Manager process not running.")
|
||||
return
|
||||
|
||||
# First try SIGTERM
|
||||
if manager_process.is_running() \
|
||||
and manager_process.pid in [x.pid for x in this_process.children()]:
|
||||
|
|
|
@ -23,6 +23,7 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import signal
|
||||
import os
|
||||
|
||||
from airflow.exceptions import AirflowTaskTimeout
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
|
@ -35,10 +36,10 @@ class timeout(LoggingMixin):
|
|||
|
||||
def __init__(self, seconds=1, error_message='Timeout'):
|
||||
self.seconds = seconds
|
||||
self.error_message = error_message
|
||||
self.error_message = error_message + ', PID: ' + str(os.getpid())
|
||||
|
||||
def handle_timeout(self, signum, frame):
|
||||
self.log.error("Process timed out")
|
||||
self.log.error("Process timed out, PID: " + str(os.getpid()))
|
||||
raise AirflowTaskTimeout(self.error_message)
|
||||
|
||||
def __enter__(self):
|
||||
|
|
|
@ -18,12 +18,17 @@
|
|||
# under the License.
|
||||
import sys
|
||||
import unittest
|
||||
from multiprocessing import Pool
|
||||
|
||||
import mock
|
||||
from celery.contrib.testing.worker import start_worker
|
||||
|
||||
from airflow.executors.celery_executor import CeleryExecutor
|
||||
from airflow.executors.celery_executor import app
|
||||
from airflow.executors import celery_executor
|
||||
from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER
|
||||
from airflow.executors.celery_executor import (CeleryExecutor, celery_configuration,
|
||||
send_task_to_executor, execute_command)
|
||||
from airflow.executors.celery_executor import app
|
||||
from celery import states as celery_states
|
||||
from airflow.utils.state import State
|
||||
|
||||
from airflow import configuration
|
||||
|
@ -40,16 +45,37 @@ class CeleryExecutorTest(unittest.TestCase):
|
|||
executor = CeleryExecutor()
|
||||
executor.start()
|
||||
with start_worker(app=app, logfile=sys.stdout, loglevel='debug'):
|
||||
|
||||
success_command = ['true', 'some_parameter']
|
||||
fail_command = ['false', 'some_parameter']
|
||||
|
||||
executor.execute_async(key='success', command=success_command)
|
||||
# errors are propagated for some reason
|
||||
try:
|
||||
executor.execute_async(key='fail', command=fail_command)
|
||||
except Exception:
|
||||
pass
|
||||
cached_celery_backend = execute_command.backend
|
||||
task_tuples_to_send = [('success', 'fake_simple_ti', success_command,
|
||||
celery_configuration['task_default_queue'],
|
||||
execute_command),
|
||||
('fail', 'fake_simple_ti', fail_command,
|
||||
celery_configuration['task_default_queue'],
|
||||
execute_command)]
|
||||
|
||||
chunksize = executor._num_tasks_per_send_process(len(task_tuples_to_send))
|
||||
num_processes = min(len(task_tuples_to_send), executor._sync_parallelism)
|
||||
|
||||
send_pool = Pool(processes=num_processes)
|
||||
key_and_async_results = send_pool.map(
|
||||
send_task_to_executor,
|
||||
task_tuples_to_send,
|
||||
chunksize=chunksize)
|
||||
|
||||
send_pool.close()
|
||||
send_pool.join()
|
||||
|
||||
for key, command, result in key_and_async_results:
|
||||
# Only pops when enqueued successfully, otherwise keep it
|
||||
# and expect scheduler loop to deal with it.
|
||||
result.backend = cached_celery_backend
|
||||
executor.running[key] = command
|
||||
executor.tasks[key] = result
|
||||
executor.last_state[key] = celery_states.PENDING
|
||||
|
||||
executor.running['success'] = True
|
||||
executor.running['fail'] = True
|
||||
|
||||
|
@ -64,6 +90,23 @@ class CeleryExecutorTest(unittest.TestCase):
|
|||
self.assertNotIn('success', executor.last_state)
|
||||
self.assertNotIn('fail', executor.last_state)
|
||||
|
||||
@unittest.skipIf('sqlite' in configuration.conf.get('core', 'sql_alchemy_conn'),
|
||||
"sqlite is configured with SequentialExecutor")
|
||||
def test_error_sending_task(self):
|
||||
@app.task
|
||||
def fake_execute_command():
|
||||
pass
|
||||
|
||||
# fake_execute_command takes no arguments while execute_command takes 1,
|
||||
# which will cause TypeError when calling task.apply_async()
|
||||
celery_executor.execute_command = fake_execute_command
|
||||
executor = CeleryExecutor()
|
||||
value_tuple = 'command', '_', 'queue', 'should_be_a_simple_ti'
|
||||
executor.queued_tasks['key'] = value_tuple
|
||||
executor.heartbeat()
|
||||
self.assertEquals(1, len(executor.queued_tasks))
|
||||
self.assertEquals(executor.queued_tasks['key'], value_tuple)
|
||||
|
||||
def test_exception_propagation(self):
|
||||
@app.task
|
||||
def fake_celery_task():
|
||||
|
|
|
@ -46,7 +46,8 @@ class TestExecutor(BaseExecutor):
|
|||
ti = self._running.pop()
|
||||
ti.set_state(State.SUCCESS, session)
|
||||
for key, val in list(self.queued_tasks.items()):
|
||||
(command, priority, queue, ti) = val
|
||||
(command, priority, queue, simple_ti) = val
|
||||
ti = simple_ti.construct_task_instance()
|
||||
ti.set_state(State.RUNNING, session)
|
||||
self._running.append(ti)
|
||||
self.queued_tasks.pop(key)
|
||||
|
|
|
@ -2033,6 +2033,54 @@ class SchedulerJobTest(unittest.TestCase):
|
|||
ti2.refresh_from_db(session=session)
|
||||
self.assertEqual(ti2.state, State.SCHEDULED)
|
||||
|
||||
def test_change_state_for_tasks_failed_to_execute(self):
|
||||
dag = DAG(
|
||||
dag_id='dag_id',
|
||||
start_date=DEFAULT_DATE)
|
||||
|
||||
task = DummyOperator(
|
||||
task_id='task_id',
|
||||
dag=dag,
|
||||
owner='airflow')
|
||||
|
||||
# If there's no left over task in executor.queued_tasks, nothing happens
|
||||
session = settings.Session()
|
||||
scheduler_job = SchedulerJob()
|
||||
mock_logger = mock.MagicMock()
|
||||
test_executor = TestExecutor()
|
||||
scheduler_job.executor = test_executor
|
||||
scheduler_job._logger = mock_logger
|
||||
scheduler_job._change_state_for_tasks_failed_to_execute()
|
||||
mock_logger.info.assert_not_called()
|
||||
|
||||
# Tasks failed to execute with QUEUED state will be set to SCHEDULED state.
|
||||
session.query(TI).delete()
|
||||
session.commit()
|
||||
key = 'dag_id', 'task_id', DEFAULT_DATE, 1
|
||||
test_executor.queued_tasks[key] = 'value'
|
||||
ti = TI(task, DEFAULT_DATE)
|
||||
ti.state = State.QUEUED
|
||||
session.merge(ti)
|
||||
session.commit()
|
||||
|
||||
scheduler_job._change_state_for_tasks_failed_to_execute()
|
||||
|
||||
ti.refresh_from_db()
|
||||
self.assertEquals(State.SCHEDULED, ti.state)
|
||||
|
||||
# Tasks failed to execute with RUNNING state will not be set to SCHEDULED state.
|
||||
session.query(TI).delete()
|
||||
session.commit()
|
||||
ti.state = State.RUNNING
|
||||
|
||||
session.merge(ti)
|
||||
session.commit()
|
||||
|
||||
scheduler_job._change_state_for_tasks_failed_to_execute()
|
||||
|
||||
ti.refresh_from_db()
|
||||
self.assertEquals(State.RUNNING, ti.state)
|
||||
|
||||
def test_execute_helper_reset_orphaned_tasks(self):
|
||||
session = settings.Session()
|
||||
dag = DAG(
|
||||
|
@ -2950,7 +2998,8 @@ class SchedulerJobTest(unittest.TestCase):
|
|||
pass
|
||||
|
||||
ti_tuple = six.next(six.itervalues(executor.queued_tasks))
|
||||
(command, priority, queue, ti) = ti_tuple
|
||||
(command, priority, queue, simple_ti) = ti_tuple
|
||||
ti = simple_ti.construct_task_instance()
|
||||
ti.task = dag_task1
|
||||
|
||||
self.assertEqual(ti.try_number, 1)
|
||||
|
@ -2971,15 +3020,21 @@ class SchedulerJobTest(unittest.TestCase):
|
|||
# removing self.assertEqual(ti.state, State.SCHEDULED)
|
||||
# as scheduler will move state from SCHEDULED to QUEUED
|
||||
|
||||
# now the executor has cleared and it should be allowed the re-queue
|
||||
# now the executor has cleared and it should be allowed the re-queue,
|
||||
# but tasks stay in the executor.queued_tasks after executor.heartbeat()
|
||||
# will be set back to SCHEDULED state
|
||||
executor.queued_tasks.clear()
|
||||
do_schedule()
|
||||
ti.refresh_from_db()
|
||||
self.assertEqual(ti.state, State.QUEUED)
|
||||
# calling below again in order to ensure with try_number 2,
|
||||
# scheduler doesn't put task in queue
|
||||
|
||||
self.assertEqual(ti.state, State.SCHEDULED)
|
||||
|
||||
# To verify that task does get re-queued.
|
||||
executor.queued_tasks.clear()
|
||||
executor.do_update = True
|
||||
do_schedule()
|
||||
self.assertEquals(1, len(executor.queued_tasks))
|
||||
ti.refresh_from_db()
|
||||
self.assertEqual(ti.state, State.RUNNING)
|
||||
|
||||
@unittest.skipUnless("INTEGRATION" in os.environ, "Can only run end to end")
|
||||
def test_retry_handling_job(self):
|
||||
|
@ -3024,8 +3079,8 @@ class SchedulerJobTest(unittest.TestCase):
|
|||
logging.info("Test ran in %.2fs, expected %.2fs",
|
||||
run_duration,
|
||||
expected_run_duration)
|
||||
# 5s to wait for child process to exit and 1s dummy sleep
|
||||
# in scheduler loop to prevent excessive logs.
|
||||
# 5s to wait for child process to exit, 1s dummy sleep
|
||||
# in scheduler loop to prevent excessive logs and 1s for last loop to finish.
|
||||
self.assertLess(run_duration - expected_run_duration, 6.0)
|
||||
|
||||
def test_dag_with_system_exit(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче