From b56cb5cc97de074bb0e520f66b79e7eb2d913fb1 Mon Sep 17 00:00:00 2001 From: Dan Davydov Date: Wed, 18 Jan 2017 18:11:01 -0800 Subject: [PATCH] [AIRFLOW-219][AIRFLOW-398] Cgroups + impersonation Submitting on behalf of plypaul Please accept this PR that addresses the following issues: - https://issues.apache.org/jira/browse/AIRFLOW-219 - https://issues.apache.org/jira/browse/AIRFLOW-398 Testing Done: - Running on Airbnb prod (though on a different mergebase) for many months Credits: Impersonation Work: georgeke did most of the work but plypaul did quite a bit of work too. Cgroups: plypaul did most of the work, I just did some touch up/bug fixes (see commit history, cgroups + impersonation commit is actually plypaul 's not mine) Closes #1934 from aoen/ddavydov/cgroups_and_impers onation_after_rebase --- .travis.yml | 2 +- airflow/bin/cli.py | 96 +++++++-- airflow/configuration.py | 7 + airflow/contrib/task_runner/__init__.py | 13 ++ .../contrib/task_runner/cgroup_task_runner.py | 202 ++++++++++++++++++ airflow/jobs.py | 67 +++--- ...a5a9e6bf2b5_add_state_index_for_dagruns.py | 37 ++++ airflow/models.py | 92 +++++--- airflow/settings.py | 23 +- airflow/task_runner/__init__.py | 38 ++++ airflow/task_runner/base_task_runner.py | 153 +++++++++++++ airflow/task_runner/bash_task_runner.py | 39 ++++ airflow/utils/file.py | 23 ++ airflow/utils/helpers.py | 79 ++++++- docs/security.rst | 22 ++ run_unit_tests.sh | 14 ++ scripts/ci/airflow_travis.cfg | 1 + scripts/ci/requirements.txt | 1 + setup.py | 4 + tests/__init__.py | 1 + tests/dags/test_default_impersonation.py | 44 ++++ tests/dags/test_impersonation.py | 45 ++++ tests/dags/test_no_impersonation.py | 43 ++++ tests/impersonation.py | 111 ++++++++++ 24 files changed, 1061 insertions(+), 96 deletions(-) create mode 100644 airflow/contrib/task_runner/__init__.py create mode 100644 airflow/contrib/task_runner/cgroup_task_runner.py create mode 100644 airflow/migrations/versions/1a5a9e6bf2b5_add_state_index_for_dagruns.py create mode 100644 airflow/task_runner/__init__.py create mode 100644 airflow/task_runner/base_task_runner.py create mode 100644 airflow/task_runner/bash_task_runner.py create mode 100644 tests/dags/test_default_impersonation.py create mode 100644 tests/dags/test_impersonation.py create mode 100644 tests/dags/test_no_impersonation.py create mode 100644 tests/impersonation.py diff --git a/.travis.yml b/.travis.yml index 407e7f9c9b..90f33e3fcb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -89,7 +89,7 @@ cache: - $HOME/.wheelhouse/ - $HOME/.travis_cache/ before_install: - - ssh-keygen -t rsa -C your_email@youremail.com -P '' -f ~/.ssh/id_rsa + - yes | ssh-keygen -t rsa -C your_email@youremail.com -P '' -f ~/.ssh/id_rsa - cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys - ln -s ~/.ssh/authorized_keys ~/.ssh/authorized_keys2 - chmod 600 ~/.ssh/* diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index d55fdfc4be..736df0a9bf 100755 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -22,7 +22,6 @@ import os import subprocess import textwrap import warnings -from datetime import datetime from importlib import import_module import argparse @@ -53,7 +52,7 @@ from airflow.models import (DagModel, DagBag, TaskInstance, from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS) from airflow.utils import db as db_utils from airflow.utils import logging as logging_utils -from airflow.utils.state import State +from airflow.utils.file import mkdirs from airflow.www.app import cached_app from sqlalchemy import func @@ -300,6 +299,7 @@ def export_helper(filepath): varfile.write(json.dumps(var_dict, sort_keys=True, indent=4)) print("{} variables successfully exported to {}".format(len(var_dict), filepath)) + def pause(args, dag=None): set_is_paused(True, args, dag) @@ -329,19 +329,65 @@ def run(args, dag=None): if dag: args.dag_id = dag.dag_id - # Setting up logging - log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER')) - directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args) - if not os.path.exists(directory): - os.makedirs(directory) - iso = args.execution_date.isoformat() - filename = "{directory}/{iso}".format(**locals()) + # Load custom airflow config + if args.cfg_path: + with open(args.cfg_path, 'r') as conf_file: + conf_dict = json.load(conf_file) + + if os.path.exists(args.cfg_path): + os.remove(args.cfg_path) + + for section, config in conf_dict.items(): + for option, value in config.items(): + conf.set(section, option, value) + settings.configure_vars() + settings.configure_orm() logging.root.handlers = [] - logging.basicConfig( - filename=filename, - level=settings.LOGGING_LEVEL, - format=settings.LOG_FORMAT) + if args.raw: + # Output to STDOUT for the parent process to read and log + logging.basicConfig( + stream=sys.stdout, + level=settings.LOGGING_LEVEL, + format=settings.LOG_FORMAT) + else: + # Setting up logging to a file. + + # To handle log writing when tasks are impersonated, the log files need to + # be writable by the user that runs the Airflow command and the user + # that is impersonated. This is mainly to handle corner cases with the + # SubDagOperator. When the SubDagOperator is run, all of the operators + # run under the impersonated user and create appropriate log files + # as the impersonated user. However, if the user manually runs tasks + # of the SubDagOperator through the UI, then the log files are created + # by the user that runs the Airflow command. For example, the Airflow + # run command may be run by the `airflow_sudoable` user, but the Airflow + # tasks may be run by the `airflow` user. If the log files are not + # writable by both users, then it's possible that re-running a task + # via the UI (or vice versa) results in a permission error as the task + # tries to write to a log file created by the other user. + log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER')) + directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args) + # Create the log file and give it group writable permissions + # TODO(aoen): Make log dirs and logs globally readable for now since the SubDag + # operator is not compatible with impersonation (e.g. if a Celery executor is used + # for a SubDag operator and the SubDag operator has a different owner than the + # parent DAG) + if not os.path.exists(directory): + # Create the directory as globally writable using custom mkdirs + # as os.makedirs doesn't set mode properly. + mkdirs(directory, 0o775) + iso = args.execution_date.isoformat() + filename = "{directory}/{iso}".format(**locals()) + + if not os.path.exists(filename): + open(filename, "a").close() + os.chmod(filename, 0o666) + + logging.basicConfig( + filename=filename, + level=settings.LOGGING_LEVEL, + format=settings.LOG_FORMAT) if not args.pickle and not dag: dag = get_dag(args) @@ -413,6 +459,10 @@ def run(args, dag=None): executor.heartbeat() executor.end() + # Child processes should not flush or upload to remote + if args.raw: + return + # Force the log to flush, and set the handler to go back to normal so we # don't continue logging to the task's log file. The flush is important # because we subsequently read from the log to insert into S3 or Google @@ -626,7 +676,7 @@ def restart_workers(gunicorn_master_proc, num_workers_expected): def start_refresh(gunicorn_master_proc): batch_size = conf.getint('webserver', 'worker_refresh_batch_size') logging.debug('%s doing a refresh of %s workers', - state, batch_size) + state, batch_size) sys.stdout.flush() sys.stderr.flush() @@ -635,11 +685,10 @@ def restart_workers(gunicorn_master_proc, num_workers_expected): gunicorn_master_proc.send_signal(signal.SIGTTIN) excess += 1 wait_until_true(lambda: num_workers_expected + excess == - get_num_workers_running(gunicorn_master_proc)) - + get_num_workers_running(gunicorn_master_proc)) wait_until_true(lambda: num_workers_expected == - get_num_workers_running(gunicorn_master_proc)) + get_num_workers_running(gunicorn_master_proc)) while True: num_workers_running = get_num_workers_running(gunicorn_master_proc) @@ -662,7 +711,7 @@ def restart_workers(gunicorn_master_proc, num_workers_expected): gunicorn_master_proc.send_signal(signal.SIGTTOU) excess -= 1 wait_until_true(lambda: num_workers_expected + excess == - get_num_workers_running(gunicorn_master_proc)) + get_num_workers_running(gunicorn_master_proc)) # Start a new worker by asking gunicorn to increase number of workers elif num_workers_running == num_workers_expected: @@ -761,7 +810,8 @@ def webserver(args): if conf.getint('webserver', 'worker_refresh_interval') > 0: restart_workers(gunicorn_master_proc, num_workers) else: - while True: time.sleep(1) + while True: + time.sleep(1) def scheduler(args): @@ -920,7 +970,7 @@ def connections(args): Connection.is_encrypted, Connection.is_extra_encrypted, Connection.extra).all() - conns = [map(reprlib.repr, conn) for conn in conns] + conns = [map(reprlib.repr, conn) for conn in conns] print(tabulate(conns, ['Conn Id', 'Conn Type', 'Host', 'Port', 'Is Encrypted', 'Is Extra Encrypted', 'Extra'], tablefmt="fancy_grid")) @@ -1255,6 +1305,8 @@ class CLIFactory(object): ("-p", "--pickle"), "Serialized pickle object of the entire dag (used internally)"), 'job_id': Arg(("-j", "--job_id"), argparse.SUPPRESS), + 'cfg_path': Arg( + ("--cfg_path", ), "Path to config file to use instead of airflow.cfg"), # webserver 'port': Arg( ("-p", "--port"), @@ -1433,7 +1485,7 @@ class CLIFactory(object): 'help': "Run a single task instance", 'args': ( 'dag_id', 'task_id', 'execution_date', 'subdir', - 'mark_success', 'force', 'pool', + 'mark_success', 'force', 'pool', 'cfg_path', 'local', 'raw', 'ignore_all_dependencies', 'ignore_dependencies', 'ignore_depends_on_past', 'ship_dag', 'pickle', 'job_id'), }, { @@ -1486,7 +1538,7 @@ class CLIFactory(object): 'func': upgradedb, 'help': "Upgrade the metadata database to latest version", 'args': tuple(), - },{ + }, { 'func': scheduler, 'help': "Start a scheduler instance", 'args': ('dag_id_opt', 'subdir', 'run_duration', 'num_runs', diff --git a/airflow/configuration.py b/airflow/configuration.py index 9b27328129..979b071ea1 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -166,6 +166,13 @@ donot_pickle = False # How long before timing out a python file import while filling the DagBag dagbag_import_timeout = 30 +# The class to use for running task instances in a subprocess +task_runner = BashTaskRunner + +# If set, tasks without a `run_as_user` argument will be run with this user +# Can be used to de-elevate a sudo user running Airflow when executing tasks +default_impersonation = + # What security module to use (for example kerberos): security = diff --git a/airflow/contrib/task_runner/__init__.py b/airflow/contrib/task_runner/__init__.py new file mode 100644 index 0000000000..d4cd6f7c7c --- /dev/null +++ b/airflow/contrib/task_runner/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/airflow/contrib/task_runner/cgroup_task_runner.py b/airflow/contrib/task_runner/cgroup_task_runner.py new file mode 100644 index 0000000000..79aafc8d65 --- /dev/null +++ b/airflow/contrib/task_runner/cgroup_task_runner.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import getpass +import subprocess +import os +import uuid + +from cgroupspy import trees +import psutil + +from airflow.task_runner.base_task_runner import BaseTaskRunner +from airflow.utils.helpers import kill_process_tree + + +class CgroupTaskRunner(BaseTaskRunner): + """ + Runs the raw Airflow task in a cgroup that has containment for memory and + cpu. It uses the resource requirements defined in the task to construct + the settings for the cgroup. + + Note that this task runner will only work if the Airflow user has root privileges, + e.g. if the airflow user is called `airflow` then the following entries (or an even + less restrictive ones) are needed in the sudoers file (replacing + /CGROUPS_FOLDER with your system's cgroups folder, e.g. '/sys/fs/cgroup/'): + airflow ALL= (root) NOEXEC: /bin/chown /CGROUPS_FOLDER/memory/airflow/* + airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/memory/airflow/*..* + airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/memory/airflow/* * + airflow ALL= (root) NOEXEC: /bin/chown /CGROUPS_FOLDER/cpu/airflow/* + airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/cpu/airflow/*..* + airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/cpu/airflow/* * + airflow ALL= (root) NOEXEC: /bin/chmod /CGROUPS_FOLDER/memory/airflow/* + airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/memory/airflow/*..* + airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/memory/airflow/* * + airflow ALL= (root) NOEXEC: /bin/chmod /CGROUPS_FOLDER/cpu/airflow/* + airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/cpu/airflow/*..* + airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/cpu/airflow/* * + """ + + def __init__(self, local_task_job): + super(CgroupTaskRunner, self).__init__(local_task_job) + self.process = None + self._finished_running = False + self._cpu_shares = None + self._mem_mb_limit = None + self._created_cpu_cgroup = False + self._created_mem_cgroup = False + self._cur_user = getpass.getuser() + + def _create_cgroup(self, path): + """ + Create the specified cgroup. + + :param path: The path of the cgroup to create. + E.g. cpu/mygroup/mysubgroup + :return: the Node associated with the created cgroup. + :rtype: cgroupspy.nodes.Node + """ + node = trees.Tree().root + path_split = path.split(os.sep) + for path_element in path_split: + name_to_node = {x.name: x for x in node.children} + if path_element not in name_to_node: + self.logger.debug("Creating cgroup {} in {}" + .format(path_element, node.path)) + subprocess.check_output("sudo mkdir -p {}".format(path_element)) + subprocess.check_output("sudo chown -R {} {}".format( + self._cur_user, path_element)) + else: + self.logger.debug("Not creating cgroup {} in {} " + "since it already exists" + .format(path_element, node.path)) + node = name_to_node[path_element] + return node + + def _delete_cgroup(self, path): + """ + Delete the specified cgroup. + + :param path: The path of the cgroup to delete. + E.g. cpu/mygroup/mysubgroup + """ + node = trees.Tree().root + path_split = path.split("/") + for path_element in path_split: + name_to_node = {x.name: x for x in node.children} + if path_element not in name_to_node: + self.logger.warn("Cgroup does not exist: {}" + .format(path)) + return + else: + node = name_to_node[path_element] + # node is now the leaf node + parent = node.parent + self.logger.debug("Deleting cgroup {}/{}".format(parent, node.name)) + parent.delete_cgroup(node.name) + + def start(self): + # Use bash if it's already in a cgroup + cgroups = self._get_cgroup_names() + if cgroups["cpu"] != "/" or cgroups["memory"] != "/": + self.logger.debug("Already running in a cgroup (cpu: {} memory: {} so " + "not creating another one" + .format(cgroups.get("cpu"), + cgroups.get("memory"))) + self.process = self.run_command(['bash', '-c'], join_args=True) + return + + # Create a unique cgroup name + cgroup_name = "airflow/{}/{}".format(datetime.datetime.now(). + strftime("%Y-%m-%d"), + str(uuid.uuid1())) + + self.mem_cgroup_name = "memory/{}".format(cgroup_name) + self.cpu_cgroup_name = "cpu/{}".format(cgroup_name) + + # Get the resource requirements from the task + task = self._task_instance.task + resources = task.resources + cpus = resources.cpus.qty + self._cpu_shares = cpus * 1024 + self._mem_mb_limit = resources.ram.qty + + # Create the memory cgroup + mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name) + self._created_mem_cgroup = True + if self._mem_mb_limit > 0: + self.logger.debug("Setting {} with {} MB of memory" + .format(self.mem_cgroup_name, self._mem_mb_limit)) + mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024 + + # Create the CPU cgroup + cpu_cgroup_node = self._create_cgroup(self.cpu_cgroup_name) + self._created_cpu_cgroup = True + if self._cpu_shares > 0: + self.logger.debug("Setting {} with {} CPU shares" + .format(self.cpu_cgroup_name, self._cpu_shares)) + cpu_cgroup_node.controller.shares = self._cpu_shares + + # Start the process w/ cgroups + self.logger.debug("Starting task process with cgroups cpu,memory:{}" + .format(cgroup_name)) + self.process = self.run_command( + ['cgexec', '-g', 'cpu,memory:{}'.format(cgroup_name)] + ) + + def return_code(self): + return_code = self.process.poll() + # TODO(plypaul) Monitoring the the control file in the cgroup fs is better than + # checking the return code here. The PR to use this is here: + # https://github.com/plypaul/airflow/blob/e144e4d41996300ffa93947f136eab7785b114ed/airflow/contrib/task_runner/cgroup_task_runner.py#L43 + # but there were some issues installing the python butter package and + # libseccomp-dev on some hosts for some reason. + # I wasn't able to track down the root cause of the package install failures, but + # we might want to revisit that approach at some other point. + if return_code == 137: + self.logger.warn("Task failed with return code of 137. This may indicate " + "that it was killed due to excessive memory usage. " + "Please consider optimizing your task or using the " + "resources argument to reserve more memory for your " + "task") + return return_code + + def terminate(self): + if self.process and psutil.pid_exists(self.process.pid): + kill_process_tree(self.logger, self.process.pid) + + def on_finish(self): + # Let the OOM watcher thread know we're done to avoid false OOM alarms + self._finished_running = True + # Clean up the cgroups + if self._created_mem_cgroup: + self._delete_cgroup(self.mem_cgroup_name) + if self._created_cpu_cgroup: + self._delete_cgroup(self.cpu_cgroup_name) + + def _get_cgroup_names(self): + """ + :return: a mapping between the subsystem name to the cgroup name + :rtype: dict[str, str] + """ + with open("/proc/self/cgroup") as f: + lines = f.readlines() + d = {} + for line in lines: + line_split = line.rstrip().split(":") + subsystem = line_split[1] + group_name = line_split[2] + d[subsystem] = group_name + return d diff --git a/airflow/jobs.py b/airflow/jobs.py index 2a6af3902b..f1de333c10 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -25,7 +25,6 @@ from datetime import datetime import getpass import logging import socket -import subprocess import multiprocessing import os import signal @@ -35,7 +34,7 @@ import time from time import sleep import psutil -from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_ +from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_, and_ from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import make_transient from tabulate import tabulate @@ -45,6 +44,7 @@ from airflow import configuration as conf from airflow.exceptions import AirflowException from airflow.models import DagRun from airflow.settings import Stats +from airflow.task_runner import get_task_runner from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS from airflow.utils.state import State from airflow.utils.db import provide_session, pessimistic_connection_handling @@ -54,15 +54,12 @@ from airflow.utils.dag_processing import (AbstractDagFileProcessor, SimpleDagBag, list_py_file_paths) from airflow.utils.email import send_email -from airflow.utils.helpers import kill_descendant_processes from airflow.utils.logging import LoggingMixin from airflow.utils import asciiart Base = models.Base -DagRun = models.DagRun ID_LEN = models.ID_LEN -Stats = settings.Stats class BaseJob(Base, LoggingMixin): @@ -956,13 +953,18 @@ class SchedulerJob(BaseJob): :type states: Tuple[State] :return: None """ - # Get all the relevant task instances + # Get all the queued task instances from associated with scheduled + # DagRuns. TI = models.TaskInstance task_instances_to_examine = ( session .query(TI) .filter(TI.dag_id.in_(simple_dag_bag.dag_ids)) .filter(TI.state.in_(states)) + .join(DagRun, and_(TI.dag_id == DagRun.dag_id, + TI.execution_date == DagRun.execution_date, + DagRun.state == State.RUNNING, + DagRun.run_id.like(DagRun.ID_PREFIX + '%'))) .all() ) @@ -1017,7 +1019,7 @@ class SchedulerJob(BaseJob): self.logger.debug("Not handling task {} as the executor reports it is running" .format(task_instance.key)) continue - + if simple_dag_bag.get_dag(task_instance.dag_id).is_paused: self.logger.info("Not executing queued {} since {} is paused" .format(task_instance, task_instance.dag_id)) @@ -1054,7 +1056,7 @@ class SchedulerJob(BaseJob): task_concurrency_limit)) continue - command = TI.generate_command( + command = " ".join(TI.generate_command( task_instance.dag_id, task_instance.task_id, task_instance.execution_date, @@ -1066,7 +1068,7 @@ class SchedulerJob(BaseJob): ignore_ti_state=False, pool=task_instance.pool, file_path=simple_dag_bag.get_dag(task_instance.dag_id).full_filepath, - pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id) + pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id)) priority = task_instance.priority_weight queue = task_instance.queue @@ -1659,7 +1661,7 @@ class BackfillJob(BaseJob): # consider max_active_runs but ignore when running subdags # "parent.child" as a dag_id is by convention a subdag - if self.dag.schedule_interval and not "." in self.dag.dag_id: + if self.dag.schedule_interval and "." not in self.dag.dag_id: active_runs = DagRun.find( dag_id=self.dag.dag_id, state=State.RUNNING, @@ -1915,7 +1917,6 @@ class BackfillJob(BaseJob): self.logger.error(msg) ti.handle_failure(msg) tasks_to_run.pop(key) - msg = ' | '.join([ "[backfill progress]", "dag run {6} of {7}", @@ -2026,23 +2027,14 @@ class LocalTaskJob(BaseJob): super(LocalTaskJob, self).__init__(*args, **kwargs) def _execute(self): + self.task_runner = get_task_runner(self) try: - command = self.task_instance.command( - raw=True, - ignore_all_deps = self.ignore_all_deps, - ignore_depends_on_past = self.ignore_depends_on_past, - ignore_task_deps = self.ignore_task_deps, - ignore_ti_state = self.ignore_ti_state, - pickle_id = self.pickle_id, - mark_success = self.mark_success, - job_id = self.id, - pool = self.pool - ) - self.process = subprocess.Popen(['bash', '-c', command]) - self.logger.info("Subprocess PID is {}".format(self.process.pid)) + self.task_runner.start() + ti = self.task_instance session = settings.Session() - ti.pid = self.process.pid + if self.task_runner.process: + ti.pid = self.task_runner.process.pid ti.hostname = socket.getfqdn() session.merge(ti) session.commit() @@ -2053,8 +2045,10 @@ class LocalTaskJob(BaseJob): 'scheduler_zombie_task_threshold') while True: # Monitor the task to see if it's done - return_code = self.process.poll() + return_code = self.task_runner.return_code() if return_code is not None: + self.logger.info("Task exited with return code {}" + .format(return_code)) return # Periodically heartbeat so that the scheduler doesn't think this @@ -2079,11 +2073,11 @@ class LocalTaskJob(BaseJob): .format(time_since_last_heartbeat, heartbeat_time_limit)) finally: - # Kill processes that were left running - kill_descendant_processes(self.logger) + self.on_kill() def on_kill(self): - self.process.terminate() + self.task_runner.terminate() + self.task_runner.on_finish() @provide_session def heartbeat_callback(self, session=None): @@ -2097,23 +2091,24 @@ class LocalTaskJob(BaseJob): TI = models.TaskInstance ti = self.task_instance new_ti = session.query(TI).filter( - TI.dag_id==ti.dag_id, TI.task_id==ti.task_id, - TI.execution_date==ti.execution_date).scalar() + TI.dag_id == ti.dag_id, TI.task_id == ti.task_id, + TI.execution_date == ti.execution_date).scalar() if new_ti.state == State.RUNNING: self.was_running = True fqdn = socket.getfqdn() - if not (fqdn == new_ti.hostname and self.process.pid == new_ti.pid): + if not (fqdn == new_ti.hostname and + self.task_runner.process.pid == new_ti.pid): logging.warning("Recorded hostname and pid of {new_ti.hostname} " "and {new_ti.pid} do not match this instance's " "which are {fqdn} and " - "{self.process.pid}. Taking the poison pill. So " - "long." + "{self.task_runner.process.pid}. Taking the poison pill. " + "So long." .format(**locals())) raise AirflowException("Another worker/process is running this job") - elif self.was_running and hasattr(self, 'process'): + elif self.was_running and hasattr(self.task_runner, 'process'): logging.warning( "State of this instance has been externally set to " "{self.task_instance.state}. " "Taking the poison pill. So long.".format(**locals())) - self.process.terminate() + self.task_runner.terminate() self.terminating = True diff --git a/airflow/migrations/versions/1a5a9e6bf2b5_add_state_index_for_dagruns.py b/airflow/migrations/versions/1a5a9e6bf2b5_add_state_index_for_dagruns.py new file mode 100644 index 0000000000..29ffaf1180 --- /dev/null +++ b/airflow/migrations/versions/1a5a9e6bf2b5_add_state_index_for_dagruns.py @@ -0,0 +1,37 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Add state index for dagruns to allow the quick lookup of active dagruns + +Revision ID: 1a5a9e6bf2b5 +Revises: 5e7d17757c7a +Create Date: 2017-01-17 10:22:53.193711 + +""" + +# revision identifiers, used by Alembic. +revision = '1a5a9e6bf2b5' +down_revision = '5e7d17757c7a' +branch_labels = None +depends_on = None + +from alembic import op +import sqlalchemy as sa + + +def upgrade(): + op.create_index('dr_state', 'dag_run', ['state'], unique=False) + + +def downgrade(): + op.drop_index('state', table_name='dag_run') diff --git a/airflow/models.py b/airflow/models.py index 8682f35898..a16603d546 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -251,9 +251,9 @@ class DagBag(BaseDagBag, LoggingMixin): self.logger.debug("Importing {}".format(filepath)) org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1]) - mod_name = ('unusual_prefix_' - + hashlib.sha1(filepath.encode('utf-8')).hexdigest() - + '_' + org_mod_name) + mod_name = ('unusual_prefix_' + + hashlib.sha1(filepath.encode('utf-8')).hexdigest() + + '_' + org_mod_name) if mod_name in sys.modules: del sys.modules[mod_name] @@ -756,6 +756,7 @@ class TaskInstance(Base): self.priority_weight = task.priority_weight_total self.try_number = 0 self.unixname = getpass.getuser() + self.run_as_user = task.run_as_user if state: self.state = state self.hostname = '' @@ -777,7 +778,39 @@ class TaskInstance(Base): pickle_id=None, raw=False, job_id=None, - pool=None): + pool=None, + cfg_path=None): + """ + Returns a command that can be executed anywhere where airflow is + installed. This command is part of the message sent to executors by + the orchestrator. + """ + return " ".join(self.command_as_list( + mark_success=mark_success, + ignore_all_deps=ignore_all_deps, + ignore_depends_on_past=ignore_depends_on_past, + ignore_task_deps=ignore_task_deps, + ignore_ti_state=ignore_ti_state, + local=local, + pickle_id=pickle_id, + raw=raw, + job_id=job_id, + pool=pool, + cfg_path=cfg_path)) + + def command_as_list( + self, + mark_success=False, + ignore_all_deps=False, + ignore_task_deps=False, + ignore_depends_on_past=False, + ignore_ti_state=False, + local=False, + pickle_id=None, + raw=False, + job_id=None, + pool=None, + cfg_path=None): """ Returns a command that can be executed anywhere where airflow is installed. This command is part of the message sent to executors by @@ -799,15 +832,16 @@ class TaskInstance(Base): self.execution_date, mark_success=mark_success, ignore_all_deps=ignore_all_deps, - ignore_depends_on_past=ignore_depends_on_past, ignore_task_deps=ignore_task_deps, + ignore_depends_on_past=ignore_depends_on_past, ignore_ti_state=ignore_ti_state, local=local, pickle_id=pickle_id, file_path=path, raw=raw, job_id=job_id, - pool=pool) + pool=pool, + cfg_path=cfg_path) @staticmethod def generate_command(dag_id, @@ -823,7 +857,8 @@ class TaskInstance(Base): file_path=None, raw=False, job_id=None, - pool=None + pool=None, + cfg_path=None ): """ Generates the shell command required to execute this task instance. @@ -860,19 +895,20 @@ class TaskInstance(Base): :return: shell command that can be used to run the task instance """ iso = execution_date.isoformat() - cmd = "airflow run {dag_id} {task_id} {iso} " - cmd += "--mark_success " if mark_success else "" - cmd += "--pickle {pickle_id} " if pickle_id else "" - cmd += "--job_id {job_id} " if job_id else "" - cmd += "-A " if ignore_all_deps else "" - cmd += "-i " if ignore_task_deps else "" - cmd += "-I " if ignore_depends_on_past else "" - cmd += "--force " if ignore_ti_state else "" - cmd += "--local " if local else "" - cmd += "--pool {pool} " if pool else "" - cmd += "--raw " if raw else "" - cmd += "-sd {file_path}" if file_path else "" - return cmd.format(**locals()) + cmd = ["airflow", "run", str(dag_id), str(task_id), str(iso)] + cmd.extend(["--mark_success"]) if mark_success else None + cmd.extend(["--pickle", str(pickle_id)]) if pickle_id else None + cmd.extend(["--job_id", str(job_id)]) if job_id else None + cmd.extend(["-A "]) if ignore_all_deps else None + cmd.extend(["-i"]) if ignore_task_deps else None + cmd.extend(["-I"]) if ignore_depends_on_past else None + cmd.extend(["--force"]) if ignore_ti_state else None + cmd.extend(["--local"]) if local else None + cmd.extend(["--pool", pool]) if pool else None + cmd.extend(["--raw"]) if raw else None + cmd.extend(["-sd", file_path]) if file_path else None + cmd.extend(["--cfg_path", cfg_path]) if cfg_path else None + return cmd @property def log_filepath(self): @@ -1825,6 +1861,8 @@ class BaseOperator(object): :param resources: A map of resource parameter names (the argument names of the Resources constructor) to their values. :type resources: dict + :param run_as_user: unix username to impersonate while running the task + :type run_as_user: str """ # For derived classes to define which fields will get jinjaified @@ -1866,6 +1904,7 @@ class BaseOperator(object): on_retry_callback=None, trigger_rule=TriggerRule.ALL_SUCCESS, resources=None, + run_as_user=None, *args, **kwargs): @@ -1929,6 +1968,7 @@ class BaseOperator(object): self.adhoc = adhoc self.priority_weight = priority_weight self.resources = Resources(**(resources or {})) + self.run_as_user = run_as_user # Private attributes self._upstream_task_ids = [] @@ -2854,13 +2894,7 @@ class DAG(BaseDag, LoggingMixin): :param session: :return: List of execution dates """ - runs = ( - session.query(DagRun) - .filter( - DagRun.dag_id == self.dag_id, - DagRun.state == State.RUNNING) - .order_by(DagRun.execution_date) - .all()) + runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING) active_dates = [] for run in runs: @@ -2959,7 +2993,7 @@ class DAG(BaseDag, LoggingMixin): self, session, start_date=None, end_date=None, state=None): TI = TaskInstance if not start_date: - start_date = (datetime.today()-timedelta(30)).date() + start_date = (datetime.today() - timedelta(30)).date() start_date = datetime.combine(start_date, datetime.min.time()) end_date = end_date or datetime.now() tis = session.query(TI).filter( @@ -3488,7 +3522,6 @@ class Variable(Base): else: return obj.val - @classmethod @provide_session def get(cls, key, default_var=None, deserialize_json=False, session=None): @@ -3695,7 +3728,6 @@ class DagStat(Base): :type full_query: bool """ dag_ids = set(dag_ids) - ds_ids = set(session.query(DagStat.dag_id).all()) qry = ( session.query(DagStat) diff --git a/airflow/settings.py b/airflow/settings.py index ce2ca9228a..4882875ea8 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -68,10 +68,7 @@ ___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ / """ BASE_LOG_URL = '/admin/airflow/log' -AIRFLOW_HOME = os.path.expanduser(conf.get('core', 'AIRFLOW_HOME')) -SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN') LOGGING_LEVEL = logging.INFO -DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER')) # the prefix to append to gunicorn worker processes after init GUNICORN_WORKER_READY_PREFIX = "[ready] " @@ -85,6 +82,13 @@ LOG_FORMAT_WITH_THREAD_NAME = ( '[%(asctime)s] {%(filename)s:%(lineno)d} %(threadName)s %(levelname)s - %(message)s') SIMPLE_LOG_FORMAT = '%(asctime)s %(levelname)s - %(message)s' +AIRFLOW_HOME = None +SQL_ALCHEMY_CONN = None +DAGS_FOLDER = None + +engine = None +Session = None + def policy(task_instance): """ @@ -118,8 +122,14 @@ def configure_logging(log_format=LOG_FORMAT): logging.basicConfig( format=log_format, stream=sys.stdout, level=LOGGING_LEVEL) -engine = None -Session = None + +def configure_vars(): + global AIRFLOW_HOME + global SQL_ALCHEMY_CONN + global DAGS_FOLDER + AIRFLOW_HOME = os.path.expanduser(conf.get('core', 'AIRFLOW_HOME')) + SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN') + DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER')) def configure_orm(disable_connection_pool=False): @@ -133,7 +143,7 @@ def configure_orm(disable_connection_pool=False): engine_args['pool_size'] = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE') engine_args['pool_recycle'] = conf.getint('core', 'SQL_ALCHEMY_POOL_RECYCLE') - #engine_args['echo'] = True + # engine_args['echo'] = True engine = create_engine(SQL_ALCHEMY_CONN, **engine_args) Session = scoped_session( @@ -146,6 +156,7 @@ except: pass configure_logging() +configure_vars() configure_orm() # Const stuff diff --git a/airflow/task_runner/__init__.py b/airflow/task_runner/__init__.py new file mode 100644 index 0000000000..f134e8e1d6 --- /dev/null +++ b/airflow/task_runner/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow import configuration +from airflow.contrib.task_runner.cgroup_task_runner import CgroupTaskRunner +from airflow.task_runner.bash_task_runner import BashTaskRunner +from airflow.exceptions import AirflowException + +_TASK_RUNNER = configuration.get('core', 'TASK_RUNNER') + + +def get_task_runner(local_task_job): + """ + Get the task runner that can be used to run the given job. + + :param local_task_job: The LocalTaskJob associated with the TaskInstance + that needs to be executed. + :type local_task_job: airflow.jobs.LocalTaskJob + :return: The task runner to use to run the task. + :rtype: airflow.task_runner.base_task_runner.BaseTaskRunner + """ + if _TASK_RUNNER == "BashTaskRunner": + return BashTaskRunner(local_task_job) + elif _TASK_RUNNER == "CgroupTaskRunner": + return CgroupTaskRunner(local_task_job) + else: + raise AirflowException("Unknown task runner type {}".format(_TASK_RUNNER)) diff --git a/airflow/task_runner/base_task_runner.py b/airflow/task_runner/base_task_runner.py new file mode 100644 index 0000000000..69802a87b1 --- /dev/null +++ b/airflow/task_runner/base_task_runner.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import getpass +import os +import json +import subprocess +import threading + +from airflow import configuration as conf +from airflow.utils.logging import LoggingMixin +from tempfile import mkstemp + + +class BaseTaskRunner(LoggingMixin): + """ + Runs Airflow task instances by invoking the `airflow run` command with raw + mode enabled in a subprocess. + """ + + def __init__(self, local_task_job): + """ + :param local_task_job: The local task job associated with running the + associated task instance. + :type local_task_job: airflow.jobs.LocalTaskJob + """ + self._task_instance = local_task_job.task_instance + + popen_prepend = [] + cfg_path = None + if self._task_instance.run_as_user: + self.run_as_user = self._task_instance.run_as_user + else: + try: + self.run_as_user = conf.get('core', 'default_impersonation') + except conf.AirflowConfigException: + self.run_as_user = None + + # Add sudo commands to change user if we need to. Needed to handle SubDagOperator + # case using a SequentialExecutor. + if self.run_as_user and (self.run_as_user != getpass.getuser()): + self.logger.debug("Planning to run as the {} user".format(self.run_as_user)) + cfg_dict = conf.as_dict(display_sensitive=True) + cfg_subset = { + 'core': cfg_dict.get('core', {}), + 'smtp': cfg_dict.get('smtp', {}), + 'scheduler': cfg_dict.get('scheduler', {}), + 'webserver': cfg_dict.get('webserver', {}), + } + temp_fd, cfg_path = mkstemp() + + # Give ownership of file to user; only they can read and write + subprocess.call( + ['sudo', 'chown', self.run_as_user, cfg_path] + ) + subprocess.call( + ['sudo', 'chmod', '600', cfg_path] + ) + + with os.fdopen(temp_fd, 'w') as temp_file: + json.dump(cfg_subset, temp_file) + + popen_prepend = ['sudo', '-H', '-u', self.run_as_user] + + self._cfg_path = cfg_path + self._command = popen_prepend + self._task_instance.command_as_list( + raw=True, + ignore_all_deps=local_task_job.ignore_all_deps, + ignore_depends_on_past=local_task_job.ignore_depends_on_past, + ignore_ti_state=local_task_job.ignore_ti_state, + pickle_id=local_task_job.pickle_id, + mark_success=local_task_job.mark_success, + job_id=local_task_job.id, + pool=local_task_job.pool, + cfg_path=cfg_path, + ) + self.process = None + + def _read_task_logs(self, stream): + while True: + line = stream.readline() + if len(line) == 0: + break + self.logger.info('Subtask: {}'.format(line.rstrip('\n'))) + + def run_command(self, run_with, join_args=False): + """ + Run the task command + + :param run_with: list of tokens to run the task command with + E.g. ['bash', '-c'] + :type run_with: list + :param join_args: whether to concatenate the list of command tokens + E.g. ['airflow', 'run'] vs ['airflow run'] + :param join_args: bool + :return: the process that was run + :rtype: subprocess.Popen + """ + cmd = [" ".join(self._command)] if join_args else self._command + full_cmd = run_with + cmd + self.logger.info('Running: {}'.format(full_cmd)) + proc = subprocess.Popen( + full_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT + ) + + # Start daemon thread to read subprocess logging output + log_reader = threading.Thread( + target=self._read_task_logs, + args=(proc.stdout,), + ) + log_reader.daemon = True + log_reader.start() + return proc + + def start(self): + """ + Start running the task instance in a subprocess. + """ + raise NotImplementedError() + + def return_code(self): + """ + :return: The return code associated with running the task instance or + None if the task is not yet done. + :rtype int: + """ + raise NotImplementedError() + + def terminate(self): + """ + Kill the running task instance. + """ + raise NotImplementedError() + + def on_finish(self): + """ + A callback that should be called when this is done running. + """ + if self._cfg_path and os.path.isfile(self._cfg_path): + subprocess.call(['sudo', 'rm', self._cfg_path]) diff --git a/airflow/task_runner/bash_task_runner.py b/airflow/task_runner/bash_task_runner.py new file mode 100644 index 0000000000..b73e25818d --- /dev/null +++ b/airflow/task_runner/bash_task_runner.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import psutil + +from airflow.task_runner.base_task_runner import BaseTaskRunner +from airflow.utils.helpers import kill_process_tree + + +class BashTaskRunner(BaseTaskRunner): + """ + Runs the raw Airflow task by invoking through the Bash shell. + """ + def __init__(self, local_task_job): + super(BashTaskRunner, self).__init__(local_task_job) + + def start(self): + self.process = self.run_command(['bash', '-c'], join_args=True) + + def return_code(self): + return self.process.poll() + + def terminate(self): + if self.process and psutil.pid_exists(self.process.pid): + kill_process_tree(self.logger, self.process.pid) + + def on_finish(self): + super(BashTaskRunner, self).on_finish() diff --git a/airflow/utils/file.py b/airflow/utils/file.py index d4526e9d8e..78ddeaa569 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import unicode_literals import errno +import os import shutil from tempfile import mkdtemp @@ -34,3 +35,25 @@ def TemporaryDirectory(suffix='', prefix=None, dir=None): # ENOENT - no such file or directory if e.errno != errno.ENOENT: raise e + + +def mkdirs(path, mode): + """ + Creates the directory specified by path, creating intermediate directories + as necessary. If directory already exists, this is a no-op. + + :param path: The directory to create + :type path: str + :param mode: The mode to give to the directory e.g. 0o755 + :type mode: int + :return: A list of directories that were created + :rtype: list[str] + """ + if not path or os.path.exists(path): + return [] + (head, _) = os.path.split(path) + res = mkdirs(head, mode) + os.mkdir(path) + os.chmod(path, mode) + res += [path] + return res diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 6bd7a64ac5..e66745cce9 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -22,10 +22,12 @@ import psutil from builtins import input from past.builtins import basestring from datetime import datetime +import getpass import imp -import logging import os import re +import signal +import subprocess import sys import warnings @@ -35,6 +37,7 @@ from airflow.exceptions import AirflowException # SIGKILL. TIME_TO_WAIT_AFTER_SIGTERM = 5 + def validate_key(k, max_length=250): if not isinstance(k, basestring): raise TypeError("The key has to be a string") @@ -179,6 +182,80 @@ def pprinttable(rows): return s +def kill_using_shell(pid, signal=signal.SIGTERM): + process = psutil.Process(pid) + # Use sudo only when necessary - consider SubDagOperator and SequentialExecutor case. + if process.username() != getpass.getuser(): + args = ["sudo", "kill", "-{}".format(int(signal)), str(pid)] + else: + args = ["kill", "-{}".format(int(signal)), str(pid)] + # PID may not exist and return a non-zero error code + subprocess.call(args) + + +def kill_process_tree(logger, pid): + """ + Kills the process and all of the descendants. Kills using the `kill` + shell command so that it can change users. Note: killing via PIDs + has the potential to the wrong process if the process dies and the + PID gets recycled in a narrow time window. + + :param logger: logger + :type logger: logging.Logger + """ + try: + root_process = psutil.Process(pid) + except psutil.NoSuchProcess: + logger.warn("PID: {} does not exist".format(pid)) + return + + # Check child processes to reduce cases where a child process died but + # the PID got reused. + descendant_processes = [x for x in root_process.children(recursive=True) + if x.is_running()] + + if len(descendant_processes) != 0: + logger.warn("Terminating descendant processes of {} PID: {}" + .format(root_process.cmdline(), + root_process.pid)) + temp_processes = descendant_processes[:] + for descendant in temp_processes: + logger.warn("Terminating descendant process {} PID: {}" + .format(descendant.cmdline(), descendant.pid)) + try: + kill_using_shell(descendant.pid, signal.SIGTERM) + except psutil.NoSuchProcess: + descendant_processes.remove(descendant) + + logger.warn("Waiting up to {}s for processes to exit..." + .format(TIME_TO_WAIT_AFTER_SIGTERM)) + try: + psutil.wait_procs(descendant_processes, TIME_TO_WAIT_AFTER_SIGTERM) + logger.warn("Done waiting") + except psutil.TimeoutExpired: + logger.warn("Ran out of time while waiting for " + "processes to exit") + # Then SIGKILL + descendant_processes = [x for x in root_process.children(recursive=True) + if x.is_running()] + + if len(descendant_processes) > 0: + temp_processes = descendant_processes[:] + for descendant in temp_processes: + logger.warn("Killing descendant process {} PID: {}" + .format(descendant.cmdline(), descendant.pid)) + try: + kill_using_shell(descendant.pid, signal.SIGTERM) + descendant.wait() + except psutil.NoSuchProcess: + descendant_processes.remove(descendant) + logger.warn("Killed all descendant processes of {} PID: {}" + .format(root_process.cmdline(), + root_process.pid)) + else: + logger.debug("There are no descendant processes to kill") + + def kill_descendant_processes(logger, pids_to_kill=None): """ Kills all descendant processes of this process. diff --git a/docs/security.rst b/docs/security.rst index 29f228d1e7..70db606728 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -310,3 +310,25 @@ standard port 443, you'll need to configure that too. Be aware that super user p # Optionally, set the server to listen on the standard SSL port. web_server_port = 443 base_url = http://:443 + +Impersonation +''''''''''''' + +Airflow has the ability to impersonate a unix user while running task +instances based on the task's ``run_as_user`` parameter, which takes a user's name. + +*NOTE* For impersonations to work, Airflow must be run with `sudo` as subtasks are run +with `sudo -u` and permissions of files are changed. Furthermore, the unix user needs to +exist on the worker. Here is what a simple sudoers file entry could look like to achieve +this, assuming as airflow is running as the `airflow` user. Note that this means that +the airflow user must be trusted and treated the same way as the root user. + +.. code-block:: none + airflow ALL=(ALL) NOPASSWD: ALL + +Subtasks with impersonation will still log to the same folder, except that the files they +log to will have permissions changed such that only the unix user can write to it. + +*Default impersonation* To prevent tasks that don't use impersonation to be run with +`sudo` privileges, you can set the `default_impersonation` config in `core` which sets a +default user impersonate if `run_as_user` is not set. diff --git a/run_unit_tests.sh b/run_unit_tests.sh index c2912927bc..c922a55c92 100755 --- a/run_unit_tests.sh +++ b/run_unit_tests.sh @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +set -x # environment export AIRFLOW_HOME=${AIRFLOW_HOME:=~/airflow} @@ -48,6 +49,19 @@ echo "Initializing the DB" yes | airflow resetdb airflow initdb +if [ "${TRAVIS}" ]; then + # For impersonation tests running on SQLite on Travis, make the database world readable so other + # users can update it + AIRFLOW_DB="/home/travis/airflow/airflow.db" + if [ -f "${AIRFLOW_DB}" ]; then + sudo chmod a+rw "${AIRFLOW_DB}" + fi + + # For impersonation tests on Travis, make airflow accessible to other users via the global PATH + # (which contains /usr/local/bin) + sudo ln -s "${VIRTUAL_ENV}/bin/airflow" /usr/local/bin/ +fi + echo "Starting the unit tests with the following nose arguments: "$nose_args nosetests $nose_args diff --git a/scripts/ci/airflow_travis.cfg b/scripts/ci/airflow_travis.cfg index 505bc0e48f..2834ad4556 100644 --- a/scripts/ci/airflow_travis.cfg +++ b/scripts/ci/airflow_travis.cfg @@ -22,6 +22,7 @@ load_examples = True donot_pickle = False dag_concurrency = 16 dags_are_paused_at_creation = False +default_impersonation = fernet_key = af7CN0q6ag5U3g08IsPsw3K45U7Xa0axgVFhoh-3zB8= [webserver] diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt index 9e503f9608..a5786f6b6b 100644 --- a/scripts/ci/requirements.txt +++ b/scripts/ci/requirements.txt @@ -2,6 +2,7 @@ alembic bcrypt boto celery +cgroupspy chartkick cloudant coverage diff --git a/setup.py b/setup.py index aad9984b8e..b8fe67769c 100644 --- a/setup.py +++ b/setup.py @@ -108,6 +108,9 @@ celery = [ 'celery>=3.1.17', 'flower>=0.7.3' ] +cgroups = [ + 'cgroupspy>=0.1.4', +] crypto = ['cryptography>=0.9.3'] datadog = ['datadog>=0.14.0'] doc = [ @@ -227,6 +230,7 @@ def do_setup(): 'all_dbs': all_dbs, 'async': async, 'celery': celery, + 'cgroups': cgroups, 'cloudant': cloudant, 'crypto': crypto, 'datadog': datadog, diff --git a/tests/__init__.py b/tests/__init__.py index 69abb33eba..e1e8551747 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,6 +18,7 @@ from .configuration import * from .contrib import * from .core import * from .jobs import * +from .impersonation import * from .models import * from .operators import * from .utils import * diff --git a/tests/dags/test_default_impersonation.py b/tests/dags/test_default_impersonation.py new file mode 100644 index 0000000000..41cca00e83 --- /dev/null +++ b/tests/dags/test_default_impersonation.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow.models import DAG +from airflow.operators.bash_operator import BashOperator +from datetime import datetime +from textwrap import dedent + + +DEFAULT_DATE = datetime(2016, 1, 1) + +args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, +} + +dag = DAG(dag_id='test_default_impersonation', default_args=args) + +deelevated_user = 'airflow_test_user' + +test_command = dedent( + """\ + if [ '{user}' != "$(whoami)" ]; then + echo current user $(whoami) is not {user}! + exit 1 + fi + """.format(user=deelevated_user)) + +task = BashOperator( + task_id='test_deelevated_user', + bash_command=test_command, + dag=dag, +) diff --git a/tests/dags/test_impersonation.py b/tests/dags/test_impersonation.py new file mode 100644 index 0000000000..3727903c9e --- /dev/null +++ b/tests/dags/test_impersonation.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow.models import DAG +from airflow.operators.bash_operator import BashOperator +from datetime import datetime +from textwrap import dedent + + +DEFAULT_DATE = datetime(2016, 1, 1) + +args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, +} + +dag = DAG(dag_id='test_impersonation', default_args=args) + +run_as_user = 'airflow_test_user' + +test_command = dedent( + """\ + if [ '{user}' != "$(whoami)" ]; then + echo current user is not {user}! + exit 1 + fi + """.format(user=run_as_user)) + +task = BashOperator( + task_id='test_impersonated_user', + bash_command=test_command, + dag=dag, + run_as_user=run_as_user, +) diff --git a/tests/dags/test_no_impersonation.py b/tests/dags/test_no_impersonation.py new file mode 100644 index 0000000000..0fc63dade0 --- /dev/null +++ b/tests/dags/test_no_impersonation.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow.models import DAG +from airflow.operators.bash_operator import BashOperator +from datetime import datetime +from textwrap import dedent + + +DEFAULT_DATE = datetime(2016, 1, 1) + +args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, +} + +dag = DAG(dag_id='test_no_impersonation', default_args=args) + +test_command = dedent( + """\ + sudo ls + if [ $? -ne 0 ]; then + echo 'current uid does not have root privileges!' + exit 1 + fi + """) + +task = BashOperator( + task_id='test_superuser', + bash_command=test_command, + dag=dag, +) diff --git a/tests/impersonation.py b/tests/impersonation.py new file mode 100644 index 0000000000..0777defe24 --- /dev/null +++ b/tests/impersonation.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from __future__ import print_function +import errno +import os +import subprocess +import unittest + +from airflow import jobs, models +from airflow.utils.state import State +from datetime import datetime + +DEV_NULL = '/dev/null' +TEST_DAG_FOLDER = os.path.join( + os.path.dirname(os.path.realpath(__file__)), 'dags') +DEFAULT_DATE = datetime(2015, 1, 1) +TEST_USER = 'airflow_test_user' + + +# TODO(aoen): Adding/remove a user as part of a test is very bad (especially if the user +# already existed to begin with on the OS), this logic should be moved into a test +# that is wrapped in a container like docker so that the user can be safely added/removed. +# When this is done we can also modify the sudoers file to ensure that useradd will work +# without any manual modification of the sudoers file by the agent that is running these +# tests. + +class ImpersonationTest(unittest.TestCase): + def setUp(self): + self.dagbag = models.DagBag( + dag_folder=TEST_DAG_FOLDER, + include_examples=False, + ) + try: + subprocess.check_output(['sudo', 'useradd', '-m', TEST_USER, '-g', + str(os.getegid())]) + except OSError as e: + if e.errno == errno.ENOENT: + raise unittest.SkipTest( + "The 'useradd' command did not exist so unable to test " + "impersonation; Skipping Test. These tests can only be run on a " + "linux host that supports 'useradd'." + ) + else: + raise unittest.SkipTest( + "The 'useradd' command exited non-zero; Skipping tests. Does the " + "current user have permission to run 'useradd' without a password " + "prompt (check sudoers file)?" + ) + + def tearDown(self): + subprocess.check_output(['sudo', 'userdel', '-r', TEST_USER]) + + def run_backfill(self, dag_id, task_id): + dag = self.dagbag.get_dag(dag_id) + dag.clear() + + jobs.BackfillJob( + dag=dag, + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE).run() + + ti = models.TaskInstance( + task=dag.get_task(task_id), + execution_date=DEFAULT_DATE) + ti.refresh_from_db() + self.assertEqual(ti.state, State.SUCCESS) + + def test_impersonation(self): + """ + Tests that impersonating a unix user works + """ + self.run_backfill( + 'test_impersonation', + 'test_impersonated_user' + ) + + def test_no_impersonation(self): + """ + If default_impersonation=None, tests that the job is run + as the current user (which will be a sudoer) + """ + self.run_backfill( + 'test_no_impersonation', + 'test_superuser', + ) + + def test_default_impersonation(self): + """ + If default_impersonation=TEST_USER, tests that the job defaults + to running as TEST_USER for a test without run_as_user set + """ + os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION'] = TEST_USER + + try: + self.run_backfill( + 'test_default_impersonation', + 'test_deelevated_user' + ) + finally: + del os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION']