[AIRFLOW-219][AIRFLOW-398] Cgroups + impersonation

Submitting on behalf of plypaul

Please accept this PR that addresses the following
issues:
-
https://issues.apache.org/jira/browse/AIRFLOW-219
-
https://issues.apache.org/jira/browse/AIRFLOW-398

Testing Done:
- Running on Airbnb prod (though on a different
mergebase) for many months

Credits:
Impersonation Work: georgeke did most of the work
but plypaul did quite a bit of work too.
Cgroups: plypaul did most of the work, I just did
some touch up/bug fixes (see commit history,
cgroups + impersonation commit is actually plypaul
's not mine)

Closes #1934 from aoen/ddavydov/cgroups_and_impers
onation_after_rebase
This commit is contained in:
Dan Davydov 2017-01-18 18:11:01 -08:00
Родитель 8f9a466dee
Коммит b56cb5cc97
24 изменённых файлов: 1061 добавлений и 96 удалений

Просмотреть файл

@ -89,7 +89,7 @@ cache:
- $HOME/.wheelhouse/
- $HOME/.travis_cache/
before_install:
- ssh-keygen -t rsa -C your_email@youremail.com -P '' -f ~/.ssh/id_rsa
- yes | ssh-keygen -t rsa -C your_email@youremail.com -P '' -f ~/.ssh/id_rsa
- cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys
- ln -s ~/.ssh/authorized_keys ~/.ssh/authorized_keys2
- chmod 600 ~/.ssh/*

Просмотреть файл

@ -22,7 +22,6 @@ import os
import subprocess
import textwrap
import warnings
from datetime import datetime
from importlib import import_module
import argparse
@ -53,7 +52,7 @@ from airflow.models import (DagModel, DagBag, TaskInstance,
from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
from airflow.utils import db as db_utils
from airflow.utils import logging as logging_utils
from airflow.utils.state import State
from airflow.utils.file import mkdirs
from airflow.www.app import cached_app
from sqlalchemy import func
@ -300,6 +299,7 @@ def export_helper(filepath):
varfile.write(json.dumps(var_dict, sort_keys=True, indent=4))
print("{} variables successfully exported to {}".format(len(var_dict), filepath))
def pause(args, dag=None):
set_is_paused(True, args, dag)
@ -329,15 +329,61 @@ def run(args, dag=None):
if dag:
args.dag_id = dag.dag_id
# Setting up logging
# Load custom airflow config
if args.cfg_path:
with open(args.cfg_path, 'r') as conf_file:
conf_dict = json.load(conf_file)
if os.path.exists(args.cfg_path):
os.remove(args.cfg_path)
for section, config in conf_dict.items():
for option, value in config.items():
conf.set(section, option, value)
settings.configure_vars()
settings.configure_orm()
logging.root.handlers = []
if args.raw:
# Output to STDOUT for the parent process to read and log
logging.basicConfig(
stream=sys.stdout,
level=settings.LOGGING_LEVEL,
format=settings.LOG_FORMAT)
else:
# Setting up logging to a file.
# To handle log writing when tasks are impersonated, the log files need to
# be writable by the user that runs the Airflow command and the user
# that is impersonated. This is mainly to handle corner cases with the
# SubDagOperator. When the SubDagOperator is run, all of the operators
# run under the impersonated user and create appropriate log files
# as the impersonated user. However, if the user manually runs tasks
# of the SubDagOperator through the UI, then the log files are created
# by the user that runs the Airflow command. For example, the Airflow
# run command may be run by the `airflow_sudoable` user, but the Airflow
# tasks may be run by the `airflow` user. If the log files are not
# writable by both users, then it's possible that re-running a task
# via the UI (or vice versa) results in a permission error as the task
# tries to write to a log file created by the other user.
log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args)
# Create the log file and give it group writable permissions
# TODO(aoen): Make log dirs and logs globally readable for now since the SubDag
# operator is not compatible with impersonation (e.g. if a Celery executor is used
# for a SubDag operator and the SubDag operator has a different owner than the
# parent DAG)
if not os.path.exists(directory):
os.makedirs(directory)
# Create the directory as globally writable using custom mkdirs
# as os.makedirs doesn't set mode properly.
mkdirs(directory, 0o775)
iso = args.execution_date.isoformat()
filename = "{directory}/{iso}".format(**locals())
logging.root.handlers = []
if not os.path.exists(filename):
open(filename, "a").close()
os.chmod(filename, 0o666)
logging.basicConfig(
filename=filename,
level=settings.LOGGING_LEVEL,
@ -413,6 +459,10 @@ def run(args, dag=None):
executor.heartbeat()
executor.end()
# Child processes should not flush or upload to remote
if args.raw:
return
# Force the log to flush, and set the handler to go back to normal so we
# don't continue logging to the task's log file. The flush is important
# because we subsequently read from the log to insert into S3 or Google
@ -637,7 +687,6 @@ def restart_workers(gunicorn_master_proc, num_workers_expected):
wait_until_true(lambda: num_workers_expected + excess ==
get_num_workers_running(gunicorn_master_proc))
wait_until_true(lambda: num_workers_expected ==
get_num_workers_running(gunicorn_master_proc))
@ -761,7 +810,8 @@ def webserver(args):
if conf.getint('webserver', 'worker_refresh_interval') > 0:
restart_workers(gunicorn_master_proc, num_workers)
else:
while True: time.sleep(1)
while True:
time.sleep(1)
def scheduler(args):
@ -1255,6 +1305,8 @@ class CLIFactory(object):
("-p", "--pickle"),
"Serialized pickle object of the entire dag (used internally)"),
'job_id': Arg(("-j", "--job_id"), argparse.SUPPRESS),
'cfg_path': Arg(
("--cfg_path", ), "Path to config file to use instead of airflow.cfg"),
# webserver
'port': Arg(
("-p", "--port"),
@ -1433,7 +1485,7 @@ class CLIFactory(object):
'help': "Run a single task instance",
'args': (
'dag_id', 'task_id', 'execution_date', 'subdir',
'mark_success', 'force', 'pool',
'mark_success', 'force', 'pool', 'cfg_path',
'local', 'raw', 'ignore_all_dependencies', 'ignore_dependencies',
'ignore_depends_on_past', 'ship_dag', 'pickle', 'job_id'),
}, {

Просмотреть файл

@ -166,6 +166,13 @@ donot_pickle = False
# How long before timing out a python file import while filling the DagBag
dagbag_import_timeout = 30
# The class to use for running task instances in a subprocess
task_runner = BashTaskRunner
# If set, tasks without a `run_as_user` argument will be run with this user
# Can be used to de-elevate a sudo user running Airflow when executing tasks
default_impersonation =
# What security module to use (for example kerberos):
security =

Просмотреть файл

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Просмотреть файл

@ -0,0 +1,202 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import getpass
import subprocess
import os
import uuid
from cgroupspy import trees
import psutil
from airflow.task_runner.base_task_runner import BaseTaskRunner
from airflow.utils.helpers import kill_process_tree
class CgroupTaskRunner(BaseTaskRunner):
"""
Runs the raw Airflow task in a cgroup that has containment for memory and
cpu. It uses the resource requirements defined in the task to construct
the settings for the cgroup.
Note that this task runner will only work if the Airflow user has root privileges,
e.g. if the airflow user is called `airflow` then the following entries (or an even
less restrictive ones) are needed in the sudoers file (replacing
/CGROUPS_FOLDER with your system's cgroups folder, e.g. '/sys/fs/cgroup/'):
airflow ALL= (root) NOEXEC: /bin/chown /CGROUPS_FOLDER/memory/airflow/*
airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/memory/airflow/*..*
airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/memory/airflow/* *
airflow ALL= (root) NOEXEC: /bin/chown /CGROUPS_FOLDER/cpu/airflow/*
airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/cpu/airflow/*..*
airflow ALL= (root) NOEXEC: !/bin/chown /CGROUPS_FOLDER/cpu/airflow/* *
airflow ALL= (root) NOEXEC: /bin/chmod /CGROUPS_FOLDER/memory/airflow/*
airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/memory/airflow/*..*
airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/memory/airflow/* *
airflow ALL= (root) NOEXEC: /bin/chmod /CGROUPS_FOLDER/cpu/airflow/*
airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/cpu/airflow/*..*
airflow ALL= (root) NOEXEC: !/bin/chmod /CGROUPS_FOLDER/cpu/airflow/* *
"""
def __init__(self, local_task_job):
super(CgroupTaskRunner, self).__init__(local_task_job)
self.process = None
self._finished_running = False
self._cpu_shares = None
self._mem_mb_limit = None
self._created_cpu_cgroup = False
self._created_mem_cgroup = False
self._cur_user = getpass.getuser()
def _create_cgroup(self, path):
"""
Create the specified cgroup.
:param path: The path of the cgroup to create.
E.g. cpu/mygroup/mysubgroup
:return: the Node associated with the created cgroup.
:rtype: cgroupspy.nodes.Node
"""
node = trees.Tree().root
path_split = path.split(os.sep)
for path_element in path_split:
name_to_node = {x.name: x for x in node.children}
if path_element not in name_to_node:
self.logger.debug("Creating cgroup {} in {}"
.format(path_element, node.path))
subprocess.check_output("sudo mkdir -p {}".format(path_element))
subprocess.check_output("sudo chown -R {} {}".format(
self._cur_user, path_element))
else:
self.logger.debug("Not creating cgroup {} in {} "
"since it already exists"
.format(path_element, node.path))
node = name_to_node[path_element]
return node
def _delete_cgroup(self, path):
"""
Delete the specified cgroup.
:param path: The path of the cgroup to delete.
E.g. cpu/mygroup/mysubgroup
"""
node = trees.Tree().root
path_split = path.split("/")
for path_element in path_split:
name_to_node = {x.name: x for x in node.children}
if path_element not in name_to_node:
self.logger.warn("Cgroup does not exist: {}"
.format(path))
return
else:
node = name_to_node[path_element]
# node is now the leaf node
parent = node.parent
self.logger.debug("Deleting cgroup {}/{}".format(parent, node.name))
parent.delete_cgroup(node.name)
def start(self):
# Use bash if it's already in a cgroup
cgroups = self._get_cgroup_names()
if cgroups["cpu"] != "/" or cgroups["memory"] != "/":
self.logger.debug("Already running in a cgroup (cpu: {} memory: {} so "
"not creating another one"
.format(cgroups.get("cpu"),
cgroups.get("memory")))
self.process = self.run_command(['bash', '-c'], join_args=True)
return
# Create a unique cgroup name
cgroup_name = "airflow/{}/{}".format(datetime.datetime.now().
strftime("%Y-%m-%d"),
str(uuid.uuid1()))
self.mem_cgroup_name = "memory/{}".format(cgroup_name)
self.cpu_cgroup_name = "cpu/{}".format(cgroup_name)
# Get the resource requirements from the task
task = self._task_instance.task
resources = task.resources
cpus = resources.cpus.qty
self._cpu_shares = cpus * 1024
self._mem_mb_limit = resources.ram.qty
# Create the memory cgroup
mem_cgroup_node = self._create_cgroup(self.mem_cgroup_name)
self._created_mem_cgroup = True
if self._mem_mb_limit > 0:
self.logger.debug("Setting {} with {} MB of memory"
.format(self.mem_cgroup_name, self._mem_mb_limit))
mem_cgroup_node.controller.limit_in_bytes = self._mem_mb_limit * 1024 * 1024
# Create the CPU cgroup
cpu_cgroup_node = self._create_cgroup(self.cpu_cgroup_name)
self._created_cpu_cgroup = True
if self._cpu_shares > 0:
self.logger.debug("Setting {} with {} CPU shares"
.format(self.cpu_cgroup_name, self._cpu_shares))
cpu_cgroup_node.controller.shares = self._cpu_shares
# Start the process w/ cgroups
self.logger.debug("Starting task process with cgroups cpu,memory:{}"
.format(cgroup_name))
self.process = self.run_command(
['cgexec', '-g', 'cpu,memory:{}'.format(cgroup_name)]
)
def return_code(self):
return_code = self.process.poll()
# TODO(plypaul) Monitoring the the control file in the cgroup fs is better than
# checking the return code here. The PR to use this is here:
# https://github.com/plypaul/airflow/blob/e144e4d41996300ffa93947f136eab7785b114ed/airflow/contrib/task_runner/cgroup_task_runner.py#L43
# but there were some issues installing the python butter package and
# libseccomp-dev on some hosts for some reason.
# I wasn't able to track down the root cause of the package install failures, but
# we might want to revisit that approach at some other point.
if return_code == 137:
self.logger.warn("Task failed with return code of 137. This may indicate "
"that it was killed due to excessive memory usage. "
"Please consider optimizing your task or using the "
"resources argument to reserve more memory for your "
"task")
return return_code
def terminate(self):
if self.process and psutil.pid_exists(self.process.pid):
kill_process_tree(self.logger, self.process.pid)
def on_finish(self):
# Let the OOM watcher thread know we're done to avoid false OOM alarms
self._finished_running = True
# Clean up the cgroups
if self._created_mem_cgroup:
self._delete_cgroup(self.mem_cgroup_name)
if self._created_cpu_cgroup:
self._delete_cgroup(self.cpu_cgroup_name)
def _get_cgroup_names(self):
"""
:return: a mapping between the subsystem name to the cgroup name
:rtype: dict[str, str]
"""
with open("/proc/self/cgroup") as f:
lines = f.readlines()
d = {}
for line in lines:
line_split = line.rstrip().split(":")
subsystem = line_split[1]
group_name = line_split[2]
d[subsystem] = group_name
return d

Просмотреть файл

@ -25,7 +25,6 @@ from datetime import datetime
import getpass
import logging
import socket
import subprocess
import multiprocessing
import os
import signal
@ -35,7 +34,7 @@ import time
from time import sleep
import psutil
from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_
from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_, and_
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.session import make_transient
from tabulate import tabulate
@ -45,6 +44,7 @@ from airflow import configuration as conf
from airflow.exceptions import AirflowException
from airflow.models import DagRun
from airflow.settings import Stats
from airflow.task_runner import get_task_runner
from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
from airflow.utils.state import State
from airflow.utils.db import provide_session, pessimistic_connection_handling
@ -54,15 +54,12 @@ from airflow.utils.dag_processing import (AbstractDagFileProcessor,
SimpleDagBag,
list_py_file_paths)
from airflow.utils.email import send_email
from airflow.utils.helpers import kill_descendant_processes
from airflow.utils.logging import LoggingMixin
from airflow.utils import asciiart
Base = models.Base
DagRun = models.DagRun
ID_LEN = models.ID_LEN
Stats = settings.Stats
class BaseJob(Base, LoggingMixin):
@ -956,13 +953,18 @@ class SchedulerJob(BaseJob):
:type states: Tuple[State]
:return: None
"""
# Get all the relevant task instances
# Get all the queued task instances from associated with scheduled
# DagRuns.
TI = models.TaskInstance
task_instances_to_examine = (
session
.query(TI)
.filter(TI.dag_id.in_(simple_dag_bag.dag_ids))
.filter(TI.state.in_(states))
.join(DagRun, and_(TI.dag_id == DagRun.dag_id,
TI.execution_date == DagRun.execution_date,
DagRun.state == State.RUNNING,
DagRun.run_id.like(DagRun.ID_PREFIX + '%')))
.all()
)
@ -1054,7 +1056,7 @@ class SchedulerJob(BaseJob):
task_concurrency_limit))
continue
command = TI.generate_command(
command = " ".join(TI.generate_command(
task_instance.dag_id,
task_instance.task_id,
task_instance.execution_date,
@ -1066,7 +1068,7 @@ class SchedulerJob(BaseJob):
ignore_ti_state=False,
pool=task_instance.pool,
file_path=simple_dag_bag.get_dag(task_instance.dag_id).full_filepath,
pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id)
pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id))
priority = task_instance.priority_weight
queue = task_instance.queue
@ -1659,7 +1661,7 @@ class BackfillJob(BaseJob):
# consider max_active_runs but ignore when running subdags
# "parent.child" as a dag_id is by convention a subdag
if self.dag.schedule_interval and not "." in self.dag.dag_id:
if self.dag.schedule_interval and "." not in self.dag.dag_id:
active_runs = DagRun.find(
dag_id=self.dag.dag_id,
state=State.RUNNING,
@ -1915,7 +1917,6 @@ class BackfillJob(BaseJob):
self.logger.error(msg)
ti.handle_failure(msg)
tasks_to_run.pop(key)
msg = ' | '.join([
"[backfill progress]",
"dag run {6} of {7}",
@ -2026,23 +2027,14 @@ class LocalTaskJob(BaseJob):
super(LocalTaskJob, self).__init__(*args, **kwargs)
def _execute(self):
self.task_runner = get_task_runner(self)
try:
command = self.task_instance.command(
raw=True,
ignore_all_deps = self.ignore_all_deps,
ignore_depends_on_past = self.ignore_depends_on_past,
ignore_task_deps = self.ignore_task_deps,
ignore_ti_state = self.ignore_ti_state,
pickle_id = self.pickle_id,
mark_success = self.mark_success,
job_id = self.id,
pool = self.pool
)
self.process = subprocess.Popen(['bash', '-c', command])
self.logger.info("Subprocess PID is {}".format(self.process.pid))
self.task_runner.start()
ti = self.task_instance
session = settings.Session()
ti.pid = self.process.pid
if self.task_runner.process:
ti.pid = self.task_runner.process.pid
ti.hostname = socket.getfqdn()
session.merge(ti)
session.commit()
@ -2053,8 +2045,10 @@ class LocalTaskJob(BaseJob):
'scheduler_zombie_task_threshold')
while True:
# Monitor the task to see if it's done
return_code = self.process.poll()
return_code = self.task_runner.return_code()
if return_code is not None:
self.logger.info("Task exited with return code {}"
.format(return_code))
return
# Periodically heartbeat so that the scheduler doesn't think this
@ -2079,11 +2073,11 @@ class LocalTaskJob(BaseJob):
.format(time_since_last_heartbeat,
heartbeat_time_limit))
finally:
# Kill processes that were left running
kill_descendant_processes(self.logger)
self.on_kill()
def on_kill(self):
self.process.terminate()
self.task_runner.terminate()
self.task_runner.on_finish()
@provide_session
def heartbeat_callback(self, session=None):
@ -2102,18 +2096,19 @@ class LocalTaskJob(BaseJob):
if new_ti.state == State.RUNNING:
self.was_running = True
fqdn = socket.getfqdn()
if not (fqdn == new_ti.hostname and self.process.pid == new_ti.pid):
if not (fqdn == new_ti.hostname and
self.task_runner.process.pid == new_ti.pid):
logging.warning("Recorded hostname and pid of {new_ti.hostname} "
"and {new_ti.pid} do not match this instance's "
"which are {fqdn} and "
"{self.process.pid}. Taking the poison pill. So "
"long."
"{self.task_runner.process.pid}. Taking the poison pill. "
"So long."
.format(**locals()))
raise AirflowException("Another worker/process is running this job")
elif self.was_running and hasattr(self, 'process'):
elif self.was_running and hasattr(self.task_runner, 'process'):
logging.warning(
"State of this instance has been externally set to "
"{self.task_instance.state}. "
"Taking the poison pill. So long.".format(**locals()))
self.process.terminate()
self.task_runner.terminate()
self.terminating = True

Просмотреть файл

@ -0,0 +1,37 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Add state index for dagruns to allow the quick lookup of active dagruns
Revision ID: 1a5a9e6bf2b5
Revises: 5e7d17757c7a
Create Date: 2017-01-17 10:22:53.193711
"""
# revision identifiers, used by Alembic.
revision = '1a5a9e6bf2b5'
down_revision = '5e7d17757c7a'
branch_labels = None
depends_on = None
from alembic import op
import sqlalchemy as sa
def upgrade():
op.create_index('dr_state', 'dag_run', ['state'], unique=False)
def downgrade():
op.drop_index('state', table_name='dag_run')

Просмотреть файл

@ -251,9 +251,9 @@ class DagBag(BaseDagBag, LoggingMixin):
self.logger.debug("Importing {}".format(filepath))
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
mod_name = ('unusual_prefix_'
+ hashlib.sha1(filepath.encode('utf-8')).hexdigest()
+ '_' + org_mod_name)
mod_name = ('unusual_prefix_' +
hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
'_' + org_mod_name)
if mod_name in sys.modules:
del sys.modules[mod_name]
@ -756,6 +756,7 @@ class TaskInstance(Base):
self.priority_weight = task.priority_weight_total
self.try_number = 0
self.unixname = getpass.getuser()
self.run_as_user = task.run_as_user
if state:
self.state = state
self.hostname = ''
@ -777,7 +778,39 @@ class TaskInstance(Base):
pickle_id=None,
raw=False,
job_id=None,
pool=None):
pool=None,
cfg_path=None):
"""
Returns a command that can be executed anywhere where airflow is
installed. This command is part of the message sent to executors by
the orchestrator.
"""
return " ".join(self.command_as_list(
mark_success=mark_success,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
local=local,
pickle_id=pickle_id,
raw=raw,
job_id=job_id,
pool=pool,
cfg_path=cfg_path))
def command_as_list(
self,
mark_success=False,
ignore_all_deps=False,
ignore_task_deps=False,
ignore_depends_on_past=False,
ignore_ti_state=False,
local=False,
pickle_id=None,
raw=False,
job_id=None,
pool=None,
cfg_path=None):
"""
Returns a command that can be executed anywhere where airflow is
installed. This command is part of the message sent to executors by
@ -799,15 +832,16 @@ class TaskInstance(Base):
self.execution_date,
mark_success=mark_success,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=ignore_task_deps,
ignore_depends_on_past=ignore_depends_on_past,
ignore_ti_state=ignore_ti_state,
local=local,
pickle_id=pickle_id,
file_path=path,
raw=raw,
job_id=job_id,
pool=pool)
pool=pool,
cfg_path=cfg_path)
@staticmethod
def generate_command(dag_id,
@ -823,7 +857,8 @@ class TaskInstance(Base):
file_path=None,
raw=False,
job_id=None,
pool=None
pool=None,
cfg_path=None
):
"""
Generates the shell command required to execute this task instance.
@ -860,19 +895,20 @@ class TaskInstance(Base):
:return: shell command that can be used to run the task instance
"""
iso = execution_date.isoformat()
cmd = "airflow run {dag_id} {task_id} {iso} "
cmd += "--mark_success " if mark_success else ""
cmd += "--pickle {pickle_id} " if pickle_id else ""
cmd += "--job_id {job_id} " if job_id else ""
cmd += "-A " if ignore_all_deps else ""
cmd += "-i " if ignore_task_deps else ""
cmd += "-I " if ignore_depends_on_past else ""
cmd += "--force " if ignore_ti_state else ""
cmd += "--local " if local else ""
cmd += "--pool {pool} " if pool else ""
cmd += "--raw " if raw else ""
cmd += "-sd {file_path}" if file_path else ""
return cmd.format(**locals())
cmd = ["airflow", "run", str(dag_id), str(task_id), str(iso)]
cmd.extend(["--mark_success"]) if mark_success else None
cmd.extend(["--pickle", str(pickle_id)]) if pickle_id else None
cmd.extend(["--job_id", str(job_id)]) if job_id else None
cmd.extend(["-A "]) if ignore_all_deps else None
cmd.extend(["-i"]) if ignore_task_deps else None
cmd.extend(["-I"]) if ignore_depends_on_past else None
cmd.extend(["--force"]) if ignore_ti_state else None
cmd.extend(["--local"]) if local else None
cmd.extend(["--pool", pool]) if pool else None
cmd.extend(["--raw"]) if raw else None
cmd.extend(["-sd", file_path]) if file_path else None
cmd.extend(["--cfg_path", cfg_path]) if cfg_path else None
return cmd
@property
def log_filepath(self):
@ -1825,6 +1861,8 @@ class BaseOperator(object):
:param resources: A map of resource parameter names (the argument names of the
Resources constructor) to their values.
:type resources: dict
:param run_as_user: unix username to impersonate while running the task
:type run_as_user: str
"""
# For derived classes to define which fields will get jinjaified
@ -1866,6 +1904,7 @@ class BaseOperator(object):
on_retry_callback=None,
trigger_rule=TriggerRule.ALL_SUCCESS,
resources=None,
run_as_user=None,
*args,
**kwargs):
@ -1929,6 +1968,7 @@ class BaseOperator(object):
self.adhoc = adhoc
self.priority_weight = priority_weight
self.resources = Resources(**(resources or {}))
self.run_as_user = run_as_user
# Private attributes
self._upstream_task_ids = []
@ -2854,13 +2894,7 @@ class DAG(BaseDag, LoggingMixin):
:param session:
:return: List of execution dates
"""
runs = (
session.query(DagRun)
.filter(
DagRun.dag_id == self.dag_id,
DagRun.state == State.RUNNING)
.order_by(DagRun.execution_date)
.all())
runs = DagRun.find(dag_id=self.dag_id, state=State.RUNNING)
active_dates = []
for run in runs:
@ -3488,7 +3522,6 @@ class Variable(Base):
else:
return obj.val
@classmethod
@provide_session
def get(cls, key, default_var=None, deserialize_json=False, session=None):
@ -3695,7 +3728,6 @@ class DagStat(Base):
:type full_query: bool
"""
dag_ids = set(dag_ids)
ds_ids = set(session.query(DagStat.dag_id).all())
qry = (
session.query(DagStat)

Просмотреть файл

@ -68,10 +68,7 @@ ___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /
"""
BASE_LOG_URL = '/admin/airflow/log'
AIRFLOW_HOME = os.path.expanduser(conf.get('core', 'AIRFLOW_HOME'))
SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
LOGGING_LEVEL = logging.INFO
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
# the prefix to append to gunicorn worker processes after init
GUNICORN_WORKER_READY_PREFIX = "[ready] "
@ -85,6 +82,13 @@ LOG_FORMAT_WITH_THREAD_NAME = (
'[%(asctime)s] {%(filename)s:%(lineno)d} %(threadName)s %(levelname)s - %(message)s')
SIMPLE_LOG_FORMAT = '%(asctime)s %(levelname)s - %(message)s'
AIRFLOW_HOME = None
SQL_ALCHEMY_CONN = None
DAGS_FOLDER = None
engine = None
Session = None
def policy(task_instance):
"""
@ -118,8 +122,14 @@ def configure_logging(log_format=LOG_FORMAT):
logging.basicConfig(
format=log_format, stream=sys.stdout, level=LOGGING_LEVEL)
engine = None
Session = None
def configure_vars():
global AIRFLOW_HOME
global SQL_ALCHEMY_CONN
global DAGS_FOLDER
AIRFLOW_HOME = os.path.expanduser(conf.get('core', 'AIRFLOW_HOME'))
SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
def configure_orm(disable_connection_pool=False):
@ -146,6 +156,7 @@ except:
pass
configure_logging()
configure_vars()
configure_orm()
# Const stuff

Просмотреть файл

@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow import configuration
from airflow.contrib.task_runner.cgroup_task_runner import CgroupTaskRunner
from airflow.task_runner.bash_task_runner import BashTaskRunner
from airflow.exceptions import AirflowException
_TASK_RUNNER = configuration.get('core', 'TASK_RUNNER')
def get_task_runner(local_task_job):
"""
Get the task runner that can be used to run the given job.
:param local_task_job: The LocalTaskJob associated with the TaskInstance
that needs to be executed.
:type local_task_job: airflow.jobs.LocalTaskJob
:return: The task runner to use to run the task.
:rtype: airflow.task_runner.base_task_runner.BaseTaskRunner
"""
if _TASK_RUNNER == "BashTaskRunner":
return BashTaskRunner(local_task_job)
elif _TASK_RUNNER == "CgroupTaskRunner":
return CgroupTaskRunner(local_task_job)
else:
raise AirflowException("Unknown task runner type {}".format(_TASK_RUNNER))

Просмотреть файл

@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import getpass
import os
import json
import subprocess
import threading
from airflow import configuration as conf
from airflow.utils.logging import LoggingMixin
from tempfile import mkstemp
class BaseTaskRunner(LoggingMixin):
"""
Runs Airflow task instances by invoking the `airflow run` command with raw
mode enabled in a subprocess.
"""
def __init__(self, local_task_job):
"""
:param local_task_job: The local task job associated with running the
associated task instance.
:type local_task_job: airflow.jobs.LocalTaskJob
"""
self._task_instance = local_task_job.task_instance
popen_prepend = []
cfg_path = None
if self._task_instance.run_as_user:
self.run_as_user = self._task_instance.run_as_user
else:
try:
self.run_as_user = conf.get('core', 'default_impersonation')
except conf.AirflowConfigException:
self.run_as_user = None
# Add sudo commands to change user if we need to. Needed to handle SubDagOperator
# case using a SequentialExecutor.
if self.run_as_user and (self.run_as_user != getpass.getuser()):
self.logger.debug("Planning to run as the {} user".format(self.run_as_user))
cfg_dict = conf.as_dict(display_sensitive=True)
cfg_subset = {
'core': cfg_dict.get('core', {}),
'smtp': cfg_dict.get('smtp', {}),
'scheduler': cfg_dict.get('scheduler', {}),
'webserver': cfg_dict.get('webserver', {}),
}
temp_fd, cfg_path = mkstemp()
# Give ownership of file to user; only they can read and write
subprocess.call(
['sudo', 'chown', self.run_as_user, cfg_path]
)
subprocess.call(
['sudo', 'chmod', '600', cfg_path]
)
with os.fdopen(temp_fd, 'w') as temp_file:
json.dump(cfg_subset, temp_file)
popen_prepend = ['sudo', '-H', '-u', self.run_as_user]
self._cfg_path = cfg_path
self._command = popen_prepend + self._task_instance.command_as_list(
raw=True,
ignore_all_deps=local_task_job.ignore_all_deps,
ignore_depends_on_past=local_task_job.ignore_depends_on_past,
ignore_ti_state=local_task_job.ignore_ti_state,
pickle_id=local_task_job.pickle_id,
mark_success=local_task_job.mark_success,
job_id=local_task_job.id,
pool=local_task_job.pool,
cfg_path=cfg_path,
)
self.process = None
def _read_task_logs(self, stream):
while True:
line = stream.readline()
if len(line) == 0:
break
self.logger.info('Subtask: {}'.format(line.rstrip('\n')))
def run_command(self, run_with, join_args=False):
"""
Run the task command
:param run_with: list of tokens to run the task command with
E.g. ['bash', '-c']
:type run_with: list
:param join_args: whether to concatenate the list of command tokens
E.g. ['airflow', 'run'] vs ['airflow run']
:param join_args: bool
:return: the process that was run
:rtype: subprocess.Popen
"""
cmd = [" ".join(self._command)] if join_args else self._command
full_cmd = run_with + cmd
self.logger.info('Running: {}'.format(full_cmd))
proc = subprocess.Popen(
full_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT
)
# Start daemon thread to read subprocess logging output
log_reader = threading.Thread(
target=self._read_task_logs,
args=(proc.stdout,),
)
log_reader.daemon = True
log_reader.start()
return proc
def start(self):
"""
Start running the task instance in a subprocess.
"""
raise NotImplementedError()
def return_code(self):
"""
:return: The return code associated with running the task instance or
None if the task is not yet done.
:rtype int:
"""
raise NotImplementedError()
def terminate(self):
"""
Kill the running task instance.
"""
raise NotImplementedError()
def on_finish(self):
"""
A callback that should be called when this is done running.
"""
if self._cfg_path and os.path.isfile(self._cfg_path):
subprocess.call(['sudo', 'rm', self._cfg_path])

Просмотреть файл

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import psutil
from airflow.task_runner.base_task_runner import BaseTaskRunner
from airflow.utils.helpers import kill_process_tree
class BashTaskRunner(BaseTaskRunner):
"""
Runs the raw Airflow task by invoking through the Bash shell.
"""
def __init__(self, local_task_job):
super(BashTaskRunner, self).__init__(local_task_job)
def start(self):
self.process = self.run_command(['bash', '-c'], join_args=True)
def return_code(self):
return self.process.poll()
def terminate(self):
if self.process and psutil.pid_exists(self.process.pid):
kill_process_tree(self.logger, self.process.pid)
def on_finish(self):
super(BashTaskRunner, self).on_finish()

Просмотреть файл

@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import unicode_literals
import errno
import os
import shutil
from tempfile import mkdtemp
@ -34,3 +35,25 @@ def TemporaryDirectory(suffix='', prefix=None, dir=None):
# ENOENT - no such file or directory
if e.errno != errno.ENOENT:
raise e
def mkdirs(path, mode):
"""
Creates the directory specified by path, creating intermediate directories
as necessary. If directory already exists, this is a no-op.
:param path: The directory to create
:type path: str
:param mode: The mode to give to the directory e.g. 0o755
:type mode: int
:return: A list of directories that were created
:rtype: list[str]
"""
if not path or os.path.exists(path):
return []
(head, _) = os.path.split(path)
res = mkdirs(head, mode)
os.mkdir(path)
os.chmod(path, mode)
res += [path]
return res

Просмотреть файл

@ -22,10 +22,12 @@ import psutil
from builtins import input
from past.builtins import basestring
from datetime import datetime
import getpass
import imp
import logging
import os
import re
import signal
import subprocess
import sys
import warnings
@ -35,6 +37,7 @@ from airflow.exceptions import AirflowException
# SIGKILL.
TIME_TO_WAIT_AFTER_SIGTERM = 5
def validate_key(k, max_length=250):
if not isinstance(k, basestring):
raise TypeError("The key has to be a string")
@ -179,6 +182,80 @@ def pprinttable(rows):
return s
def kill_using_shell(pid, signal=signal.SIGTERM):
process = psutil.Process(pid)
# Use sudo only when necessary - consider SubDagOperator and SequentialExecutor case.
if process.username() != getpass.getuser():
args = ["sudo", "kill", "-{}".format(int(signal)), str(pid)]
else:
args = ["kill", "-{}".format(int(signal)), str(pid)]
# PID may not exist and return a non-zero error code
subprocess.call(args)
def kill_process_tree(logger, pid):
"""
Kills the process and all of the descendants. Kills using the `kill`
shell command so that it can change users. Note: killing via PIDs
has the potential to the wrong process if the process dies and the
PID gets recycled in a narrow time window.
:param logger: logger
:type logger: logging.Logger
"""
try:
root_process = psutil.Process(pid)
except psutil.NoSuchProcess:
logger.warn("PID: {} does not exist".format(pid))
return
# Check child processes to reduce cases where a child process died but
# the PID got reused.
descendant_processes = [x for x in root_process.children(recursive=True)
if x.is_running()]
if len(descendant_processes) != 0:
logger.warn("Terminating descendant processes of {} PID: {}"
.format(root_process.cmdline(),
root_process.pid))
temp_processes = descendant_processes[:]
for descendant in temp_processes:
logger.warn("Terminating descendant process {} PID: {}"
.format(descendant.cmdline(), descendant.pid))
try:
kill_using_shell(descendant.pid, signal.SIGTERM)
except psutil.NoSuchProcess:
descendant_processes.remove(descendant)
logger.warn("Waiting up to {}s for processes to exit..."
.format(TIME_TO_WAIT_AFTER_SIGTERM))
try:
psutil.wait_procs(descendant_processes, TIME_TO_WAIT_AFTER_SIGTERM)
logger.warn("Done waiting")
except psutil.TimeoutExpired:
logger.warn("Ran out of time while waiting for "
"processes to exit")
# Then SIGKILL
descendant_processes = [x for x in root_process.children(recursive=True)
if x.is_running()]
if len(descendant_processes) > 0:
temp_processes = descendant_processes[:]
for descendant in temp_processes:
logger.warn("Killing descendant process {} PID: {}"
.format(descendant.cmdline(), descendant.pid))
try:
kill_using_shell(descendant.pid, signal.SIGTERM)
descendant.wait()
except psutil.NoSuchProcess:
descendant_processes.remove(descendant)
logger.warn("Killed all descendant processes of {} PID: {}"
.format(root_process.cmdline(),
root_process.pid))
else:
logger.debug("There are no descendant processes to kill")
def kill_descendant_processes(logger, pids_to_kill=None):
"""
Kills all descendant processes of this process.

Просмотреть файл

@ -310,3 +310,25 @@ standard port 443, you'll need to configure that too. Be aware that super user p
# Optionally, set the server to listen on the standard SSL port.
web_server_port = 443
base_url = http://<hostname or IP>:443
Impersonation
'''''''''''''
Airflow has the ability to impersonate a unix user while running task
instances based on the task's ``run_as_user`` parameter, which takes a user's name.
*NOTE* For impersonations to work, Airflow must be run with `sudo` as subtasks are run
with `sudo -u` and permissions of files are changed. Furthermore, the unix user needs to
exist on the worker. Here is what a simple sudoers file entry could look like to achieve
this, assuming as airflow is running as the `airflow` user. Note that this means that
the airflow user must be trusted and treated the same way as the root user.
.. code-block:: none
airflow ALL=(ALL) NOPASSWD: ALL
Subtasks with impersonation will still log to the same folder, except that the files they
log to will have permissions changed such that only the unix user can write to it.
*Default impersonation* To prevent tasks that don't use impersonation to be run with
`sudo` privileges, you can set the `default_impersonation` config in `core` which sets a
default user impersonate if `run_as_user` is not set.

Просмотреть файл

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
set -x
# environment
export AIRFLOW_HOME=${AIRFLOW_HOME:=~/airflow}
@ -48,6 +49,19 @@ echo "Initializing the DB"
yes | airflow resetdb
airflow initdb
if [ "${TRAVIS}" ]; then
# For impersonation tests running on SQLite on Travis, make the database world readable so other
# users can update it
AIRFLOW_DB="/home/travis/airflow/airflow.db"
if [ -f "${AIRFLOW_DB}" ]; then
sudo chmod a+rw "${AIRFLOW_DB}"
fi
# For impersonation tests on Travis, make airflow accessible to other users via the global PATH
# (which contains /usr/local/bin)
sudo ln -s "${VIRTUAL_ENV}/bin/airflow" /usr/local/bin/
fi
echo "Starting the unit tests with the following nose arguments: "$nose_args
nosetests $nose_args

Просмотреть файл

@ -22,6 +22,7 @@ load_examples = True
donot_pickle = False
dag_concurrency = 16
dags_are_paused_at_creation = False
default_impersonation =
fernet_key = af7CN0q6ag5U3g08IsPsw3K45U7Xa0axgVFhoh-3zB8=
[webserver]

Просмотреть файл

@ -2,6 +2,7 @@ alembic
bcrypt
boto
celery
cgroupspy
chartkick
cloudant
coverage

Просмотреть файл

@ -108,6 +108,9 @@ celery = [
'celery>=3.1.17',
'flower>=0.7.3'
]
cgroups = [
'cgroupspy>=0.1.4',
]
crypto = ['cryptography>=0.9.3']
datadog = ['datadog>=0.14.0']
doc = [
@ -227,6 +230,7 @@ def do_setup():
'all_dbs': all_dbs,
'async': async,
'celery': celery,
'cgroups': cgroups,
'cloudant': cloudant,
'crypto': crypto,
'datadog': datadog,

Просмотреть файл

@ -18,6 +18,7 @@ from .configuration import *
from .contrib import *
from .core import *
from .jobs import *
from .impersonation import *
from .models import *
from .operators import *
from .utils import *

Просмотреть файл

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow.models import DAG
from airflow.operators.bash_operator import BashOperator
from datetime import datetime
from textwrap import dedent
DEFAULT_DATE = datetime(2016, 1, 1)
args = {
'owner': 'airflow',
'start_date': DEFAULT_DATE,
}
dag = DAG(dag_id='test_default_impersonation', default_args=args)
deelevated_user = 'airflow_test_user'
test_command = dedent(
"""\
if [ '{user}' != "$(whoami)" ]; then
echo current user $(whoami) is not {user}!
exit 1
fi
""".format(user=deelevated_user))
task = BashOperator(
task_id='test_deelevated_user',
bash_command=test_command,
dag=dag,
)

Просмотреть файл

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow.models import DAG
from airflow.operators.bash_operator import BashOperator
from datetime import datetime
from textwrap import dedent
DEFAULT_DATE = datetime(2016, 1, 1)
args = {
'owner': 'airflow',
'start_date': DEFAULT_DATE,
}
dag = DAG(dag_id='test_impersonation', default_args=args)
run_as_user = 'airflow_test_user'
test_command = dedent(
"""\
if [ '{user}' != "$(whoami)" ]; then
echo current user is not {user}!
exit 1
fi
""".format(user=run_as_user))
task = BashOperator(
task_id='test_impersonated_user',
bash_command=test_command,
dag=dag,
run_as_user=run_as_user,
)

Просмотреть файл

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow.models import DAG
from airflow.operators.bash_operator import BashOperator
from datetime import datetime
from textwrap import dedent
DEFAULT_DATE = datetime(2016, 1, 1)
args = {
'owner': 'airflow',
'start_date': DEFAULT_DATE,
}
dag = DAG(dag_id='test_no_impersonation', default_args=args)
test_command = dedent(
"""\
sudo ls
if [ $? -ne 0 ]; then
echo 'current uid does not have root privileges!'
exit 1
fi
""")
task = BashOperator(
task_id='test_superuser',
bash_command=test_command,
dag=dag,
)

111
tests/impersonation.py Normal file
Просмотреть файл

@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from __future__ import print_function
import errno
import os
import subprocess
import unittest
from airflow import jobs, models
from airflow.utils.state import State
from datetime import datetime
DEV_NULL = '/dev/null'
TEST_DAG_FOLDER = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'dags')
DEFAULT_DATE = datetime(2015, 1, 1)
TEST_USER = 'airflow_test_user'
# TODO(aoen): Adding/remove a user as part of a test is very bad (especially if the user
# already existed to begin with on the OS), this logic should be moved into a test
# that is wrapped in a container like docker so that the user can be safely added/removed.
# When this is done we can also modify the sudoers file to ensure that useradd will work
# without any manual modification of the sudoers file by the agent that is running these
# tests.
class ImpersonationTest(unittest.TestCase):
def setUp(self):
self.dagbag = models.DagBag(
dag_folder=TEST_DAG_FOLDER,
include_examples=False,
)
try:
subprocess.check_output(['sudo', 'useradd', '-m', TEST_USER, '-g',
str(os.getegid())])
except OSError as e:
if e.errno == errno.ENOENT:
raise unittest.SkipTest(
"The 'useradd' command did not exist so unable to test "
"impersonation; Skipping Test. These tests can only be run on a "
"linux host that supports 'useradd'."
)
else:
raise unittest.SkipTest(
"The 'useradd' command exited non-zero; Skipping tests. Does the "
"current user have permission to run 'useradd' without a password "
"prompt (check sudoers file)?"
)
def tearDown(self):
subprocess.check_output(['sudo', 'userdel', '-r', TEST_USER])
def run_backfill(self, dag_id, task_id):
dag = self.dagbag.get_dag(dag_id)
dag.clear()
jobs.BackfillJob(
dag=dag,
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE).run()
ti = models.TaskInstance(
task=dag.get_task(task_id),
execution_date=DEFAULT_DATE)
ti.refresh_from_db()
self.assertEqual(ti.state, State.SUCCESS)
def test_impersonation(self):
"""
Tests that impersonating a unix user works
"""
self.run_backfill(
'test_impersonation',
'test_impersonated_user'
)
def test_no_impersonation(self):
"""
If default_impersonation=None, tests that the job is run
as the current user (which will be a sudoer)
"""
self.run_backfill(
'test_no_impersonation',
'test_superuser',
)
def test_default_impersonation(self):
"""
If default_impersonation=TEST_USER, tests that the job defaults
to running as TEST_USER for a test without run_as_user set
"""
os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION'] = TEST_USER
try:
self.run_backfill(
'test_default_impersonation',
'test_deelevated_user'
)
finally:
del os.environ['AIRFLOW__CORE__DEFAULT_IMPERSONATION']