[AIRFLOW-2156] Parallelize Celery Executor task state fetching (#3830)
This commit is contained in:
Родитель
1f038a7919
Коммит
9b82fcb5fb
|
@ -17,6 +17,11 @@ so you might need to update your config.
|
|||
The scheduler.min_file_parsing_loop_time config option has been temporarily removed due to
|
||||
some bugs.
|
||||
|
||||
### new `sync_parallelism` config option in celery section
|
||||
|
||||
The new `sync_parallelism` config option will control how many processes CeleryExecutor will use to
|
||||
fetch celery task state in parallel. Default value is max(1, number of cores - 1)
|
||||
|
||||
## Airflow 1.10
|
||||
|
||||
Installation and upgrading requires setting `SLUGIFY_USES_TEXT_UNIDECODE=yes` in your environment or
|
||||
|
|
|
@ -380,6 +380,10 @@ flower_port = 5555
|
|||
# Default queue that tasks get assigned to and that worker listen on.
|
||||
default_queue = default
|
||||
|
||||
# How many processes CeleryExecutor uses to sync task state.
|
||||
# 0 means to use max(1, number of cores - 1) processes.
|
||||
sync_parallelism = 0
|
||||
|
||||
# Import path for celery configuration options
|
||||
celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG
|
||||
|
||||
|
|
|
@ -97,6 +97,7 @@ result_backend = db+mysql://airflow:airflow@localhost:3306/airflow
|
|||
flower_host = 0.0.0.0
|
||||
flower_port = 5555
|
||||
default_queue = default
|
||||
sync_parallelism = 0
|
||||
|
||||
[mesos]
|
||||
master = localhost:5050
|
||||
|
|
|
@ -17,20 +17,26 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import os
|
||||
import traceback
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
from celery import Celery
|
||||
from celery import states as celery_states
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.executors.base_executor import BaseExecutor
|
||||
from airflow import configuration
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from airflow.utils.module_loading import import_string
|
||||
|
||||
# Make it constant for unit test.
|
||||
CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'
|
||||
|
||||
'''
|
||||
To start the celery worker, run the command:
|
||||
airflow worker
|
||||
|
@ -63,6 +69,42 @@ def execute_command(command):
|
|||
raise AirflowException('Celery command failed')
|
||||
|
||||
|
||||
class ExceptionWithTraceback(object):
|
||||
"""
|
||||
Wrapper class used to propogate exceptions to parent processes from subprocesses.
|
||||
:param exception: The exception to wrap
|
||||
:type exception: Exception
|
||||
:param traceback: The stacktrace to wrap
|
||||
:type traceback: str
|
||||
"""
|
||||
|
||||
def __init__(self, exception, exception_traceback):
|
||||
self.exception = exception
|
||||
self.traceback = exception_traceback
|
||||
|
||||
|
||||
def fetch_celery_task_state(celery_task):
|
||||
"""
|
||||
Fetch and return the state of the given celery task. The scope of this function is
|
||||
global so that it can be called by subprocesses in the pool.
|
||||
:param celery_task: a tuple of the Celery task key and the async Celery object used
|
||||
to fetch the task's state
|
||||
:type celery_task: (str, celery.result.AsyncResult)
|
||||
:return: a tuple of the Celery task key and the Celery state of the task
|
||||
:rtype: (str, str)
|
||||
"""
|
||||
|
||||
try:
|
||||
# Accessing state property of celery task will make actual network request
|
||||
# to get the current state of the task.
|
||||
res = (celery_task[0], celery_task[1].state)
|
||||
except Exception as e:
|
||||
exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
|
||||
traceback.format_exc())
|
||||
res = ExceptionWithTraceback(e, exception_traceback)
|
||||
return res
|
||||
|
||||
|
||||
class CeleryExecutor(BaseExecutor):
|
||||
"""
|
||||
CeleryExecutor is recommended for production use of Airflow. It allows
|
||||
|
@ -72,10 +114,27 @@ class CeleryExecutor(BaseExecutor):
|
|||
vast amounts of messages, while providing operations with the tools
|
||||
required to maintain such a system.
|
||||
"""
|
||||
def start(self):
|
||||
|
||||
def __init__(self):
|
||||
super(CeleryExecutor, self).__init__()
|
||||
|
||||
# Celery doesn't support querying the state of multiple tasks in parallel
|
||||
# (which can become a bottleneck on bigger clusters) so we use
|
||||
# a multiprocessing pool to speed this up.
|
||||
# How many worker processes are created for checking celery task state.
|
||||
self._sync_parallelism = configuration.getint('celery', 'SYNC_PARALLELISM')
|
||||
if self._sync_parallelism == 0:
|
||||
self._sync_parallelism = max(1, cpu_count() - 1)
|
||||
|
||||
self._sync_pool = None
|
||||
self.tasks = {}
|
||||
self.last_state = {}
|
||||
|
||||
def start(self):
|
||||
self.log.debug(
|
||||
'Starting Celery Executor using {} processes for syncing'.format(
|
||||
self._sync_parallelism))
|
||||
|
||||
def execute_async(self, key, command,
|
||||
queue=DEFAULT_CELERY_CONFIG['task_default_queue'],
|
||||
executor_config=None):
|
||||
|
@ -85,11 +144,48 @@ class CeleryExecutor(BaseExecutor):
|
|||
args=[command], queue=queue)
|
||||
self.last_state[key] = celery_states.PENDING
|
||||
|
||||
def _num_tasks_per_process(self):
|
||||
"""
|
||||
How many Celery tasks should be sent to each worker process.
|
||||
:return: Number of tasks that should be used per process
|
||||
:rtype: int
|
||||
"""
|
||||
return max(1,
|
||||
int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism)))
|
||||
|
||||
def sync(self):
|
||||
self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
|
||||
for key, task in list(self.tasks.items()):
|
||||
num_processes = min(len(self.tasks), self._sync_parallelism)
|
||||
if num_processes == 0:
|
||||
self.log.debug("No task to query celery, skipping sync")
|
||||
return
|
||||
|
||||
self.log.debug("Inquiring about %s celery task(s) using %s processes",
|
||||
len(self.tasks), num_processes)
|
||||
|
||||
# Recreate the process pool each sync in case processes in the pool die
|
||||
self._sync_pool = Pool(processes=num_processes)
|
||||
|
||||
# Use chunking instead of a work queue to reduce context switching since tasks are
|
||||
# roughly uniform in size
|
||||
chunksize = self._num_tasks_per_process()
|
||||
|
||||
self.log.debug("Waiting for inquiries to complete...")
|
||||
task_keys_to_states = self._sync_pool.map(
|
||||
fetch_celery_task_state,
|
||||
self.tasks.items(),
|
||||
chunksize=chunksize)
|
||||
self._sync_pool.close()
|
||||
self._sync_pool.join()
|
||||
self.log.debug("Inquiries completed.")
|
||||
|
||||
for key_and_state in task_keys_to_states:
|
||||
if isinstance(key_and_state, ExceptionWithTraceback):
|
||||
self.log.error(
|
||||
CELERY_FETCH_ERR_MSG_HEADER + ", ignoring it:{}\n{}\n".format(
|
||||
key_and_state.exception, key_and_state.traceback))
|
||||
continue
|
||||
key, state = key_and_state
|
||||
try:
|
||||
state = task.state
|
||||
if self.last_state[key] != state:
|
||||
if state == celery_states.SUCCESS:
|
||||
self.success(key)
|
||||
|
@ -104,11 +200,10 @@ class CeleryExecutor(BaseExecutor):
|
|||
del self.tasks[key]
|
||||
del self.last_state[key]
|
||||
else:
|
||||
self.log.info("Unexpected state: %s", state)
|
||||
self.log.info("Unexpected state: " + state)
|
||||
self.last_state[key] = state
|
||||
except Exception as e:
|
||||
self.log.error("Error syncing the celery executor, ignoring it:")
|
||||
self.log.exception(e)
|
||||
except Exception:
|
||||
self.log.exception("Error syncing the Celery executor, ignoring it.")
|
||||
|
||||
def end(self, synchronous=False):
|
||||
if synchronous:
|
||||
|
|
|
@ -59,6 +59,7 @@ broker_url = amqp://guest:guest@rabbitmq:5672/
|
|||
result_backend = db+mysql://root@mysql/airflow
|
||||
flower_port = 5555
|
||||
default_queue = default
|
||||
sync_parallelism = 0
|
||||
|
||||
[celery_broker_transport_options]
|
||||
visibility_timeout = 21600
|
||||
|
|
|
@ -256,6 +256,10 @@ data:
|
|||
# Default queue that tasks get assigned to and that worker listen on.
|
||||
default_queue = default
|
||||
|
||||
# How many processes CeleryExecutor uses to sync task state.
|
||||
# 0 means to use max(1, number of cores - 1) processes.
|
||||
sync_parallelism = 0
|
||||
|
||||
# Import path for celery configuration options
|
||||
celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG
|
||||
|
||||
|
|
|
@ -18,10 +18,12 @@
|
|||
# under the License.
|
||||
import sys
|
||||
import unittest
|
||||
import mock
|
||||
from celery.contrib.testing.worker import start_worker
|
||||
|
||||
from airflow.executors.celery_executor import CeleryExecutor
|
||||
from airflow.executors.celery_executor import app
|
||||
from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER
|
||||
from airflow.utils.state import State
|
||||
|
||||
# leave this it is used by the test worker
|
||||
|
@ -57,5 +59,26 @@ class CeleryExecutorTest(unittest.TestCase):
|
|||
self.assertNotIn('success', executor.last_state)
|
||||
self.assertNotIn('fail', executor.last_state)
|
||||
|
||||
def test_exception_propagation(self):
|
||||
@app.task
|
||||
def fake_celery_task():
|
||||
return {}
|
||||
|
||||
mock_log = mock.MagicMock()
|
||||
executor = CeleryExecutor()
|
||||
executor._log = mock_log
|
||||
|
||||
executor.tasks = {'key': fake_celery_task()}
|
||||
executor.sync()
|
||||
mock_log.error.assert_called_once()
|
||||
args, kwargs = mock_log.error.call_args_list[0]
|
||||
log = args[0]
|
||||
# Result of queuing is not a celery task but a dict,
|
||||
# and it should raise AttributeError and then get propagated
|
||||
# to the error log.
|
||||
self.assertIn(CELERY_FETCH_ERR_MSG_HEADER, log)
|
||||
self.assertIn('AttributeError', log)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче