diff --git a/UPDATING.md b/UPDATING.md index 78b8327f05..0405309d62 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -17,6 +17,11 @@ so you might need to update your config. The scheduler.min_file_parsing_loop_time config option has been temporarily removed due to some bugs. +### new `sync_parallelism` config option in celery section + +The new `sync_parallelism` config option will control how many processes CeleryExecutor will use to +fetch celery task state in parallel. Default value is max(1, number of cores - 1) + ## Airflow 1.10 Installation and upgrading requires setting `SLUGIFY_USES_TEXT_UNIDECODE=yes` in your environment or diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 18c486cb1e..000dd67a13 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -380,6 +380,10 @@ flower_port = 5555 # Default queue that tasks get assigned to and that worker listen on. default_queue = default +# How many processes CeleryExecutor uses to sync task state. +# 0 means to use max(1, number of cores - 1) processes. +sync_parallelism = 0 + # Import path for celery configuration options celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index 06937452b0..f9279cce54 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -97,6 +97,7 @@ result_backend = db+mysql://airflow:airflow@localhost:3306/airflow flower_host = 0.0.0.0 flower_port = 5555 default_queue = default +sync_parallelism = 0 [mesos] master = localhost:5050 diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 61bbc66716..0de48b4d39 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -17,20 +17,26 @@ # specific language governing permissions and limitations # under the License. +import math +import os import subprocess import time -import os +import traceback +from multiprocessing import Pool, cpu_count from celery import Celery from celery import states as celery_states +from airflow import configuration from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG from airflow.exceptions import AirflowException from airflow.executors.base_executor import BaseExecutor -from airflow import configuration from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string +# Make it constant for unit test. +CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state' + ''' To start the celery worker, run the command: airflow worker @@ -63,6 +69,42 @@ def execute_command(command): raise AirflowException('Celery command failed') +class ExceptionWithTraceback(object): + """ + Wrapper class used to propogate exceptions to parent processes from subprocesses. + :param exception: The exception to wrap + :type exception: Exception + :param traceback: The stacktrace to wrap + :type traceback: str + """ + + def __init__(self, exception, exception_traceback): + self.exception = exception + self.traceback = exception_traceback + + +def fetch_celery_task_state(celery_task): + """ + Fetch and return the state of the given celery task. The scope of this function is + global so that it can be called by subprocesses in the pool. + :param celery_task: a tuple of the Celery task key and the async Celery object used + to fetch the task's state + :type celery_task: (str, celery.result.AsyncResult) + :return: a tuple of the Celery task key and the Celery state of the task + :rtype: (str, str) + """ + + try: + # Accessing state property of celery task will make actual network request + # to get the current state of the task. + res = (celery_task[0], celery_task[1].state) + except Exception as e: + exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0], + traceback.format_exc()) + res = ExceptionWithTraceback(e, exception_traceback) + return res + + class CeleryExecutor(BaseExecutor): """ CeleryExecutor is recommended for production use of Airflow. It allows @@ -72,10 +114,27 @@ class CeleryExecutor(BaseExecutor): vast amounts of messages, while providing operations with the tools required to maintain such a system. """ - def start(self): + + def __init__(self): + super(CeleryExecutor, self).__init__() + + # Celery doesn't support querying the state of multiple tasks in parallel + # (which can become a bottleneck on bigger clusters) so we use + # a multiprocessing pool to speed this up. + # How many worker processes are created for checking celery task state. + self._sync_parallelism = configuration.getint('celery', 'SYNC_PARALLELISM') + if self._sync_parallelism == 0: + self._sync_parallelism = max(1, cpu_count() - 1) + + self._sync_pool = None self.tasks = {} self.last_state = {} + def start(self): + self.log.debug( + 'Starting Celery Executor using {} processes for syncing'.format( + self._sync_parallelism)) + def execute_async(self, key, command, queue=DEFAULT_CELERY_CONFIG['task_default_queue'], executor_config=None): @@ -85,11 +144,48 @@ class CeleryExecutor(BaseExecutor): args=[command], queue=queue) self.last_state[key] = celery_states.PENDING + def _num_tasks_per_process(self): + """ + How many Celery tasks should be sent to each worker process. + :return: Number of tasks that should be used per process + :rtype: int + """ + return max(1, + int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism))) + def sync(self): - self.log.debug("Inquiring about %s celery task(s)", len(self.tasks)) - for key, task in list(self.tasks.items()): + num_processes = min(len(self.tasks), self._sync_parallelism) + if num_processes == 0: + self.log.debug("No task to query celery, skipping sync") + return + + self.log.debug("Inquiring about %s celery task(s) using %s processes", + len(self.tasks), num_processes) + + # Recreate the process pool each sync in case processes in the pool die + self._sync_pool = Pool(processes=num_processes) + + # Use chunking instead of a work queue to reduce context switching since tasks are + # roughly uniform in size + chunksize = self._num_tasks_per_process() + + self.log.debug("Waiting for inquiries to complete...") + task_keys_to_states = self._sync_pool.map( + fetch_celery_task_state, + self.tasks.items(), + chunksize=chunksize) + self._sync_pool.close() + self._sync_pool.join() + self.log.debug("Inquiries completed.") + + for key_and_state in task_keys_to_states: + if isinstance(key_and_state, ExceptionWithTraceback): + self.log.error( + CELERY_FETCH_ERR_MSG_HEADER + ", ignoring it:{}\n{}\n".format( + key_and_state.exception, key_and_state.traceback)) + continue + key, state = key_and_state try: - state = task.state if self.last_state[key] != state: if state == celery_states.SUCCESS: self.success(key) @@ -104,11 +200,10 @@ class CeleryExecutor(BaseExecutor): del self.tasks[key] del self.last_state[key] else: - self.log.info("Unexpected state: %s", state) + self.log.info("Unexpected state: " + state) self.last_state[key] = state - except Exception as e: - self.log.error("Error syncing the celery executor, ignoring it:") - self.log.exception(e) + except Exception: + self.log.exception("Error syncing the Celery executor, ignoring it.") def end(self, synchronous=False): if synchronous: diff --git a/scripts/ci/airflow_travis.cfg b/scripts/ci/airflow_travis.cfg index 2d412e182c..6895c4d5ad 100644 --- a/scripts/ci/airflow_travis.cfg +++ b/scripts/ci/airflow_travis.cfg @@ -59,6 +59,7 @@ broker_url = amqp://guest:guest@rabbitmq:5672/ result_backend = db+mysql://root@mysql/airflow flower_port = 5555 default_queue = default +sync_parallelism = 0 [celery_broker_transport_options] visibility_timeout = 21600 diff --git a/scripts/ci/kubernetes/kube/configmaps.yaml b/scripts/ci/kubernetes/kube/configmaps.yaml index f8e99778f5..c0c7e9b2d0 100644 --- a/scripts/ci/kubernetes/kube/configmaps.yaml +++ b/scripts/ci/kubernetes/kube/configmaps.yaml @@ -256,6 +256,10 @@ data: # Default queue that tasks get assigned to and that worker listen on. default_queue = default + # How many processes CeleryExecutor uses to sync task state. + # 0 means to use max(1, number of cores - 1) processes. + sync_parallelism = 0 + # Import path for celery configuration options celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 95ad58f6a2..2ebcfd7b63 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -18,10 +18,12 @@ # under the License. import sys import unittest +import mock from celery.contrib.testing.worker import start_worker from airflow.executors.celery_executor import CeleryExecutor from airflow.executors.celery_executor import app +from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER from airflow.utils.state import State # leave this it is used by the test worker @@ -57,5 +59,26 @@ class CeleryExecutorTest(unittest.TestCase): self.assertNotIn('success', executor.last_state) self.assertNotIn('fail', executor.last_state) + def test_exception_propagation(self): + @app.task + def fake_celery_task(): + return {} + + mock_log = mock.MagicMock() + executor = CeleryExecutor() + executor._log = mock_log + + executor.tasks = {'key': fake_celery_task()} + executor.sync() + mock_log.error.assert_called_once() + args, kwargs = mock_log.error.call_args_list[0] + log = args[0] + # Result of queuing is not a celery task but a dict, + # and it should raise AttributeError and then get propagated + # to the error log. + self.assertIn(CELERY_FETCH_ERR_MSG_HEADER, log) + self.assertIn('AttributeError', log) + + if __name__ == '__main__': unittest.main()