[AIRFLOW-2156] Parallelize Celery Executor task state fetching (#3830)

This commit is contained in:
yrqls21 2018-09-11 09:12:18 -07:00 коммит произвёл Tao Feng
Родитель 1f038a7919
Коммит 9b82fcb5fb
7 изменённых файлов: 143 добавлений и 10 удалений

Просмотреть файл

@ -17,6 +17,11 @@ so you might need to update your config.
The scheduler.min_file_parsing_loop_time config option has been temporarily removed due to
some bugs.
### new `sync_parallelism` config option in celery section
The new `sync_parallelism` config option will control how many processes CeleryExecutor will use to
fetch celery task state in parallel. Default value is max(1, number of cores - 1)
## Airflow 1.10
Installation and upgrading requires setting `SLUGIFY_USES_TEXT_UNIDECODE=yes` in your environment or

Просмотреть файл

@ -380,6 +380,10 @@ flower_port = 5555
# Default queue that tasks get assigned to and that worker listen on.
default_queue = default
# How many processes CeleryExecutor uses to sync task state.
# 0 means to use max(1, number of cores - 1) processes.
sync_parallelism = 0
# Import path for celery configuration options
celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG

Просмотреть файл

@ -97,6 +97,7 @@ result_backend = db+mysql://airflow:airflow@localhost:3306/airflow
flower_host = 0.0.0.0
flower_port = 5555
default_queue = default
sync_parallelism = 0
[mesos]
master = localhost:5050

Просмотреть файл

@ -17,20 +17,26 @@
# specific language governing permissions and limitations
# under the License.
import math
import os
import subprocess
import time
import os
import traceback
from multiprocessing import Pool, cpu_count
from celery import Celery
from celery import states as celery_states
from airflow import configuration
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow import configuration
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
# Make it constant for unit test.
CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'
'''
To start the celery worker, run the command:
airflow worker
@ -63,6 +69,42 @@ def execute_command(command):
raise AirflowException('Celery command failed')
class ExceptionWithTraceback(object):
"""
Wrapper class used to propogate exceptions to parent processes from subprocesses.
:param exception: The exception to wrap
:type exception: Exception
:param traceback: The stacktrace to wrap
:type traceback: str
"""
def __init__(self, exception, exception_traceback):
self.exception = exception
self.traceback = exception_traceback
def fetch_celery_task_state(celery_task):
"""
Fetch and return the state of the given celery task. The scope of this function is
global so that it can be called by subprocesses in the pool.
:param celery_task: a tuple of the Celery task key and the async Celery object used
to fetch the task's state
:type celery_task: (str, celery.result.AsyncResult)
:return: a tuple of the Celery task key and the Celery state of the task
:rtype: (str, str)
"""
try:
# Accessing state property of celery task will make actual network request
# to get the current state of the task.
res = (celery_task[0], celery_task[1].state)
except Exception as e:
exception_traceback = "Celery Task ID: {}\n{}".format(celery_task[0],
traceback.format_exc())
res = ExceptionWithTraceback(e, exception_traceback)
return res
class CeleryExecutor(BaseExecutor):
"""
CeleryExecutor is recommended for production use of Airflow. It allows
@ -72,10 +114,27 @@ class CeleryExecutor(BaseExecutor):
vast amounts of messages, while providing operations with the tools
required to maintain such a system.
"""
def start(self):
def __init__(self):
super(CeleryExecutor, self).__init__()
# Celery doesn't support querying the state of multiple tasks in parallel
# (which can become a bottleneck on bigger clusters) so we use
# a multiprocessing pool to speed this up.
# How many worker processes are created for checking celery task state.
self._sync_parallelism = configuration.getint('celery', 'SYNC_PARALLELISM')
if self._sync_parallelism == 0:
self._sync_parallelism = max(1, cpu_count() - 1)
self._sync_pool = None
self.tasks = {}
self.last_state = {}
def start(self):
self.log.debug(
'Starting Celery Executor using {} processes for syncing'.format(
self._sync_parallelism))
def execute_async(self, key, command,
queue=DEFAULT_CELERY_CONFIG['task_default_queue'],
executor_config=None):
@ -85,11 +144,48 @@ class CeleryExecutor(BaseExecutor):
args=[command], queue=queue)
self.last_state[key] = celery_states.PENDING
def _num_tasks_per_process(self):
"""
How many Celery tasks should be sent to each worker process.
:return: Number of tasks that should be used per process
:rtype: int
"""
return max(1,
int(math.ceil(1.0 * len(self.tasks) / self._sync_parallelism)))
def sync(self):
self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
for key, task in list(self.tasks.items()):
num_processes = min(len(self.tasks), self._sync_parallelism)
if num_processes == 0:
self.log.debug("No task to query celery, skipping sync")
return
self.log.debug("Inquiring about %s celery task(s) using %s processes",
len(self.tasks), num_processes)
# Recreate the process pool each sync in case processes in the pool die
self._sync_pool = Pool(processes=num_processes)
# Use chunking instead of a work queue to reduce context switching since tasks are
# roughly uniform in size
chunksize = self._num_tasks_per_process()
self.log.debug("Waiting for inquiries to complete...")
task_keys_to_states = self._sync_pool.map(
fetch_celery_task_state,
self.tasks.items(),
chunksize=chunksize)
self._sync_pool.close()
self._sync_pool.join()
self.log.debug("Inquiries completed.")
for key_and_state in task_keys_to_states:
if isinstance(key_and_state, ExceptionWithTraceback):
self.log.error(
CELERY_FETCH_ERR_MSG_HEADER + ", ignoring it:{}\n{}\n".format(
key_and_state.exception, key_and_state.traceback))
continue
key, state = key_and_state
try:
state = task.state
if self.last_state[key] != state:
if state == celery_states.SUCCESS:
self.success(key)
@ -104,11 +200,10 @@ class CeleryExecutor(BaseExecutor):
del self.tasks[key]
del self.last_state[key]
else:
self.log.info("Unexpected state: %s", state)
self.log.info("Unexpected state: " + state)
self.last_state[key] = state
except Exception as e:
self.log.error("Error syncing the celery executor, ignoring it:")
self.log.exception(e)
except Exception:
self.log.exception("Error syncing the Celery executor, ignoring it.")
def end(self, synchronous=False):
if synchronous:

Просмотреть файл

@ -59,6 +59,7 @@ broker_url = amqp://guest:guest@rabbitmq:5672/
result_backend = db+mysql://root@mysql/airflow
flower_port = 5555
default_queue = default
sync_parallelism = 0
[celery_broker_transport_options]
visibility_timeout = 21600

Просмотреть файл

@ -256,6 +256,10 @@ data:
# Default queue that tasks get assigned to and that worker listen on.
default_queue = default
# How many processes CeleryExecutor uses to sync task state.
# 0 means to use max(1, number of cores - 1) processes.
sync_parallelism = 0
# Import path for celery configuration options
celery_config_options = airflow.config_templates.default_celery.DEFAULT_CELERY_CONFIG

Просмотреть файл

@ -18,10 +18,12 @@
# under the License.
import sys
import unittest
import mock
from celery.contrib.testing.worker import start_worker
from airflow.executors.celery_executor import CeleryExecutor
from airflow.executors.celery_executor import app
from airflow.executors.celery_executor import CELERY_FETCH_ERR_MSG_HEADER
from airflow.utils.state import State
# leave this it is used by the test worker
@ -57,5 +59,26 @@ class CeleryExecutorTest(unittest.TestCase):
self.assertNotIn('success', executor.last_state)
self.assertNotIn('fail', executor.last_state)
def test_exception_propagation(self):
@app.task
def fake_celery_task():
return {}
mock_log = mock.MagicMock()
executor = CeleryExecutor()
executor._log = mock_log
executor.tasks = {'key': fake_celery_task()}
executor.sync()
mock_log.error.assert_called_once()
args, kwargs = mock_log.error.call_args_list[0]
log = args[0]
# Result of queuing is not a celery task but a dict,
# and it should raise AttributeError and then get propagated
# to the error log.
self.assertIn(CELERY_FETCH_ERR_MSG_HEADER, log)
self.assertIn('AttributeError', log)
if __name__ == '__main__':
unittest.main()