[AIRFLOW-219][AIRFLOW-398] Cgroups + impersonation
Submitting on behalf of plypaul Please accept this PR that addresses the following issues: - https://issues.apache.org/jira/browse/AIRFLOW-219 - https://issues.apache.org/jira/browse/AIRFLOW-398 Testing Done: - Running on Airbnb prod (though on a different mergebase) for many months Credits: Impersonation Work: georgeke did most of the work but plypaul did quite a bit of work too. Cgroups: plypaul did most of the work, I just did some touch up/bug fixes (see commit history, cgroups + impersonation commit is actually plypaul 's not mine) Closes #1934 from aoen/ddavydov/cgroups_and_impers onation_after_rebase
This commit is contained in:
Родитель
8f9a466dee
Коммит
b56cb5cc97
|
@ -89,7 +89,7 @@ cache:
|
|||
- $HOME/.wheelhouse/
|
||||
- $HOME/.travis_cache/
|
||||
before_install:
|
||||
- ssh-keygen -t rsa -C your_email@youremail.com -P '' -f ~/.ssh/id_rsa
|
||||
- yes | ssh-keygen -t rsa -C your_email@youremail.com -P '' -f ~/.ssh/id_rsa
|
||||
- cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys
|
||||
- ln -s ~/.ssh/authorized_keys ~/.ssh/authorized_keys2
|
||||
- chmod 600 ~/.ssh/*
|
||||
|
|
|
@ -22,7 +22,6 @@ import os
|
|||
import subprocess
|
||||
import textwrap
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from importlib import import_module
|
||||
|
||||
import argparse
|
||||
|
@ -53,7 +52,7 @@ from airflow.models import (DagModel, DagBag, TaskInstance,
|
|||
from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
|
||||
from airflow.utils import db as db_utils
|
||||
from airflow.utils import logging as logging_utils
|
||||
from airflow.utils.state import State
|
||||
from airflow.utils.file import mkdirs
|
||||
from airflow.www.app import cached_app
|
||||
|
||||
from sqlalchemy import func
|
||||
|
@ -300,6 +299,7 @@ def export_helper(filepath):
|
|||
varfile.write(json.dumps(var_dict, sort_keys=True, indent=4))
|
||||
print("{} variables successfully exported to {}".format(len(var_dict), filepath))
|
||||
|
||||
|
||||
def pause(args, dag=None):
|
||||
set_is_paused(True, args, dag)
|
||||
|
||||
|
@ -329,15 +329,61 @@ def run(args, dag=None):
|
|||
if dag:
|
||||
args.dag_id = dag.dag_id
|
||||
|
||||
# Setting up logging
|
||||
# Load custom airflow config
|
||||
if args.cfg_path:
|
||||
with open(args.cfg_path, 'r') as conf_file:
|
||||
conf_dict = json.load(conf_file)
|
||||
|
||||
if os.path.exists(args.cfg_path):
|
||||
os.remove(args.cfg_path)
|
||||
|
||||
for section, config in conf_dict.items():
|
||||
for option, value in config.items():
|
||||
conf.set(section, option, value)
|
||||
settings.configure_vars()
|
||||
settings.configure_orm()
|
||||
|
||||
logging.root.handlers = []
|
||||
if args.raw:
|
||||
# Output to STDOUT for the parent process to read and log
|
||||
logging.basicConfig(
|
||||
stream=sys.stdout,
|
||||
level=settings.LOGGING_LEVEL,
|
||||
format=settings.LOG_FORMAT)
|
||||
else:
|
||||
# Setting up logging to a file.
|
||||
|
||||
# To handle log writing when tasks are impersonated, the log files need to
|
||||
# be writable by the user that runs the Airflow command and the user
|
||||
# that is impersonated. This is mainly to handle corner cases with the
|
||||
# SubDagOperator. When the SubDagOperator is run, all of the operators
|
||||
# run under the impersonated user and create appropriate log files
|
||||
# as the impersonated user. However, if the user manually runs tasks
|
||||
# of the SubDagOperator through the UI, then the log files are created
|
||||
# by the user that runs the Airflow command. For example, the Airflow
|
||||
# run command may be run by the `airflow_sudoable` user, but the Airflow
|
||||
# tasks may be run by the `airflow` user. If the log files are not
|
||||
# writable by both users, then it's possible that re-running a task
|
||||
# via the UI (or vice versa) results in a permission error as the task
|
||||
# tries to write to a log file created by the other user.
|
||||
log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
|
||||
directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args)
|
||||
# Create the log file and give it group writable permissions
|
||||
# TODO(aoen): Make log dirs and logs globally readable for now since the SubDag
|
||||
# operator is not compatible with impersonation (e.g. if a Celery executor is used
|
||||
# for a SubDag operator and the SubDag operator has a different owner than the
|
||||
# parent DAG)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
# Create the directory as globally writable using custom mkdirs
|
||||
# as os.makedirs doesn't set mode properly.
|
||||
mkdirs(directory, 0o775)
|
||||
iso = args.execution_date.isoformat()
|
||||
filename = "{directory}/{iso}".format(**locals())
|
||||
|
||||
logging.root.handlers = []
|
||||
if not os.path.exists(filename):
|
||||
open(filename, "a").close()
|
||||
os.chmod(filename, 0o666)
|
||||
|
||||
logging.basicConfig(
|
||||
filename=filename,
|
||||
level=settings.LOGGING_LEVEL,
|
||||
|
@ -413,6 +459,10 @@ def run(args, dag=None):
|
|||
executor.heartbeat()
|
||||
executor.end()
|
||||
|
||||
# Child processes should not flush or upload to remote
|
||||
if args.raw:
|
||||
return
|
||||
|
||||
# Force the log to flush, and set the handler to go back to normal so we
|
||||
# don't continue logging to the task's log file. The flush is important
|
||||
# because we subsequently read from the log to insert into S3 or Google
|
||||
|
@ -637,7 +687,6 @@ def restart_workers(gunicorn_master_proc, num_workers_expected):
|
|||
wait_until_true(lambda: num_workers_expected + excess ==
|
||||
get_num_workers_running(gunicorn_master_proc))
|
||||
|
||||
|
||||
wait_until_true(lambda: num_workers_expected ==
|
||||
get_num_workers_running(gunicorn_master_proc))
|
||||
|
||||
|
@ -761,7 +810,8 @@ def webserver(args):
|
|||
if conf.getint('webserver', 'worker_refresh_interval') > 0:
|
||||
restart_workers(gunicorn_master_proc, num_workers)
|
||||
else:
|
||||
while True: time.sleep(1)
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def scheduler(args):
|
||||
|
@ -1255,6 +1305,8 @@ class CLIFactory(object):
|
|||
("-p", "--pickle"),
|
||||
"Serialized pickle object of the entire dag (used internally)"),
|
||||
'job_id': Arg(("-j", "--job_id"), argparse.SUPPRESS),
|
||||
'cfg_path': Arg(
|
||||
("--cfg_path", ), "Path to config file to use instead of airflow.cfg"),
|
||||
# webserver
|
||||
'port': Arg(
|
||||
("-p", "--port"),
|
||||
|
@ -1433,7 +1485,7 @@ class CLIFactory(object):
|
|||
'help': "Run a single task instance",
|
||||
'args': (
|
||||
'dag_id', 'task_id', 'execution_date', 'subdir',
|
||||
'mark_success', 'force', 'pool',
|
||||
'mark_success', 'force', 'pool', 'cfg_path',
|
||||
'local', 'raw', 'ignore_all_dependencies', 'ignore_dependencies',
|
||||
'ignore_depends_on_past', 'ship_dag', 'pickle', 'job_id'),
|
||||
}, {
|
||||
|
|
|
@ -166,6 +166,13 @@ donot_pickle = False
|
|||
# How long before timing out a python file import while filling the DagBag
|
||||
dagbag_import_timeout = 30
|
||||
|
||||
# The class to use for running task instances in a subprocess
|
||||
task_runner = BashTaskRunner
|
||||
|
||||
# If set, tasks without a `run_as_user` argument will be run with this user
|
||||
# Can be used to de-elevate a sudo user running Airflow when executing tasks
|
||||
default_impersonation =
|
||||
|
||||
# What security module to use (for example kerberos):
|
||||
security =
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,202 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import getpass
|
||||
import subprocess
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from cgroupspy import trees
|
||||
import psutil
|
||||
|
||||
from airflow.task_runner.base_task_runner import BaseTaskRunner
|
||||
from airflow.utils.helpers import kill_process_tree
|
||||
|
||||
|
||||
class CgroupTaskRunner(BaseTaskRunner):
|
||||
"""
|
||||
Runs the raw Airflow task in a cgroup that has containment for memory and
|
||||
cpu. It uses the resource requirements defined in the task to construct
|
||||
the settings for the cgroup.
|
||||
|
||||
Note that this task runner will only work if the Airflow user has root privileges,
|
||||
e.g. if the airflow user is called `airflow` then the following entries (or an even
|
||||
less restrictive ones) are needed in the sudoers file (replacing
|
||||
/CGROUPS_FOLDER with your system's cgroups folder, e.g. '/sys/fs/cgroup/'):
|
||||
airflow ALL= (root) NOEXEC: /bin/chown /CGROUPS_FOLDER/memory/airflow/*
|
||||
airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/memory/airflow/*..*
|
||||
airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/memory/airflow/* *
|
||||
airflow ALL= (root) NOEXEC: /bin/chown /CGROUPS_FOLDER/cpu/airflow/*
|
||||
airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/cpu/airflow/*..*
|
||||
airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/cpu/airflow/* *
|
||||
airflow ALL= (root) NOEXEC: /bin/chmod /CGROUPS_FOLDER/memory/airflow/*
|
||||
airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/memory/airflow/*..*
|
||||
airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/memory/airflow/* *
|
||||
airflow ALL= (root) NOEXEC: /bin/chmod /CGROUPS_FOLDER/cpu/airflow/*
|
||||
airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/cpu/airflow/*..*
|
||||
airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/cpu/airflow/* *
|
||||
"""
|
||||
|
||||
def __init__(self, local_task_job):
|
||||
super(CgroupTaskRunner, self).__init__(local_task_job)
|
||||
self.process = None
|
||||
self._finished_running = False
|
||||
self._cpu_shares = None
|
||||
self._mem_mb_limit = None
|
||||
self._created_cpu_cgroup = False
|
||||
self._created_mem_cgroup = False
|
||||
self._cur_user = getpass.getuser()
|
||||
|
||||
def _create_cgroup(self, path):
|
||||
"""
|
||||
Create the specified cgroup.
|
||||
|
||||
:param path: The path of the cgroup to create.
|
||||
E.g. cpu/mygroup/mysubgroup
|
||||
:return: the Node associated with the created cgroup.
|
||||
:rtype: cgroupspy.nodes.Node
|
||||
"""
|
||||
node = trees.Tree().root
|
||||
path_split = path.split(os.sep)
|
||||
for path_element in path_split:
|
||||
name_to_node = {x.name: x for x in node.children}
|
||||
if path_element not in name_to_node:
|
||||
self.logger.debug("Creating cgroup {} in {}"
|
||||
.format(path_element, node.path))
|
||||
subprocess.check_output("sudo mkdir -p {}".format(path_element))
|
||||
subprocess.check_output("sudo chown -R {} {}".format(
|
||||
self._cur_user, path_element))
|
||||
else:
|
||||
self.logger.debug("Not creating cgroup {} in {} "
|
||||
"since it already exists"
|
||||
.format(path_element, node.path))
|
||||
node = name_to_node[path_element]
|
||||
return node
|
||||
|
||||
def _delete_cgroup(self, path):
|
||||
"""
|
||||
Delete the specified cgroup.
|
||||
|
||||
:param path: The path of the cgroup to delete.
|
||||
E.g. cpu/mygroup/mysubgroup
|
||||
"""
|
||||
node = trees.Tree().root
|
||||
path_split = path.split("/")
|
||||
for path_element in path_split:
|
||||
name_to_node = {x.name: x for x in node.children}
|
||||
if path_element not in name_to_node:
|
||||
self.logger.warn("Cgroup does not exist: {}"
|
||||
.format(path))
|
||||
return
|
||||
else:
|
||||
node = name_to_node[path_element]
|
||||
# node is now the leaf node
|
||||
parent = node.parent
|
||||
self.logger.debug("Deleting cgroup {}/{}".format(parent, node.name))
|
||||
parent.delete_cgroup(node.name)
|
||||
|
||||
def start(self):
|
||||
# Use bash if it's already in a cgroup
|
||||
cgroups = self._get_cgroup_names()
|
||||
if cgroups["cpu"] != "/" or cgroups["memory"] != "/":
|
||||
self.logger.debug("Already running in a cgroup (cpu: {} memory: {} so "
|
||||
"not creating another one"
|
||||
.format(cgroups.get("cpu"),
|
||||
cgroups.get("memory")))
|
||||
self.process = self.run_command(['bash', '-c'], join_args=True)
|
||||
return
|
||||
|
||||
# Create a unique cgroup name
|
||||
cgroup_name = "airflow/{}/{}".format(datetime.datetime.now().
|
||||
strftime("%Y-%m-%d"),
|
||||
str(uuid.uuid1()))
|
||||
|
||||
self.mem_cgroup_name = "memory/{}".format(cgroup_name)
|
||||
self.cpu_cgroup_name = "cpu/{}".format(cgroup_name)
|
||||
|
||||
# Get the resource requirements from the task
|
||||
task = self._task_instance.task
|
||||
resources = task.resources
|
||||
cpus = resources.cpus.qty
|
||||
self._cpu_shares = cpus * 1024
|
||||
self._mem_mb_limit = resources.ram.qty
|
||||
|
||||
# Create the memory cgroup
|
||||
mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name)
|
||||
self._created_mem_cgroup = True
|
||||
if self._mem_mb_limit > 0:
|
||||
self.logger.debug("Setting {} with {} MB of memory"
|
||||
.format(self.mem_cgroup_name, self._mem_mb_limit))
|
||||
mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024
|
||||
|
||||
# Create the CPU cgroup
|
||||
cpu_cgroup_node = self._create_cgroup(self.cpu_cgroup_name)
|
||||
self._created_cpu_cgroup = True
|
||||
if self._cpu_shares > 0:
|
||||
self.logger.debug("Setting {} with {} CPU shares"
|
||||
.format(self.cpu_cgroup_name, self._cpu_shares))
|
||||
cpu_cgroup_node.controller.shares = self._cpu_shares
|
||||
|
||||
# Start the process w/ cgroups
|
||||
self.logger.debug("Starting task process with cgroups cpu,memory:{}"
|
||||
.format(cgroup_name))
|
||||
self.process = self.run_command(
|
||||
['cgexec', '-g', 'cpu,memory:{}'.format(cgroup_name)]
|
||||
)
|
||||
|
||||
def return_code(self):
|
||||
return_code = self.process.poll()
|
||||
# TODO(plypaul) Monitoring the the control file in the cgroup fs is better than
|
||||
# checking the return code here. The PR to use this is here:
|
||||
# https://github.com/plypaul/airflow/blob/e144e4d41996300ffa93947f136eab7785b114ed/airflow/contrib/task_runner/cgroup_task_runner.py#L43
|
||||
# but there were some issues installing the python butter package and
|
||||
# libseccomp-dev on some hosts for some reason.
|
||||
# I wasn't able to track down the root cause of the package install failures, but
|
||||
# we might want to revisit that approach at some other point.
|
||||
if return_code == 137:
|
||||
self.logger.warn("Task failed with return code of 137. This may indicate "
|
||||
"that it was killed due to excessive memory usage. "
|
||||
"Please consider optimizing your task or using the "
|
||||
"resources argument to reserve more memory for your "
|
||||
"task")
|
||||
return return_code
|
||||
|
||||
def terminate(self):
|
||||
if self.process and psutil.pid_exists(self.process.pid):
|
||||
kill_process_tree(self.logger, self.process.pid)
|
||||
|
||||
def on_finish(self):
|
||||
# Let the OOM watcher thread know we're done to avoid false OOM alarms
|
||||
self._finished_running = True
|
||||
# Clean up the cgroups
|
||||
if self._created_mem_cgroup:
|
||||
self._delete_cgroup(self.mem_cgroup_name)
|
||||
if self._created_cpu_cgroup:
|
||||
self._delete_cgroup(self.cpu_cgroup_name)
|
||||
|
||||
def _get_cgroup_names(self):
|
||||
"""
|
||||
:return: a mapping between the subsystem name to the cgroup name
|
||||
:rtype: dict[str, str]
|
||||
"""
|
||||
with open("/proc/self/cgroup") as f:
|
||||
lines = f.readlines()
|
||||
d = {}
|
||||
for line in lines:
|
||||
line_split = line.rstrip().split(":")
|
||||
subsystem = line_split[1]
|
||||
group_name = line_split[2]
|
||||
d[subsystem] = group_name
|
||||
return d
|
|
@ -25,7 +25,6 @@ from datetime import datetime
|
|||
import getpass
|
||||
import logging
|
||||
import socket
|
||||
import subprocess
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
|
@ -35,7 +34,7 @@ import time
|
|||
from time import sleep
|
||||
|
||||
import psutil
|
||||
from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_
|
||||
from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_, and_
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
from tabulate import tabulate
|
||||
|
@ -45,6 +44,7 @@ from airflow import configuration as conf
|
|||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import DagRun
|
||||
from airflow.settings import Stats
|
||||
from airflow.task_runner import get_task_runner
|
||||
from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
|
||||
from airflow.utils.state import State
|
||||
from airflow.utils.db import provide_session, pessimistic_connection_handling
|
||||
|
@ -54,15 +54,12 @@ from airflow.utils.dag_processing import (AbstractDagFileProcessor,
|
|||
SimpleDagBag,
|
||||
list_py_file_paths)
|
||||
from airflow.utils.email import send_email
|
||||
from airflow.utils.helpers import kill_descendant_processes
|
||||
from airflow.utils.logging import LoggingMixin
|
||||
from airflow.utils import asciiart
|
||||
|
||||
|
||||
Base = models.Base
|
||||
DagRun = models.DagRun
|
||||
ID_LEN = models.ID_LEN
|
||||
Stats = settings.Stats
|
||||
|
||||
|
||||
class BaseJob(Base, LoggingMixin):
|
||||
|
@ -956,13 +953,18 @@ class SchedulerJob(BaseJob):
|
|||
:type states: Tuple[State]
|
||||
:return: None
|
||||
"""
|
||||
# Get all the relevant task instances
|
||||
# Get all the queued task instances from associated with scheduled
|
||||
# DagRuns.
|
||||
TI = models.TaskInstance
|
||||
task_instances_to_examine = (
|
||||
session
|
||||
.query(TI)
|
||||
.filter(TI.dag_id.in_(simple_dag_bag.dag_ids))
|
||||
.filter(TI.state.in_(states))
|
||||
.join(DagRun, and_(TI.dag_id == DagRun.dag_id,
|
||||
TI.execution_date == DagRun.execution_date,
|
||||
DagRun.state == State.RUNNING,
|
||||
DagRun.run_id.like(DagRun.ID_PREFIX + '%')))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
@ -1054,7 +1056,7 @@ class SchedulerJob(BaseJob):
|
|||
task_concurrency_limit))
|
||||
continue
|
||||
|
||||
command = TI.generate_command(
|
||||
command = " ".join(TI.generate_command(
|
||||
task_instance.dag_id,
|
||||
task_instance.task_id,
|
||||
task_instance.execution_date,
|
||||
|
@ -1066,7 +1068,7 @@ class SchedulerJob(BaseJob):
|
|||
ignore_ti_state=False,
|
||||
pool=task_instance.pool,
|
||||
file_path=simple_dag_bag.get_dag(task_instance.dag_id).full_filepath,
|
||||
pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id)
|
||||
pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id))
|
||||
|
||||
priority = task_instance.priority_weight
|
||||
queue = task_instance.queue
|
||||
|
@ -1659,7 +1661,7 @@ class BackfillJob(BaseJob):
|
|||
|
||||
# consider max_active_runs but ignore when running subdags
|
||||
# "parent.child" as a dag_id is by convention a subdag
|
||||
if self.dag.schedule_interval and not "." in self.dag.dag_id:
|
||||
if self.dag.schedule_interval and "." not in self.dag.dag_id:
|
||||
active_runs = DagRun.find(
|
||||
dag_id=self.dag.dag_id,
|
||||
state=State.RUNNING,
|
||||
|
@ -1915,7 +1917,6 @@ class BackfillJob(BaseJob):
|
|||
self.logger.error(msg)
|
||||
ti.handle_failure(msg)
|
||||
tasks_to_run.pop(key)
|
||||
|
||||
msg = ' | '.join([
|
||||
"[backfill progress]",
|
||||
"dag run {6} of {7}",
|
||||
|
@ -2026,23 +2027,14 @@ class LocalTaskJob(BaseJob):
|
|||
super(LocalTaskJob, self).__init__(*args, **kwargs)
|
||||
|
||||
def _execute(self):
|
||||
self.task_runner = get_task_runner(self)
|
||||
try:
|
||||
command = self.task_instance.command(
|
||||
raw=True,
|
||||
ignore_all_deps = self.ignore_all_deps,
|
||||
ignore_depends_on_past = self.ignore_depends_on_past,
|
||||
ignore_task_deps = self.ignore_task_deps,
|
||||
ignore_ti_state = self.ignore_ti_state,
|
||||
pickle_id = self.pickle_id,
|
||||
mark_success = self.mark_success,
|
||||
job_id = self.id,
|
||||
pool = self.pool
|
||||
)
|
||||
self.process = subprocess.Popen(['bash', '-c', command])
|
||||
self.logger.info("Subprocess PID is {}".format(self.process.pid))
|
||||
self.task_runner.start()
|
||||
|
||||
ti = self.task_instance
|
||||
session = settings.Session()
|
||||
ti.pid = self.process.pid
|
||||
if self.task_runner.process:
|
||||
ti.pid = self.task_runner.process.pid
|
||||
ti.hostname = socket.getfqdn()
|
||||
session.merge(ti)
|
||||
session.commit()
|
||||
|
@ -2053,8 +2045,10 @@ class LocalTaskJob(BaseJob):
|
|||
'scheduler_zombie_task_threshold')
|
||||
while True:
|
||||
# Monitor the task to see if it's done
|
||||
return_code = self.process.poll()
|
||||
return_code = self.task_runner.return_code()
|
||||
if return_code is not None:
|
||||
self.logger.info("Task exited with return code {}"
|
||||
.format(return_code))
|
||||
return
|
||||
|
||||
# Periodically heartbeat so that the scheduler doesn't think this
|
||||
|
@ -2079,11 +2073,11 @@ class LocalTaskJob(BaseJob):
|
|||
.format(time_since_last_heartbeat,
|
||||
heartbeat_time_limit))
|
||||
finally:
|
||||
# Kill processes that were left running
|
||||
kill_descendant_processes(self.logger)
|
||||
self.on_kill()
|
||||
|
||||
def on_kill(self):
|
||||
self.process.terminate()
|
||||
self.task_runner.terminate()
|
||||
self.task_runner.on_finish()
|
||||
|
||||
@provide_session
|
||||
def heartbeat_callback(self, session=None):
|
||||
|
@ -2102,18 +2096,19 @@ class LocalTaskJob(BaseJob):
|
|||
if new_ti.state == State.RUNNING:
|
||||
self.was_running = True
|
||||
fqdn = socket.getfqdn()
|
||||
if not (fqdn == new_ti.hostname and self.process.pid == new_ti.pid):
|
||||
if not (fqdn == new_ti.hostname and
|
||||
self.task_runner.process.pid == new_ti.pid):
|
||||
logging.warning("Recorded hostname and pid of {new_ti.hostname} "
|
||||
"and {new_ti.pid} do not match this instance's "
|
||||
"which are {fqdn} and "
|
||||
"{self.process.pid}. Taking the poison pill. So "
|
||||
"long."
|
||||
"{self.task_runner.process.pid}. Taking the poison pill. "
|
||||
"So long."
|
||||
.format(**locals()))
|
||||
raise AirflowException("Another worker/process is running this job")
|
||||
elif self.was_running and hasattr(self, 'process'):
|
||||
elif self.was_running and hasattr(self.task_runner, 'process'):
|
||||
logging.warning(
|
||||
"State of this instance has been externally set to "
|
||||
"{self.task_instance.state}. "
|
||||
"Taking the poison pill. So long.".format(**locals()))
|
||||
self.process.terminate()
|
||||
self.task_runner.terminate()
|
||||
self.terminating = True
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Add state index for dagruns to allow the quick lookup of active dagruns
|
||||
|
||||
Revision ID: 1a5a9e6bf2b5
|
||||
Revises: 5e7d17757c7a
|
||||
Create Date: 2017-01-17 10:22:53.193711
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1a5a9e6bf2b5'
|
||||
down_revision = '5e7d17757c7a'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_index('dr_state', 'dag_run', ['state'], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index('state', table_name='dag_run')
|
|
@ -251,9 +251,9 @@ class DagBag(BaseDagBag, LoggingMixin):
|
|||
|
||||
self.logger.debug("Importing {}".format(filepath))
|
||||
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
|
||||
mod_name = ('unusual_prefix_'
|
||||
+ hashlib.sha1(filepath.encode('utf-8')).hexdigest()
|
||||
+ '_' + org_mod_name)
|
||||
mod_name = ('unusual_prefix_' +
|
||||
hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
|
||||
'_' + org_mod_name)
|
||||
|
||||
if mod_name in sys.modules:
|
||||
del sys.modules[mod_name]
|
||||
|
@ -756,6 +756,7 @@ class TaskInstance(Base):
|
|||
self.priority_weight = task.priority_weight_total
|
||||
self.try_number = 0
|
||||
self.unixname = getpass.getuser()
|
||||
self.run_as_user = task.run_as_user
|
||||
if state:
|
||||
self.state = state
|
||||
self.hostname = ''
|
||||
|
@ -777,7 +778,39 @@ class TaskInstance(Base):
|
|||
pickle_id=None,
|
||||
raw=False,
|
||||
job_id=None,
|
||||
pool=None):
|
||||
pool=None,
|
||||
cfg_path=None):
|
||||
"""
|
||||
Returns a command that can be executed anywhere where airflow is
|
||||
installed. This command is part of the message sent to executors by
|
||||
the orchestrator.
|
||||
"""
|
||||
return " ".join(self.command_as_list(
|
||||
mark_success=mark_success,
|
||||
ignore_all_deps=ignore_all_deps,
|
||||
ignore_depends_on_past=ignore_depends_on_past,
|
||||
ignore_task_deps=ignore_task_deps,
|
||||
ignore_ti_state=ignore_ti_state,
|
||||
local=local,
|
||||
pickle_id=pickle_id,
|
||||
raw=raw,
|
||||
job_id=job_id,
|
||||
pool=pool,
|
||||
cfg_path=cfg_path))
|
||||
|
||||
def command_as_list(
|
||||
self,
|
||||
mark_success=False,
|
||||
ignore_all_deps=False,
|
||||
ignore_task_deps=False,
|
||||
ignore_depends_on_past=False,
|
||||
ignore_ti_state=False,
|
||||
local=False,
|
||||
pickle_id=None,
|
||||
raw=False,
|
||||
job_id=None,
|
||||
pool=None,
|
||||
cfg_path=None):
|
||||
"""
|
||||
Returns a command that can be executed anywhere where airflow is
|
||||
installed. This command is part of the message sent to executors by
|
||||
|
@ -799,15 +832,16 @@ class TaskInstance(Base):
|
|||
self.execution_date,
|
||||
mark_success=mark_success,
|
||||
ignore_all_deps=ignore_all_deps,
|
||||
ignore_depends_on_past=ignore_depends_on_past,
|
||||
ignore_task_deps=ignore_task_deps,
|
||||
ignore_depends_on_past=ignore_depends_on_past,
|
||||
ignore_ti_state=ignore_ti_state,
|
||||
local=local,
|
||||
pickle_id=pickle_id,
|
||||
file_path=path,
|
||||
raw=raw,
|
||||
job_id=job_id,
|
||||
pool=pool)
|
||||
pool=pool,
|
||||
cfg_path=cfg_path)
|
||||
|
||||
@staticmethod
|
||||
def generate_command(dag_id,
|
||||
|
@ -823,7 +857,8 @@ class TaskInstance(Base):
|
|||
file_path=None,
|
||||
raw=False,
|
||||
job_id=None,
|
||||
pool=None
|
||||
pool=None,
|
||||
cfg_path=None
|
||||
):
|
||||
"""
|
||||
Generates the shell command required to execute this task instance.
|
||||
|
@ -860,19 +895,20 @@ class TaskInstance(Base):
|
|||
:return: shell command that can be used to run the task instance
|
||||
"""
|
||||
iso = execution_date.isoformat()
|
||||
cmd = "airflow run {dag_id} {task_id} {iso} "
|
||||
cmd += "--mark_success " if mark_success else ""
|
||||
cmd += "--pickle {pickle_id} " if pickle_id else ""
|
||||
cmd += "--job_id {job_id} " if job_id else ""
|
||||
cmd += "-A " if ignore_all_deps else ""
|
||||
cmd += "-i " if ignore_task_deps else ""
|
||||
cmd += "-I " if ignore_depends_on_past else ""
|
||||
cmd += "--force " if ignore_ti_state else ""
|
||||
cmd += "--local " if local else ""
|
||||
cmd += "--pool {pool} " if pool else ""
|
||||
cmd += "--raw " if raw else ""
|
||||
cmd += "-sd {file_path}" if file_path else ""
|
||||
return cmd.format(**locals())
|
||||
cmd = ["airflow", "run", str(dag_id), str(task_id), str(iso)]
|
||||
cmd.extend(["--mark_success"]) if mark_success else None
|
||||
cmd.extend(["--pickle", str(pickle_id)]) if pickle_id else None
|
||||
cmd.extend(["--job_id", str(job_id)]) if job_id else None
|
||||
cmd.extend(["-A "]) if ignore_all_deps else None
|
||||
cmd.extend(["-i"]) if ignore_task_deps else None
|
||||
cmd.extend(["-I"]) if ignore_depends_on_past else None
|
||||
cmd.extend(["--force"]) if ignore_ti_state else None
|
||||
cmd.extend(["--local"]) if local else None
|
||||
cmd.extend(["--pool", pool]) if pool else None
|
||||
cmd.extend(["--raw"]) if raw else None
|
||||
cmd.extend(["-sd", file_path]) if file_path else None
|
||||
cmd.extend(["--cfg_path", cfg_path]) if cfg_path else None
|
||||
return cmd
|
||||
|
||||
@property
|
||||
def log_filepath(self):
|
||||
|
@ -1825,6 +1861,8 @@ class BaseOperator(object):
|
|||
:param resources: A map of resource parameter names (the argument names of the
|
||||
Resources constructor) to their values.
|
||||
:type resources: dict
|
||||
:param run_as_user: unix username to impersonate while running the task
|
||||
:type run_as_user: str
|
||||
"""
|
||||
|
||||
# For derived classes to define which fields will get jinjaified
|
||||
|
@ -1866,6 +1904,7 @@ class BaseOperator(object):
|
|||
on_retry_callback=None,
|
||||
trigger_rule=TriggerRule.ALL_SUCCESS,
|
||||
resources=None,
|
||||
run_as_user=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
|
||||
|
@ -1929,6 +1968,7 @@ class BaseOperator(object):
|
|||
self.adhoc = adhoc
|
||||
self.priority_weight = priority_weight
|
||||
self.resources = Resources(**(resources or {}))
|
||||
self.run_as_user = run_as_user
|
||||
|
||||
# Private attributes
|
||||
self._upstream_task_ids = []
|
||||
|
@ -2854,13 +2894,7 @@ class DAG(BaseDag, LoggingMixin):
|
|||
:param session:
|
||||
:return: List of execution dates
|
||||
"""
|
||||
runs = (
|
||||
session.query(DagRun)
|
||||
.filter(
|
||||
DagRun.dag_id == self.dag_id,
|
||||
DagRun.state == State.RUNNING)
|
||||
.order_by(DagRun.execution_date)
|
||||
.all())
|
||||
runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING)
|
||||
|
||||
active_dates = []
|
||||
for run in runs:
|
||||
|
@ -3488,7 +3522,6 @@ class Variable(Base):
|
|||
else:
|
||||
return obj.val
|
||||
|
||||
|
||||
@classmethod
|
||||
@provide_session
|
||||
def get(cls, key, default_var=None, deserialize_json=False, session=None):
|
||||
|
@ -3695,7 +3728,6 @@ class DagStat(Base):
|
|||
:type full_query: bool
|
||||
"""
|
||||
dag_ids = set(dag_ids)
|
||||
ds_ids = set(session.query(DagStat.dag_id).all())
|
||||
|
||||
qry = (
|
||||
session.query(DagStat)
|
||||
|
|
|
@ -68,10 +68,7 @@ ___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /
|
|||
"""
|
||||
|
||||
BASE_LOG_URL = '/admin/airflow/log'
|
||||
AIRFLOW_HOME = os.path.expanduser(conf.get('core', 'AIRFLOW_HOME'))
|
||||
SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
|
||||
LOGGING_LEVEL = logging.INFO
|
||||
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
|
||||
|
||||
# the prefix to append to gunicorn worker processes after init
|
||||
GUNICORN_WORKER_READY_PREFIX = "[ready] "
|
||||
|
@ -85,6 +82,13 @@ LOG_FORMAT_WITH_THREAD_NAME = (
|
|||
'[%(asctime)s] {%(filename)s:%(lineno)d} %(threadName)s %(levelname)s - %(message)s')
|
||||
SIMPLE_LOG_FORMAT = '%(asctime)s %(levelname)s - %(message)s'
|
||||
|
||||
AIRFLOW_HOME = None
|
||||
SQL_ALCHEMY_CONN = None
|
||||
DAGS_FOLDER = None
|
||||
|
||||
engine = None
|
||||
Session = None
|
||||
|
||||
|
||||
def policy(task_instance):
|
||||
"""
|
||||
|
@ -118,8 +122,14 @@ def configure_logging(log_format=LOG_FORMAT):
|
|||
logging.basicConfig(
|
||||
format=log_format, stream=sys.stdout, level=LOGGING_LEVEL)
|
||||
|
||||
engine = None
|
||||
Session = None
|
||||
|
||||
def configure_vars():
|
||||
global AIRFLOW_HOME
|
||||
global SQL_ALCHEMY_CONN
|
||||
global DAGS_FOLDER
|
||||
AIRFLOW_HOME = os.path.expanduser(conf.get('core', 'AIRFLOW_HOME'))
|
||||
SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
|
||||
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
|
||||
|
||||
|
||||
def configure_orm(disable_connection_pool=False):
|
||||
|
@ -146,6 +156,7 @@ except:
|
|||
pass
|
||||
|
||||
configure_logging()
|
||||
configure_vars()
|
||||
configure_orm()
|
||||
|
||||
# Const stuff
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.contrib.task_runner.cgroup_task_runner import CgroupTaskRunner
|
||||
from airflow.task_runner.bash_task_runner import BashTaskRunner
|
||||
from airflow.exceptions import AirflowException
|
||||
|
||||
_TASK_RUNNER = configuration.get('core', 'TASK_RUNNER')
|
||||
|
||||
|
||||
def get_task_runner(local_task_job):
|
||||
"""
|
||||
Get the task runner that can be used to run the given job.
|
||||
|
||||
:param local_task_job: The LocalTaskJob associated with the TaskInstance
|
||||
that needs to be executed.
|
||||
:type local_task_job: airflow.jobs.LocalTaskJob
|
||||
:return: The task runner to use to run the task.
|
||||
:rtype: airflow.task_runner.base_task_runner.BaseTaskRunner
|
||||
"""
|
||||
if _TASK_RUNNER == "BashTaskRunner":
|
||||
return BashTaskRunner(local_task_job)
|
||||
elif _TASK_RUNNER == "CgroupTaskRunner":
|
||||
return CgroupTaskRunner(local_task_job)
|
||||
else:
|
||||
raise AirflowException("Unknown task runner type {}".format(_TASK_RUNNER))
|
|
@ -0,0 +1,153 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import json
|
||||
import subprocess
|
||||
import threading
|
||||
|
||||
from airflow import configuration as conf
|
||||
from airflow.utils.logging import LoggingMixin
|
||||
from tempfile import mkstemp
|
||||
|
||||
|
||||
class BaseTaskRunner(LoggingMixin):
|
||||
"""
|
||||
Runs Airflow task instances by invoking the `airflow run` command with raw
|
||||
mode enabled in a subprocess.
|
||||
"""
|
||||
|
||||
def __init__(self, local_task_job):
|
||||
"""
|
||||
:param local_task_job: The local task job associated with running the
|
||||
associated task instance.
|
||||
:type local_task_job: airflow.jobs.LocalTaskJob
|
||||
"""
|
||||
self._task_instance = local_task_job.task_instance
|
||||
|
||||
popen_prepend = []
|
||||
cfg_path = None
|
||||
if self._task_instance.run_as_user:
|
||||
self.run_as_user = self._task_instance.run_as_user
|
||||
else:
|
||||
try:
|
||||
self.run_as_user = conf.get('core', 'default_impersonation')
|
||||
except conf.AirflowConfigException:
|
||||
self.run_as_user = None
|
||||
|
||||
# Add sudo commands to change user if we need to. Needed to handle SubDagOperator
|
||||
# case using a SequentialExecutor.
|
||||
if self.run_as_user and (self.run_as_user != getpass.getuser()):
|
||||
self.logger.debug("Planning to run as the {} user".format(self.run_as_user))
|
||||
cfg_dict = conf.as_dict(display_sensitive=True)
|
||||
cfg_subset = {
|
||||
'core': cfg_dict.get('core', {}),
|
||||
'smtp': cfg_dict.get('smtp', {}),
|
||||
'scheduler': cfg_dict.get('scheduler', {}),
|
||||
'webserver': cfg_dict.get('webserver', {}),
|
||||
}
|
||||
temp_fd, cfg_path = mkstemp()
|
||||
|
||||
# Give ownership of file to user; only they can read and write
|
||||
subprocess.call(
|
||||
['sudo', 'chown', self.run_as_user, cfg_path]
|
||||
)
|
||||
subprocess.call(
|
||||
['sudo', 'chmod', '600', cfg_path]
|
||||
)
|
||||
|
||||
with os.fdopen(temp_fd, 'w') as temp_file:
|
||||
json.dump(cfg_subset, temp_file)
|
||||
|
||||
popen_prepend = ['sudo', '-H', '-u', self.run_as_user]
|
||||
|
||||
self._cfg_path = cfg_path
|
||||
self._command = popen_prepend + self._task_instance.command_as_list(
|
||||
raw=True,
|
||||
ignore_all_deps=local_task_job.ignore_all_deps,
|
||||
ignore_depends_on_past=local_task_job.ignore_depends_on_past,
|
||||
ignore_ti_state=local_task_job.ignore_ti_state,
|
||||
pickle_id=local_task_job.pickle_id,
|
||||
mark_success=local_task_job.mark_success,
|
||||
job_id=local_task_job.id,
|
||||
pool=local_task_job.pool,
|
||||
cfg_path=cfg_path,
|
||||
)
|
||||
self.process = None
|
||||
|
||||
def _read_task_logs(self, stream):
|
||||
while True:
|
||||
line = stream.readline()
|
||||
if len(line) == 0:
|
||||
break
|
||||
self.logger.info('Subtask: {}'.format(line.rstrip('\n')))
|
||||
|
||||
def run_command(self, run_with, join_args=False):
|
||||
"""
|
||||
Run the task command
|
||||
|
||||
:param run_with: list of tokens to run the task command with
|
||||
E.g. ['bash', '-c']
|
||||
:type run_with: list
|
||||
:param join_args: whether to concatenate the list of command tokens
|
||||
E.g. ['airflow', 'run'] vs ['airflow run']
|
||||
:param join_args: bool
|
||||
:return: the process that was run
|
||||
:rtype: subprocess.Popen
|
||||
"""
|
||||
cmd = [" ".join(self._command)] if join_args else self._command
|
||||
full_cmd = run_with + cmd
|
||||
self.logger.info('Running: {}'.format(full_cmd))
|
||||
proc = subprocess.Popen(
|
||||
full_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT
|
||||
)
|
||||
|
||||
# Start daemon thread to read subprocess logging output
|
||||
log_reader = threading.Thread(
|
||||
target=self._read_task_logs,
|
||||
args=(proc.stdout,),
|
||||
)
|
||||
log_reader.daemon = True
|
||||
log_reader.start()
|
||||
return proc
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start running the task instance in a subprocess.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def return_code(self):
|
||||
"""
|
||||
:return: The return code associated with running the task instance or
|
||||
None if the task is not yet done.
|
||||
:rtype int:
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def terminate(self):
|
||||
"""
|
||||
Kill the running task instance.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_finish(self):
|
||||
"""
|
||||
A callback that should be called when this is done running.
|
||||
"""
|
||||
if self._cfg_path and os.path.isfile(self._cfg_path):
|
||||
subprocess.call(['sudo', 'rm', self._cfg_path])
|
|
@ -0,0 +1,39 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import psutil
|
||||
|
||||
from airflow.task_runner.base_task_runner import BaseTaskRunner
|
||||
from airflow.utils.helpers import kill_process_tree
|
||||
|
||||
|
||||
class BashTaskRunner(BaseTaskRunner):
|
||||
"""
|
||||
Runs the raw Airflow task by invoking through the Bash shell.
|
||||
"""
|
||||
def __init__(self, local_task_job):
|
||||
super(BashTaskRunner, self).__init__(local_task_job)
|
||||
|
||||
def start(self):
|
||||
self.process = self.run_command(['bash', '-c'], join_args=True)
|
||||
|
||||
def return_code(self):
|
||||
return self.process.poll()
|
||||
|
||||
def terminate(self):
|
||||
if self.process and psutil.pid_exists(self.process.pid):
|
||||
kill_process_tree(self.logger, self.process.pid)
|
||||
|
||||
def on_finish(self):
|
||||
super(BashTaskRunner, self).on_finish()
|
|
@ -16,6 +16,7 @@ from __future__ import absolute_import
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import errno
|
||||
import os
|
||||
import shutil
|
||||
from tempfile import mkdtemp
|
||||
|
||||
|
@ -34,3 +35,25 @@ def TemporaryDirectory(suffix='', prefix=None, dir=None):
|
|||
# ENOENT - no such file or directory
|
||||
if e.errno != errno.ENOENT:
|
||||
raise e
|
||||
|
||||
|
||||
def mkdirs(path, mode):
|
||||
"""
|
||||
Creates the directory specified by path, creating intermediate directories
|
||||
as necessary. If directory already exists, this is a no-op.
|
||||
|
||||
:param path: The directory to create
|
||||
:type path: str
|
||||
:param mode: The mode to give to the directory e.g. 0o755
|
||||
:type mode: int
|
||||
:return: A list of directories that were created
|
||||
:rtype: list[str]
|
||||
"""
|
||||
if not path or os.path.exists(path):
|
||||
return []
|
||||
(head, _) = os.path.split(path)
|
||||
res = mkdirs(head, mode)
|
||||
os.mkdir(path)
|
||||
os.chmod(path, mode)
|
||||
res += [path]
|
||||
return res
|
||||
|
|
|
@ -22,10 +22,12 @@ import psutil
|
|||
from builtins import input
|
||||
from past.builtins import basestring
|
||||
from datetime import datetime
|
||||
import getpass
|
||||
import imp
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
|
@ -35,6 +37,7 @@ from airflow.exceptions import AirflowException
|
|||
# SIGKILL.
|
||||
TIME_TO_WAIT_AFTER_SIGTERM = 5
|
||||
|
||||
|
||||
def validate_key(k, max_length=250):
|
||||
if not isinstance(k, basestring):
|
||||
raise TypeError("The key has to be a string")
|
||||
|
@ -179,6 +182,80 @@ def pprinttable(rows):
|
|||
return s
|
||||
|
||||
|
||||
def kill_using_shell(pid, signal=signal.SIGTERM):
|
||||
process = psutil.Process(pid)
|
||||
# Use sudo only when necessary - consider SubDagOperator and SequentialExecutor case.
|
||||
if process.username() != getpass.getuser():
|
||||
args = ["sudo", "kill", "-{}".format(int(signal)), str(pid)]
|
||||
else:
|
||||
args = ["kill", "-{}".format(int(signal)), str(pid)]
|
||||
# PID may not exist and return a non-zero error code
|
||||
subprocess.call(args)
|
||||
|
||||
|
||||
def kill_process_tree(logger, pid):
|
||||
"""
|
||||
Kills the process and all of the descendants. Kills using the `kill`
|
||||
shell command so that it can change users. Note: killing via PIDs
|
||||
has the potential to the wrong process if the process dies and the
|
||||
PID gets recycled in a narrow time window.
|
||||
|
||||
:param logger: logger
|
||||
:type logger: logging.Logger
|
||||
"""
|
||||
try:
|
||||
root_process = psutil.Process(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
logger.warn("PID: {} does not exist".format(pid))
|
||||
return
|
||||
|
||||
# Check child processes to reduce cases where a child process died but
|
||||
# the PID got reused.
|
||||
descendant_processes = [x for x in root_process.children(recursive=True)
|
||||
if x.is_running()]
|
||||
|
||||
if len(descendant_processes) != 0:
|
||||
logger.warn("Terminating descendant processes of {} PID: {}"
|
||||
.format(root_process.cmdline(),
|
||||
root_process.pid))
|
||||
temp_processes = descendant_processes[:]
|
||||
for descendant in temp_processes:
|
||||
logger.warn("Terminating descendant process {} PID: {}"
|
||||
.format(descendant.cmdline(), descendant.pid))
|
||||
try:
|
||||
kill_using_shell(descendant.pid, signal.SIGTERM)
|
||||
except psutil.NoSuchProcess:
|
||||
descendant_processes.remove(descendant)
|
||||
|
||||
logger.warn("Waiting up to {}s for processes to exit..."
|
||||
.format(TIME_TO_WAIT_AFTER_SIGTERM))
|
||||
try:
|
||||
psutil.wait_procs(descendant_processes, TIME_TO_WAIT_AFTER_SIGTERM)
|
||||
logger.warn("Done waiting")
|
||||
except psutil.TimeoutExpired:
|
||||
logger.warn("Ran out of time while waiting for "
|
||||
"processes to exit")
|
||||
# Then SIGKILL
|
||||
descendant_processes = [x for x in root_process.children(recursive=True)
|
||||
if x.is_running()]
|
||||
|
||||
if len(descendant_processes) > 0:
|
||||
temp_processes = descendant_processes[:]
|
||||
for descendant in temp_processes:
|
||||
logger.warn("Killing descendant process {} PID: {}"
|
||||
.format(descendant.cmdline(), descendant.pid))
|
||||
try:
|
||||
kill_using_shell(descendant.pid, signal.SIGTERM)
|
||||
descendant.wait()
|
||||
except psutil.NoSuchProcess:
|
||||
descendant_processes.remove(descendant)
|
||||
logger.warn("Killed all descendant processes of {} PID: {}"
|
||||
.format(root_process.cmdline(),
|
||||
root_process.pid))
|
||||
else:
|
||||
logger.debug("There are no descendant processes to kill")
|
||||
|
||||
|
||||
def kill_descendant_processes(logger, pids_to_kill=None):
|
||||
"""
|
||||
Kills all descendant processes of this process.
|
||||
|
|
|
@ -310,3 +310,25 @@ standard port 443, you'll need to configure that too. Be aware that super user p
|
|||
# Optionally, set the server to listen on the standard SSL port.
|
||||
web_server_port = 443
|
||||
base_url = http://<hostname or IP>:443
|
||||
|
||||
Impersonation
|
||||
'''''''''''''
|
||||
|
||||
Airflow has the ability to impersonate a unix user while running task
|
||||
instances based on the task's ``run_as_user`` parameter, which takes a user's name.
|
||||
|
||||
*NOTE* For impersonations to work, Airflow must be run with `sudo` as subtasks are run
|
||||
with `sudo -u` and permissions of files are changed. Furthermore, the unix user needs to
|
||||
exist on the worker. Here is what a simple sudoers file entry could look like to achieve
|
||||
this, assuming as airflow is running as the `airflow` user. Note that this means that
|
||||
the airflow user must be trusted and treated the same way as the root user.
|
||||
|
||||
.. code-block:: none
|
||||
airflow ALL=(ALL) NOPASSWD: ALL
|
||||
|
||||
Subtasks with impersonation will still log to the same folder, except that the files they
|
||||
log to will have permissions changed such that only the unix user can write to it.
|
||||
|
||||
*Default impersonation* To prevent tasks that don't use impersonation to be run with
|
||||
`sudo` privileges, you can set the `default_impersonation` config in `core` which sets a
|
||||
default user impersonate if `run_as_user` is not set.
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -x
|
||||
|
||||
# environment
|
||||
export AIRFLOW_HOME=${AIRFLOW_HOME:=~/airflow}
|
||||
|
@ -48,6 +49,19 @@ echo "Initializing the DB"
|
|||
yes | airflow resetdb
|
||||
airflow initdb
|
||||
|
||||
if [ "${TRAVIS}" ]; then
|
||||
# For impersonation tests running on SQLite on Travis, make the database world readable so other
|
||||
# users can update it
|
||||
AIRFLOW_DB="/home/travis/airflow/airflow.db"
|
||||
if [ -f "${AIRFLOW_DB}" ]; then
|
||||
sudo chmod a+rw "${AIRFLOW_DB}"
|
||||
fi
|
||||
|
||||
# For impersonation tests on Travis, make airflow accessible to other users via the global PATH
|
||||
# (which contains /usr/local/bin)
|
||||
sudo ln -s "${VIRTUAL_ENV}/bin/airflow" /usr/local/bin/
|
||||
fi
|
||||
|
||||
echo "Starting the unit tests with the following nose arguments: "$nose_args
|
||||
nosetests $nose_args
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ load_examples = True
|
|||
donot_pickle = False
|
||||
dag_concurrency = 16
|
||||
dags_are_paused_at_creation = False
|
||||
default_impersonation =
|
||||
fernet_key = af7CN0q6ag5U3g08IsPsw3K45U7Xa0axgVFhoh-3zB8=
|
||||
|
||||
[webserver]
|
||||
|
|
|
@ -2,6 +2,7 @@ alembic
|
|||
bcrypt
|
||||
boto
|
||||
celery
|
||||
cgroupspy
|
||||
chartkick
|
||||
cloudant
|
||||
coverage
|
||||
|
|
4
setup.py
4
setup.py
|
@ -108,6 +108,9 @@ celery = [
|
|||
'celery>=3.1.17',
|
||||
'flower>=0.7.3'
|
||||
]
|
||||
cgroups = [
|
||||
'cgroupspy>=0.1.4',
|
||||
]
|
||||
crypto = ['cryptography>=0.9.3']
|
||||
datadog = ['datadog>=0.14.0']
|
||||
doc = [
|
||||
|
@ -227,6 +230,7 @@ def do_setup():
|
|||
'all_dbs': all_dbs,
|
||||
'async': async,
|
||||
'celery': celery,
|
||||
'cgroups': cgroups,
|
||||
'cloudant': cloudant,
|
||||
'crypto': crypto,
|
||||
'datadog': datadog,
|
||||
|
|
|
@ -18,6 +18,7 @@ from .configuration import *
|
|||
from .contrib import *
|
||||
from .core import *
|
||||
from .jobs import *
|
||||
from .impersonation import *
|
||||
from .models import *
|
||||
from .operators import *
|
||||
from .utils import *
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from airflow.models import DAG
|
||||
from airflow.operators.bash_operator import BashOperator
|
||||
from datetime import datetime
|
||||
from textwrap import dedent
|
||||
|
||||
|
||||
DEFAULT_DATE = datetime(2016, 1, 1)
|
||||
|
||||
args = {
|
||||
'owner': 'airflow',
|
||||
'start_date': DEFAULT_DATE,
|
||||
}
|
||||
|
||||
dag = DAG(dag_id='test_default_impersonation', default_args=args)
|
||||
|
||||
deelevated_user = 'airflow_test_user'
|
||||
|
||||
test_command = dedent(
|
||||
"""\
|
||||
if [ '{user}' != "$(whoami)" ]; then
|
||||
echo current user $(whoami) is not {user}!
|
||||
exit 1
|
||||
fi
|
||||
""".format(user=deelevated_user))
|
||||
|
||||
task = BashOperator(
|
||||
task_id='test_deelevated_user',
|
||||
bash_command=test_command,
|
||||
dag=dag,
|
||||
)
|
|
@ -0,0 +1,45 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from airflow.models import DAG
|
||||
from airflow.operators.bash_operator import BashOperator
|
||||
from datetime import datetime
|
||||
from textwrap import dedent
|
||||
|
||||
|
||||
DEFAULT_DATE = datetime(2016, 1, 1)
|
||||
|
||||
args = {
|
||||
'owner': 'airflow',
|
||||
'start_date': DEFAULT_DATE,
|
||||
}
|
||||
|
||||
dag = DAG(dag_id='test_impersonation', default_args=args)
|
||||
|
||||
run_as_user = 'airflow_test_user'
|
||||
|
||||
test_command = dedent(
|
||||
"""\
|
||||
if [ '{user}' != "$(whoami)" ]; then
|
||||
echo current user is not {user}!
|
||||
exit 1
|
||||
fi
|
||||
""".format(user=run_as_user))
|
||||
|
||||
task = BashOperator(
|
||||
task_id='test_impersonated_user',
|
||||
bash_command=test_command,
|
||||
dag=dag,
|
||||
run_as_user=run_as_user,
|
||||
)
|
|
@ -0,0 +1,43 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from airflow.models import DAG
|
||||
from airflow.operators.bash_operator import BashOperator
|
||||
from datetime import datetime
|
||||
from textwrap import dedent
|
||||
|
||||
|
||||
DEFAULT_DATE = datetime(2016, 1, 1)
|
||||
|
||||
args = {
|
||||
'owner': 'airflow',
|
||||
'start_date': DEFAULT_DATE,
|
||||
}
|
||||
|
||||
dag = DAG(dag_id='test_no_impersonation', default_args=args)
|
||||
|
||||
test_command = dedent(
|
||||
"""\
|
||||
sudo ls
|
||||
if [ $? -ne 0 ]; then
|
||||
echo 'current uid does not have root privileges!'
|
||||
exit 1
|
||||
fi
|
||||
""")
|
||||
|
||||
task = BashOperator(
|
||||
task_id='test_superuser',
|
||||
bash_command=test_command,
|
||||
dag=dag,
|
||||
)
|
|
@ -0,0 +1,111 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# from __future__ import print_function
|
||||
import errno
|
||||
import os
|
||||
import subprocess
|
||||
import unittest
|
||||
|
||||
from airflow import jobs, models
|
||||
from airflow.utils.state import State
|
||||
from datetime import datetime
|
||||
|
||||
DEV_NULL = '/dev/null'
|
||||
TEST_DAG_FOLDER = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), 'dags')
|
||||
DEFAULT_DATE = datetime(2015, 1, 1)
|
||||
TEST_USER = 'airflow_test_user'
|
||||
|
||||
|
||||
# TODO(aoen): Adding/remove a user as part of a test is very bad (especially if the user
|
||||
# already existed to begin with on the OS), this logic should be moved into a test
|
||||
# that is wrapped in a container like docker so that the user can be safely added/removed.
|
||||
# When this is done we can also modify the sudoers file to ensure that useradd will work
|
||||
# without any manual modification of the sudoers file by the agent that is running these
|
||||
# tests.
|
||||
|
||||
class ImpersonationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dagbag = models.DagBag(
|
||||
dag_folder=TEST_DAG_FOLDER,
|
||||
include_examples=False,
|
||||
)
|
||||
try:
|
||||
subprocess.check_output(['sudo', 'useradd', '-m', TEST_USER, '-g',
|
||||
str(os.getegid())])
|
||||
except OSError as e:
|
||||
if e.errno == errno.ENOENT:
|
||||
raise unittest.SkipTest(
|
||||
"The 'useradd' command did not exist so unable to test "
|
||||
"impersonation; Skipping Test. These tests can only be run on a "
|
||||
"linux host that supports 'useradd'."
|
||||
)
|
||||
else:
|
||||
raise unittest.SkipTest(
|
||||
"The 'useradd' command exited non-zero; Skipping tests. Does the "
|
||||
"current user have permission to run 'useradd' without a password "
|
||||
"prompt (check sudoers file)?"
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
subprocess.check_output(['sudo', 'userdel', '-r', TEST_USER])
|
||||
|
||||
def run_backfill(self, dag_id, task_id):
|
||||
dag = self.dagbag.get_dag(dag_id)
|
||||
dag.clear()
|
||||
|
||||
jobs.BackfillJob(
|
||||
dag=dag,
|
||||
start_date=DEFAULT_DATE,
|
||||
end_date=DEFAULT_DATE).run()
|
||||
|
||||
ti = models.TaskInstance(
|
||||
task=dag.get_task(task_id),
|
||||
execution_date=DEFAULT_DATE)
|
||||
ti.refresh_from_db()
|
||||
self.assertEqual(ti.state, State.SUCCESS)
|
||||
|
||||
def test_impersonation(self):
|
||||
"""
|
||||
Tests that impersonating a unix user works
|
||||
"""
|
||||
self.run_backfill(
|
||||
'test_impersonation',
|
||||
'test_impersonated_user'
|
||||
)
|
||||
|
||||
def test_no_impersonation(self):
|
||||
"""
|
||||
If default_impersonation=None, tests that the job is run
|
||||
as the current user (which will be a sudoer)
|
||||
"""
|
||||
self.run_backfill(
|
||||
'test_no_impersonation',
|
||||
'test_superuser',
|
||||
)
|
||||
|
||||
def test_default_impersonation(self):
|
||||
"""
|
||||
If default_impersonation=TEST_USER, tests that the job defaults
|
||||
to running as TEST_USER for a test without run_as_user set
|
||||
"""
|
||||
os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION'] = TEST_USER
|
||||
|
||||
try:
|
||||
self.run_backfill(
|
||||
'test_default_impersonation',
|
||||
'test_deelevated_user'
|
||||
)
|
||||
finally:
|
||||
del os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION']
|
Загрузка…
Ссылка в новой задаче