incubator-airflow/airflow/models.py

1537 строки
52 KiB
Python

import copy
from datetime import datetime, timedelta
import getpass
import imp
import jinja2
import logging
import os
import dill
import re
import signal
import socket
from sqlalchemy import (
Column, Integer, String, DateTime, Text, Boolean, ForeignKey, PickleType,
Index,)
from sqlalchemy import func, or_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.orm import relationship
from airflow.executors import DEFAULT_EXECUTOR, LocalExecutor
from airflow.configuration import conf
from airflow import settings
from airflow import utils
from airflow.utils import State
from airflow.utils import apply_defaults
Base = declarative_base()
ID_LEN = 250
SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
if 'mysql' in SQL_ALCHEMY_CONN:
LongText = LONGTEXT
else:
LongText = Text
def clear_task_instances(tis, session):
'''
Clears a set of task instances, but makes sure the running ones
get killed.
'''
job_ids = []
for ti in tis:
if ti.state == State.RUNNING:
if ti.job_id:
ti.state = State.SHUTDOWN
job_ids.append(ti.job_id)
else:
session.delete(ti)
if job_ids:
from airflow.jobs import BaseJob as BJ # HA!
for job in session.query(BJ).filter(BJ.id.in_(job_ids)).all():
job.state = State.SHUTDOWN
class DagBag(object):
"""
A dagbag is a collection of dags, parsed out of a folder tree and has high
level configuration settings, like what database to use as a backend and
what executor to use to fire off tasks. This makes it easier to run
distinct environments for say production and development, tests, or for
different teams or security profiles. What would have been system level
settings are now dagbag level so that one system can run multiple,
independent settings sets.
"""
def __init__(
self,
dag_folder=None,
executor=DEFAULT_EXECUTOR,
include_examples=True):
if not dag_folder:
dag_folder = DAGS_FOLDER
logging.info("Filling up the DagBag from " + dag_folder)
self.dag_folder = dag_folder
self.dags = {}
self.file_last_changed = {}
self.executor = executor
self.collect_dags(dag_folder)
if include_examples:
example_dag_folder = os.path.join(
os.path.dirname(__file__),
'example_dags')
self.collect_dags(example_dag_folder)
self.merge_dags()
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
"""
Given a path to a python module, this method imports the module and
look for dag objects whithin it.
"""
dttm = datetime.fromtimestamp(os.path.getmtime(filepath))
mod_name, file_ext = os.path.splitext(os.path.split(filepath)[-1])
if safe_mode:
# Skip file if no obvious references to airflow or DAG are found.
f = open(filepath, 'r')
content = f.read()
f.close()
if not all([s in content for s in ('DAG', 'airflow')]):
return
if (
not only_if_updated or
filepath not in self.file_last_changed or
dttm != self.file_last_changed[filepath]):
try:
logging.info("Importing " + filepath)
m = imp.load_source(mod_name, filepath)
except:
logging.error("Failed to import: " + filepath)
self.file_last_changed[filepath] = dttm
return
for dag in m.__dict__.values():
if type(dag) == DAG:
dag.full_filepath = filepath
self.bag_dag(dag)
self.file_last_changed[filepath] = dttm
def bag_dag(self, dag):
'''
Adds the DAG into the bag, recurses into sub dags.
'''
self.dags[dag.dag_id] = dag
dag.resolve_template_files()
for subdag in dag.subdags:
subdag.full_filepath = dag.full_filepath
subdag.parent_dag = dag
self.bag_dag(subdag)
logging.info('Loaded DAG {dag}'.format(**locals()))
def collect_dags(
self,
dag_folder=DAGS_FOLDER,
only_if_updated=True):
"""
Given a file path or a folder, this file looks for python modules,
imports them and adds them to the dagbag collection.
Note that if a .airflowignore file is found while processing,
the directory, it will behaves much like a .gitignore does,
ignoring files that match any of the regex patterns specified
in the file.
"""
if os.path.isfile(dag_folder):
self.process_file(dag_folder, only_if_updated=only_if_updated)
elif os.path.isdir(dag_folder):
patterns = []
for root, dirs, files in os.walk(dag_folder):
ignore_file = [f for f in files if f == '.airflowignore']
if ignore_file:
f = open(os.path.join(root, ignore_file[0]), 'r')
patterns += [p for p in f.read().split('\n') if p]
f.close()
for f in files:
filepath = os.path.join(root, f)
if not os.path.isfile(filepath):
continue
mod_name, file_ext = os.path.splitext(
os.path.split(filepath)[-1])
if file_ext != '.py':
continue
if not any([re.findall(p, filepath) for p in patterns]):
self.process_file(
filepath, only_if_updated=only_if_updated)
def merge_dags(self):
session = settings.Session()
for dag in self.dags.values():
session.merge(dag)
session.commit()
session.close()
def paused_dags(self):
session = settings.Session()
dag_ids = [dp.dag_id for dp in session.query(DAG).filter(
DAG.is_paused == True)]
session.commit()
session.close()
return dag_ids
class User(Base):
"""
Eventually should be used for security purposes
"""
__tablename__ = "user"
id = Column(Integer, primary_key=True)
username = Column(String(ID_LEN), unique=True)
email = Column(String(500))
def __init__(self, username=None, email=None):
self.username = username
self.email = email
def __repr__(self):
return self.username
def get_id(self):
return unicode(self.id)
def is_active(self):
return True
def is_authenticated(self):
return True
def is_anonymous(self):
return False
class Connection(Base):
"""
Placeholder to store information about different database instances
connection information. The idea here is that scripts use references to
database instances (conn_id) instead of hard coding hostname, logins and
passwords when using operators or hooks.
"""
__tablename__ = "connection"
id = Column(Integer(), primary_key=True)
conn_id = Column(String(ID_LEN), unique=True)
conn_type = Column(String(500))
host = Column(String(500))
schema = Column(String(500))
login = Column(String(500))
password = Column(String(500))
port = Column(Integer())
extra = Column(String(5000))
def __init__(
self, conn_id=None, conn_type=None,
host=None, login=None, password=None,
schema=None, port=None):
self.conn_id = conn_id
self.conn_type = conn_type
self.host = host
self.login = login
self.password = password
self.schema = schema
self.port = port
def get_hook(self):
from airflow import hooks
try:
if self.conn_type == 'mysql':
return hooks.MySqlHook(mysql_conn_id=self.conn_id)
elif self.conn_type == 'postgres':
return hooks.PostgresHook(postgres_conn_id=self.conn_id)
elif self.conn_type == 'hive_cli':
return hooks.HiveCliHook(hive_cli_conn_id=self.conn_id)
elif self.conn_type == 'presto':
return hooks.PrestoHook(presto_conn_id=self.conn_id)
elif self.conn_type == 'hiveserver2':
return hooks.HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
except:
return None
def __repr__(self):
return self.conn_id
class DagPickle(Base):
"""
Dags can originate from different places (user repos, master repo, ...)
and also get executed in different places (different executors). This
object represents a version of a DAG and becomes a source of truth for
a BackfillJob execution. A pickle is a native python serialized object,
and in this case gets stored in the database for the duration of the job.
The executors pick up the DagPickle id and read the dag definition from
the database.
"""
id = Column(Integer, primary_key=True)
pickle = Column(PickleType(pickler=dill))
__tablename__ = "dag_pickle"
def __init__(self, dag):
self.dag_id = dag.dag_id
if hasattr(dag, 'template_env'):
dag.template_env = None
self.pickle = dag
class TaskInstance(Base):
"""
Task instances store the state of a task instance. This table is the
autorithy and single source of truth around what tasks have run and the
state they are in.
The SqlAchemy model doesn't have a SqlAlchemy foreign key to the task or
dag model deliberately to have more control over transactions.
Database transactions on this table should insure double triggers and
any confusion around what task instances are or aren't ready to run
even while multiple schedulers may be firing task instances.
"""
__tablename__ = "task_instance"
task_id = Column(String(ID_LEN), primary_key=True)
dag_id = Column(String(ID_LEN), primary_key=True)
execution_date = Column(DateTime, primary_key=True)
start_date = Column(DateTime)
end_date = Column(DateTime)
duration = Column(Integer)
state = Column(String(20))
try_number = Column(Integer)
hostname = Column(String(1000))
unixname = Column(String(1000))
job_id = Column(Integer)
__table_args__ = (
Index('ti_dag_state', dag_id, state),
Index('ti_state_lkp', dag_id, task_id, execution_date, state),
)
def __init__(self, task, execution_date, state=None, job=None):
self.dag_id = task.dag_id
self.task_id = task.task_id
self.execution_date = execution_date
self.state = state
self.task = task
self.try_number = 1
self.unixname = getpass.getuser()
if job:
self.job_id = job.id
def command(
self,
mark_success=False,
ignore_dependencies=False,
force=False,
local=False,
pickle_id=None,
raw=False,
job_id=None):
"""
Returns a command that can be executed anywhere where airflow is
installed. This command is part of the message sent to executors by
the orchestrator.
"""
iso = self.execution_date.isoformat()
mark_success = "--mark_success" if mark_success else ""
pickle = "--pickle {0}".format(pickle_id) if pickle_id else ""
job_id = "--job_id {0}".format(job_id) if job_id else ""
ignore_dependencies = "-i" if ignore_dependencies else ""
force = "--force" if force else ""
local = "--local" if local else ""
raw = "--raw" if raw else ""
subdir = ""
if not pickle and self.task.dag and self.task.dag.full_filepath:
subdir = "-sd {0}".format(self.task.dag.full_filepath)
return (
"airflow run "
"{self.dag_id} {self.task_id} {iso} "
"{mark_success} "
"{pickle} "
"{local} "
"{ignore_dependencies} "
"{force} "
"{job_id} "
"{raw} "
"{subdir} "
).format(**locals())
@property
def log_filepath(self):
iso = self.execution_date.isoformat()
log = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
return (
"{log}/{self.dag_id}/{self.task_id}/{iso}.log".format(**locals()))
@property
def log_url(self):
iso = self.execution_date.isoformat()
BASE_URL = conf.get('webserver', 'BASE_URL')
return BASE_URL + (
"/admin/airflow/log"
"?dag_id={self.dag_id}"
"&task_id={self.task_id}"
"&execution_date={iso}"
).format(**locals())
def current_state(self, main_session=None):
"""
Get the very latest state from the database, if a session is passed,
we use and looking up the state becomes part of the session, otherwise
a new session is used.
"""
session = main_session or settings.Session()
TI = TaskInstance
ti = session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == self.task_id,
TI.execution_date == self.execution_date,
).all()
if ti:
state = ti[0].state
else:
state = None
if not main_session:
session.commit()
session.close()
return state
def error(self, main_session=None):
"""
Forces the task instance's state to FAILED in the database.
"""
session = settings.Session()
logging.error("Recording the task instance as FAILED")
self.state = State.FAILED
session.merge(self)
session.commit()
session.close()
def refresh_from_db(self, main_session=None):
"""
Refreshes the task instance from the database based on the primary key
"""
session = main_session or settings.Session()
TI = TaskInstance
ti = session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == self.task_id,
TI.execution_date == self.execution_date,
).first()
if ti:
self.state = ti.state
self.start_date = ti.start_date
self.end_date = ti.end_date
self.try_number = ti.try_number
if not main_session:
session.commit()
session.close()
@property
def key(self):
"""
Returns a tuple that identifies the task instance uniquely
"""
return (self.dag_id, self.task_id, self.execution_date)
def is_runnable(self):
"""
Returns a boolean on whether the task instance has met all dependencies
and is ready to run. It considers the task's state, the state
of its dependencies, depends_on_past and makes sure the execution
isn't in the future.
"""
if self.execution_date > datetime.now() - self.task.schedule_interval:
return False
elif self.state == State.UP_FOR_RETRY and not self.ready_for_retry():
return False
elif self.task.end_date and self.execution_date > self.task.end_date:
return False
elif self.state in State.runnable() and self.are_dependencies_met():
return True
else:
return False
def are_dependents_done(self, main_session=None):
"""
Checks whether the dependents of this task instance have all succeeded.
This is meant to be used by wait_for_downstream.
This is useful when you do not want to start processing the next
schedule of a task until the dependents are done. For instance,
if the task DROPs and recreates a table.
"""
session = main_session or settings.Session()
task = self.task
if not task._downstream_list:
return True
downstream_task_ids = [t.task_id for t in task._downstream_list]
ti = session.query(func.count(TaskInstance.task_id)).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id.in_(downstream_task_ids),
TaskInstance.execution_date == self.execution_date,
TaskInstance.state == State.SUCCESS,
)
count = ti[0][0]
if not main_session:
session.commit()
session.close()
return count == len(task._downstream_list)
def are_dependencies_met(self, main_session=None):
"""
Returns a boolean on whether the upstream tasks are in a SUCCESS state
and considers depends_on_past and the previous' run state.
"""
TI = TaskInstance
# Using the session if passed as param
session = main_session or settings.Session()
task = self.task
# Checking that the depends_on_past is fulfilled
if (task.depends_on_past and
not self.execution_date == task.start_date):
previous_ti = session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == task.task_id,
TI.execution_date ==
self.execution_date-task.schedule_interval,
TI.state == State.SUCCESS,
).first()
if not previous_ti:
return False
# Applying wait_for_downstream
previous_ti.task = self.task
if task.wait_for_downstream and not \
previous_ti.are_dependents_done(session):
return False
# Checking that all upstream dependencies have succeeded
if task._upstream_list:
upstream_task_ids = [t.task_id for t in task._upstream_list]
ti = session.query(func.count(TI.task_id)).filter(
TI.dag_id == self.dag_id,
TI.task_id.in_(upstream_task_ids),
TI.execution_date == self.execution_date,
TI.state == State.SUCCESS,
)
count = ti[0][0]
if count < len(task._upstream_list):
return False
if not main_session:
session.commit()
session.close()
return True
def __repr__(self):
return (
"<TaskInstance: {ti.dag_id}.{ti.task_id} "
"{ti.execution_date} [{ti.state}]>"
).format(ti=self)
def ready_for_retry(self):
"""
Checks on whether the task instance is in the right state and timeframe
to be retried.
"""
return self.state == State.UP_FOR_RETRY and \
self.end_date + self.task.retry_delay < datetime.now()
def run(
self,
verbose=True,
ignore_dependencies=False, # Doesn't check for deps, just runs
force=False, # Disregards previous successes
mark_success=False, # Don't run the task, act as if it succeeded
test_mode=False, # Doesn't record success or failure in the DB
job_id=None,):
"""
Runs the task instance.
"""
task = self.task
session = settings.Session()
self.refresh_from_db(session)
session.commit()
self.job_id = job_id
iso = datetime.now().isoformat()
self.hostname = socket.gethostname()
msg = "\n"
msg += ("-" * 80)
if self.state == State.UP_FOR_RETRY:
msg += "\nRetry run {self.try_number} out of {task.retries} "
msg += "starting @{iso}\n"
else:
msg += "\nNew run starting @{iso}\n"
msg += ("-" * 80)
logging.info(msg.format(**locals()))
if not force and self.state == State.SUCCESS:
logging.info(
"Task {self} previously succeeded"
" on {self.end_date}".format(**locals())
)
elif not ignore_dependencies and \
not self.are_dependencies_met(session):
logging.warning("Dependencies not met yet")
elif self.state == State.UP_FOR_RETRY and \
not self.ready_for_retry():
next_run = (self.end_date + task.retry_delay).isoformat()
logging.info(
"Not ready for retry yet. " +
"Next run after {0}".format(next_run)
)
elif force or self.state in State.runnable():
if self.state == State.UP_FOR_RETRY:
self.try_number += 1
else:
self.try_number = 1
if not test_mode:
session.add(Log(State.RUNNING, self))
self.state = State.RUNNING
self.start_date = datetime.now()
self.end_date = None
if not test_mode:
session.merge(self)
session.commit()
if verbose:
if mark_success:
msg = "Marking success for "
else:
msg = "Executing "
msg += "{self.task} for {self.execution_date}"
try:
logging.info(msg.format(self=self))
if not mark_success:
task_copy = copy.copy(task)
self.task = task_copy
def signal_handler(signum, frame):
'''Setting kill signal handler'''
logging.error("Killing subprocess")
task_copy.on_kill()
raise Exception("Task received SIGTERM signal")
signal.signal(signal.SIGTERM, signal_handler)
self.render_templates()
task_copy.execute(context=self.get_template_context())
except (Exception, StandardError, KeyboardInterrupt) as e:
self.record_failure(e, test_mode)
raise e
# Recording SUCCESS
session = settings.Session()
self.end_date = datetime.now()
self.set_duration()
self.state = State.SUCCESS
if not test_mode:
session.add(Log(State.SUCCESS, self))
session.merge(self)
session.commit()
def record_failure(self, error, test_mode=False):
logging.exception(error)
task = self.task
session = settings.Session()
self.end_date = datetime.now()
self.set_duration()
if not test_mode:
session.add(Log(State.FAILED, self))
# Let's go deeper
try:
if self.try_number <= task.retries:
self.state = State.UP_FOR_RETRY
if task.email_on_retry and task.email:
self.email_alert(error, is_retry=True)
else:
self.state = State.FAILED
if task.email_on_failure and task.email:
self.email_alert(error, is_retry=False)
except Exception as e2:
logging.error(
'Failed to send email to: ' + str(task.email))
logging.error(str(e2))
if not test_mode:
session.merge(self)
session.commit()
logging.error(str(error))
def get_template_context(self):
task = self.task
from airflow import macros
tables = None
if 'tables' in task.params:
tables = task.params['tables']
ds = self.execution_date.isoformat()[:10]
yesterday_ds = (self.execution_date - timedelta(1)).isoformat()[:10]
tomorrow_ds = (self.execution_date + timedelta(1)).isoformat()[:10]
ds_nodash = ds.replace('-', '')
ti_key_str = "{task.dag_id}__{task.task_id}__{ds_nodash}"
ti_key_str = ti_key_str.format(**locals())
params = {}
if hasattr(task, 'dag') and task.dag.params:
params.update(task.dag.params)
if task.params:
params.update(task.params)
return {
'dag': task.dag,
'ds': ds,
'yesterday_ds': yesterday_ds,
'tomorrow_ds': tomorrow_ds,
'END_DATE': ds,
'ds_nodash': ds_nodash,
'end_date': ds,
'execution_date': self.execution_date,
'latest_date': ds,
'macros': macros,
'params': params,
'tables': tables,
'task': task,
'task_instance': self,
'ti': self,
'task_instance_key_str': ti_key_str
}
def render_templates(self):
task = self.task
jinja_context = self.get_template_context()
if hasattr(self, 'task') and hasattr(self.task, 'dag'):
if self.task.dag.user_defined_macros:
jinja_context.update(
self.task.dag.user_defined_macros)
for attr in task.__class__.template_fields:
result = getattr(task, attr)
template = self.task.get_template(attr)
result = template.render(**jinja_context)
setattr(task, attr, result)
def email_alert(self, exception, is_retry=False):
task = self.task
title = "Airflow alert: {self}".format(**locals())
exception = str(exception).replace('\n', '<br>')
try_ = task.retries + 1
body = (
"Try {self.try_number} out of {try_}<br>"
"Exception:<br>{exception}<br>"
"Log: <a href='{self.log_url}'>Link</a><br>"
"Host: {self.hostname}<br>"
"Log file: {self.log_filepath}<br>"
).format(**locals())
utils.send_email(task.email, title, body)
def set_duration(self):
if self.end_date and self.start_date:
self.duration = (self.end_date - self.start_date).seconds
else:
self.duration = None
class Log(Base):
"""
Used to actively log events to the database
"""
__tablename__ = "log"
id = Column(Integer, primary_key=True)
dttm = Column(DateTime)
dag_id = Column(String(ID_LEN))
task_id = Column(String(ID_LEN))
event = Column(String(30))
execution_date = Column(DateTime)
owner = Column(String(500))
def __init__(self, event, task_instance):
self.dttm = datetime.now()
self.dag_id = task_instance.dag_id
self.task_id = task_instance.task_id
self.execution_date = task_instance.execution_date
self.event = event
self.owner = task_instance.task.owner
class BaseOperator(Base):
"""
Abstract base class for all operators. Since operators create objects that
become node in the dag, BaseOperator contains many recursive methods for
dag crawling behavior. To derive this class, you are expected to override
the constructor as well as the 'execute' method.
Operators derived from this task should perform or trigger certain tasks
synchronously (wait for completion). Example of operators could be an
operator the runs a Pig job (PigOperator), a sensor operator that
waits for a partition to land in Hive (HiveSensorOperator), or one that
moves data from Hive to MySQL (Hive2MySqlOperator). Instances of these
operators (tasks) target specific operations, running specific scripts,
functions or data transfers.
This class is abstract and shouldn't be instantiated. Instantiating a
class derived from this one results in the creation of a task object,
which ultimately becomes a node in DAG objects. Task dependencies should
be set by using the set_upstream and/or set_downstream methods.
Note that this class is derived from SQLAlquemy's Base class, which
allows us to push metadata regarding tasks to the database. Deriving this
classes needs to implement the polymorphic specificities documented in
SQLAlchemy. This should become clear while reading the code for other
operators.
:param task_id: a unique, meaningful id for the task
:type task_id: string
:param owner: the owner of the task, using the unix username is recommended
:type owner: string
:param retries: the number of retries that should be performed before
failing the task
:type retries: int
:param retry_delay: delay between retries
:type retry_delay: timedelta
:param start_date: start date for the task, the scheduler will start from
this point in time
:type start_date: datetime
:param end_date: if specified, the scheduler won't go beyond this date
:type end_date: datetime
:param schedule_interval: interval at which to schedule the task
:type schedule_interval: timedelta
:param depends_on_past: when set to true, task instances will run
sequentially while relying on the previous task's schedule to
succeed. The task instance for the start_date is allowed to run.
:type depends_on_past: boolean
:param dag: a reference to the dag the task is attached to (if any)
:type dag: DAG
"""
# For derived classes to define which fields will get jinjaified
template_fields = []
# Defines wich files extensions to look for in the templated fields
template_ext = []
# Defines the color in the UI
ui_color = '#fff'
ui_fgcolor = '#000'
__tablename__ = "task"
dag_id = Column(String(ID_LEN), primary_key=True)
task_id = Column(String(ID_LEN), primary_key=True)
owner = Column(String(500))
task_type = Column(String(20))
start_date = Column(DateTime())
end_date = Column(DateTime())
depends_on_past = Column(Integer())
__mapper_args__ = {
'polymorphic_on': task_type,
'polymorphic_identity': 'BaseOperator'
}
@apply_defaults
def __init__(
self,
task_id,
owner,
email=None,
email_on_retry=True,
email_on_failure=True,
retries=0,
retry_delay=timedelta(seconds=300),
start_date=None,
end_date=None,
schedule_interval=timedelta(days=1),
depends_on_past=False,
wait_for_downstream=False,
dag=None,
params=None,
default_args=None,
adhoc=False,
*args,
**kwargs):
utils.validate_key(task_id)
self.dag_id = dag.dag_id if dag else 'adhoc_' + owner
self.task_id = task_id
self.owner = owner
self.email = email
self.email_on_retry = email_on_retry
self.email_on_failure = email_on_failure
self.start_date = start_date
self.end_date = end_date
self.depends_on_past = depends_on_past
self.wait_for_downstream = wait_for_downstream
self._schedule_interval = schedule_interval
self.retries = retries
if isinstance(retry_delay, timedelta):
self.retry_delay = retry_delay
else:
logging.info("retry_delay isn't timedelta object, assuming secs")
self.retry_delay = timedelta(seconds=retry_delay)
self.params = params or {} # Available in templates!
self.adhoc = adhoc
if dag:
dag.add_task(self)
self.dag = dag
# Private attributes
self._upstream_list = []
self._downstream_list = []
@property
def schedule_interval(self):
"""
The schedule interval of the DAG always wins over individual tasks so
that tasks whitin a DAG always line up. The task still needs a
schedule_interval as it may not be attached to a DAG.
"""
if hasattr(self, 'dag') and self.dag:
return self.dag.schedule_interval
else:
return self._schedule_interval
def execute(self, context):
'''
This is the main method to derive when creating an operator.
Context is the same dictionary used as when rendering jinja templates.
Refer to get_template_context for more context.
'''
raise NotImplemented()
def on_kill(self):
'''
Override this method to cleanup subprocesses when a task instance
gets killed. Any use of the threading, subprocess or multiprocessing
module whithin an operator needs to be cleaned up or it will leave
ghost processes behind.
'''
pass
def get_template(self, attr):
content = getattr(self, attr)
if hasattr(self, 'dag'):
env = self.dag.get_template_env()
else:
env = jinja2.Environment(cache_size=0)
exts = self.__class__.template_ext
if any([content.endswith(ext) for ext in exts]):
template = env.get_template(content)
else:
template = env.from_string(content)
return template
def prepare_template(self):
'''
Hook that is trigerred after the templated fields get replaced
by their content. If you need your operator to alter the
content of the file before the template is rendered,
it should override this method to do so.
'''
pass
def resolve_template_files(self):
# Getting the content of files for template_field / template_ext
for attr in self.template_fields:
content = getattr(self, attr)
if any([content.endswith(ext) for ext in self.template_ext]):
env = self.dag.get_template_env()
try:
setattr(self, attr, env.loader.get_source(env, content)[0])
except Exception as e:
logging.exception(e)
self.prepare_template()
@property
def upstream_list(self):
"""@property: list of tasks directly upstream"""
return self._upstream_list
@property
def downstream_list(self):
"""@property: list of tasks directly downstream"""
return self._downstream_list
def clear(
self, start_date=None, end_date=None,
upstream=False, downstream=False):
"""
Clears the state of task instances associated with the task, follwing
the parameters specified.
"""
session = settings.Session()
TI = TaskInstance
qry = session.query(TI).filter(TI.dag_id == self.dag_id)
if start_date:
qry = qry.filter(TI.execution_date >= start_date)
if end_date:
qry = qry.filter(TI.execution_date <= end_date)
tasks = [self.task_id]
if upstream:
tasks += \
[t.task_id for t in self.get_flat_relatives(upstream=True)]
if downstream:
tasks += \
[t.task_id for t in self.get_flat_relatives(upstream=False)]
qry = qry.filter(TI.task_id.in_(tasks))
count = qry.count()
clear_task_instances(qry, session)
session.commit()
session.close()
return count
def get_task_instances(self, session, start_date=None, end_date=None):
"""
Get a set of task instance related to this task for a specific date
range.
"""
TI = TaskInstance
end_date = end_date or datetime.now()
return session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == self.task_id,
TI.execution_date >= start_date,
TI.execution_date <= end_date,
).order_by(TI.execution_date).all()
def get_flat_relatives(self, upstream=False, l=None):
"""
Get a flat list of relatives, either upstream or downstream.
"""
if not l:
l = []
for t in self.get_direct_relatives(upstream):
if t not in l:
l.append(t)
t.get_flat_relatives(upstream, l)
return l
def detect_downstream_cycle(self, task=None):
"""
When invoked, this routine will raise an exception if a cycle is
detected downstream from self. It is invoked when tasks are added to
the DAG to detect cycles.
"""
if not task:
task = self
for t in self.get_direct_relatives():
if task == t:
msg = "Cycle detect in DAG. Faulty task: {0}".format(task)
raise Exception(msg)
else:
t.detect_downstream_cycle(task=task)
return False
def run(
self, start_date=None, end_date=None, ignore_dependencies=False,
force=False, mark_success=False):
"""
Run a set of task instances for a date range.
"""
start_date = start_date or self.start_date
end_date = end_date or self.end_date or datetime.now()
for dt in utils.date_range(
start_date, end_date, self.schedule_interval):
TaskInstance(self, dt).run(
mark_success=mark_success,
ignore_dependencies=ignore_dependencies,
force=force,)
def get_direct_relatives(self, upstream=False):
"""
Get the direct relatives to the current task, upstream or
downstream.
"""
if upstream:
return self.upstream_list
else:
return self.downstream_list
def __repr__(self):
return "<Task({self.task_type}): {self.task_id}>".format(self=self)
def append_only_new(self, l, item):
if item in l:
raise Exception(
'Dependency {self}, {item} already registered'
''.format(**locals()))
else:
l.append(item)
def _set_relatives(self, task_or_task_list, upstream=False):
if isinstance(task_or_task_list, BaseOperator):
task_or_task_list = [task_or_task_list]
for task in task_or_task_list:
if not isinstance(task_or_task_list, list):
raise Exception('Expecting a task')
if upstream:
self.append_only_new(task._downstream_list, self)
self.append_only_new(self._upstream_list, task)
else:
self.append_only_new(task._upstream_list, self)
self.append_only_new(self._downstream_list, task)
self.detect_downstream_cycle()
def set_downstream(self, task_or_task_list):
"""
Set a task, or a task task to be directly downstream from the current
task.
"""
self._set_relatives(task_or_task_list, upstream=False)
def set_upstream(self, task_or_task_list):
"""
Set a task, or a task task to be directly upstream from the current
task.
"""
self._set_relatives(task_or_task_list, upstream=True)
class DAG(Base):
"""
A dag (directed acyclic graph) is a collection of tasks with directional
dependencies. A dag also has a schedule, a start end an end date
(optional). For each schedule, (say daily or hourly), the DAG needs to run
each individual tasks as their dependencies are met. Certain tasks have
the property of depending on their own past, meaning that they can't run
until their previous schedule (and upstream tasks) are completed.
DAGs essentially act as namespaces for tasks. A task_id can only be
added once to a DAG.
:param dag_id: The id of the DAG
:type dag_id: string
:param schedule_interval: Defines how often that DAG runs
:type schedule_interval: datetime.timedelta
:param start_date: The timestamp from which the sceduler will
attempt to backfill
:type start_date: datetime.datetime
:param end_date: A date beyond which your DAG won't run, leave to None
for open ended scheduling
:type end_date: datetime.datetime
:param template_searchpath: This list of folders (non relative)
defines where jinja will look for your templates. Order matters.
Note that jinja/airflow includes the path of your DAG file by
default
:type template_searchpath: string or list of stings
:param user_defined_macros: a dictionary of macros that will be merged
:type user_defined_macros: dict
:param default_args: A dictionary of default parameters to be used
as constructor keyword parameters when initialising operators.
Note that operators have the same hook, and precede those defined
here, meaning that if your dict contains `'depends_on_past': True`
here and `'depends_on_past': False` in te operator's call
`default_args`, the actual value will be `False`.
:type default_args: dict
"""
__tablename__ = "dag"
dag_id = Column(String(ID_LEN), primary_key=True)
is_paused = Column(Boolean, default=False)
def __init__(
self, dag_id,
schedule_interval=timedelta(days=1),
start_date=None, end_date=None,
full_filepath=None,
template_searchpath=None,
user_defined_macros=None,
default_args=None,
params=None):
self.user_defined_macros = user_defined_macros
self.default_args = default_args or {}
self.params = params
utils.validate_key(dag_id)
self.tasks = []
self.dag_id = dag_id
self.start_date = start_date
self.end_date = end_date or datetime.now()
self.schedule_interval = schedule_interval
self.full_filepath = full_filepath if full_filepath else ''
if isinstance(template_searchpath, basestring):
template_searchpath = [template_searchpath]
self.template_searchpath = template_searchpath
self.parent_dag = None # Gets set when DAGs are loaded
def __repr__(self):
return "<DAG: {self.dag_id}>".format(self=self)
@property
def task_ids(self):
return [t.task_id for t in self.tasks]
@property
def filepath(self):
fn = self.full_filepath.replace(DAGS_FOLDER + '/', '')
fn = fn.replace(os.path.dirname(__file__) + '/', '')
return fn
@property
def folder(self):
return os.path.dirname(self.full_filepath)
@property
def owner(self):
return ", ".join(list(set([t.owner for t in self.tasks])))
@property
def latest_execution_date(self):
TI = TaskInstance
session = settings.Session()
execution_date = session.query(func.max(TI.execution_date)).filter(
TI.dag_id == self.dag_id,
TI.task_id.in_(self.task_ids)
).scalar()
session.commit()
session.close()
return execution_date
@property
def subdags(self):
# Late import to prevent circular imports
from airflow.operators import SubDagOperator
l = []
for task in self.tasks:
if isinstance(task, SubDagOperator):
l.append(task.subdag)
l += task.subdag.subdags
return l
def resolve_template_files(self):
for t in self.tasks:
t.resolve_template_files()
def crawl_for_tasks(objects):
"""
Typically called at the end of a script by passing globals() as a
parameter. This allows to not explicitely add every single task to the
dag explicitely.
"""
raise NotImplemented("")
def override_start_date(self, start_date):
"""
Sets start_date of all tasks and of the DAG itself to a certain date.
This is used by BackfillJob.
"""
for t in self.tasks:
t.start_date = start_date
self.start_date = start_date
def get_template_env(self):
'''
Returns a jinja2 Environment while taking into account the DAGs
template_searchpath and user_defined_macros
'''
searchpath = [self.folder]
if self.template_searchpath:
searchpath += self.template_searchpath
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(searchpath),
extensions=["jinja2.ext.do"],
cache_size=0)
if self.user_defined_macros:
env.globals.update(self.user_defined_macros)
return env
def set_dependency(self, upstream_task_id, downstream_task_id):
"""
Simple utility method to set dependency between two tasks that
already have been added to the DAG using add_task()
"""
self.get_task(upstream_task_id).set_downstream(
self.get_task(downstream_task_id))
def get_task_instances(self, session, start_date=None, end_date=None):
TI = TaskInstance
if not start_date:
start_date = (datetime.today()-timedelta(30)).date()
start_date = datetime.combine(start_date, datetime.min.time())
if not end_date:
end_date = datetime.now()
tis = session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.execution_date >= start_date,
TI.execution_date <= end_date,
).all()
return tis
@property
def roots(self):
return [t for t in self.tasks if not t.downstream_list]
def clear(
self, start_date=None, end_date=None,
only_failed=False,
only_running=False,
confirm_prompt=False,
include_subdags=True,
dry_run=False):
session = settings.Session()
"""
Clears a set of task instances associated with the current dag for
a specified date range.
"""
TI = TaskInstance
tis = session.query(TI)
if include_subdags:
# Creafting the right filter for dag_id and task_ids combo
conditions = []
for dag in self.subdags + [self]:
conditions.append(
TI.dag_id.like(dag.dag_id) & TI.task_id.in_(dag.task_ids)
)
tis = tis.filter(or_(*conditions))
else:
tis = session.query(TI).filter(TI.dag_id == self.dag_id)
tis = tis.filter(TI.task_id.in_(self.task_ids))
if start_date:
tis = tis.filter(TI.execution_date >= start_date)
if end_date:
tis = tis.filter(TI.execution_date <= end_date)
if only_failed:
tis = tis.filter(TI.state == State.FAILED)
if only_running:
tis = tis.filter(TI.state == State.RUNNING)
if dry_run:
tis = tis.all()
session.expunge_all()
return tis
count = tis.count()
if count == 0:
print("Nothing to clear.")
return 0
if confirm_prompt:
ti_list = "\n".join([str(t) for t in tis])
question = (
"You are about to delete these {count} tasks:\n"
"{ti_list}\n\n"
"Are you sure? (yes/no): ").format(**locals())
if utils.ask_yesno(question):
clear_task_instances(tis, session)
else:
count = 0
print("Bail. Nothing was cleared.")
else:
clear_task_instances(tis, session)
session.commit()
session.close()
return count
def __deepcopy__(self, memo):
# Swiwtcharoo to go around deepcopying objects coming through the
# backdoor
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k not in ('user_defined_macros', 'params'):
setattr(result, k, copy.deepcopy(v, memo))
result.user_defined_macros = self.user_defined_macros
result.params = self.params
return result
def sub_dag(
self, task_regex,
include_downstream=False, include_upstream=True):
"""
Returns a subset of the current dag as a deep copy of the current dag
based on a regex that should match one or many tasks, and includes
upstream and downstream neighboors based on the flag passed.
"""
dag = copy.deepcopy(self)
regex_match = [
t for t in dag.tasks if re.findall(task_regex, t.task_id)]
also_include = []
for t in regex_match:
if include_downstream:
also_include += t.get_flat_relatives(upstream=False)
if include_upstream:
also_include += t.get_flat_relatives(upstream=True)
# Compiling the unique list of tasks that made the cut
tasks = list(set(regex_match + also_include))
dag.tasks = tasks
for t in dag.tasks:
# Removing upstream/downstream references to tasks that did not
# made the cut
t._upstream_list = [
ut for ut in t._upstream_list if ut in tasks]
t._downstream_list = [
ut for ut in t._downstream_list if ut in tasks]
return dag
def get_task(self, task_id):
for task in self.tasks:
if task.task_id == task_id:
return task
raise Exception("Task {task_id} not found".format(**locals()))
def tree_view(self):
"""
Shows an ascii tree representation of the DAG
"""
def get_downstream(task, level=0):
print (" " * level * 4) + str(task)
level += 1
for t in task.upstream_list:
get_downstream(t, level)
for t in self.roots:
get_downstream(t)
def add_task(self, task):
'''
Add a task to the DAG
:param task: the task you want to add
:type task: task
'''
if not self.start_date and not task.start_date:
raise Exception("Task is missing the start_date parameter")
if not task.start_date:
task.start_date = self.start_date
if task.task_id in [t.task_id for t in self.tasks]:
raise Exception(
"Task id '{0}' has already been added "
"to the DAG ".format(task.task_id))
else:
self.tasks.append(task)
task.dag_id = self.dag_id
task.dag = self
self.task_count = len(self.tasks)
def add_tasks(self, tasks):
'''
Add a list of tasks to the DAG
:param task: a lit of tasks you want to add
:type task: list of tasks
'''
for task in tasks:
self.add_task(task)
def db_merge(self):
BO = BaseOperator
session = settings.Session()
tasks = session.query(BO).filter(BO.dag_id == self.dag_id).all()
for t in tasks:
session.delete(t)
session.commit()
session.merge(self)
session.commit()
def run(
self, start_date=None, end_date=None, mark_success=False,
include_adhoc=False, local=False, executor=None,
donot_pickle=False):
from airflow.jobs import BackfillJob
if not executor and local:
executor = LocalExecutor()
elif not executor:
executor = DEFAULT_EXECUTOR
job = BackfillJob(
self,
start_date=start_date,
end_date=end_date,
mark_success=mark_success,
include_adhoc=include_adhoc,
executor=executor,
donot_pickle=donot_pickle)
job.run()
class Chart(Base):
__tablename__ = "chart"
id = Column(Integer, primary_key=True)
label = Column(String(200))
conn_id = Column(
String(ID_LEN), ForeignKey('connection.conn_id'), nullable=False)
user_id = Column(Integer(), ForeignKey('user.id'),)
chart_type = Column(String(100), default="line")
sql_layout = Column(String(50), default="series")
sql = Column(Text, default="SELECT series, x, y FROM table")
y_log_scale = Column(Boolean)
show_datatable = Column(Boolean)
show_sql = Column(Boolean, default=True)
height = Column(Integer, default=600)
default_params = Column(String(5000), default="{}")
owner = relationship(
"User", cascade=False, cascade_backrefs=False, backref='charts')
x_is_date = Column(Boolean, default=True)
db = relationship("Connection")
iteration_no = Column(Integer, default=0)
last_modified = Column(DateTime, default=datetime.now())
class KnownEventType(Base):
__tablename__ = "known_event_type"
id = Column(Integer, primary_key=True)
know_event_type = Column(String(200))
def __repr__(self):
return self.know_event_type
class KnownEvent(Base):
__tablename__ = "known_event"
id = Column(Integer, primary_key=True)
label = Column(String(200))
start_date = Column(DateTime)
end_date = Column(DateTime)
user_id = Column(Integer(), ForeignKey('user.id'),)
known_event_type_id = Column(Integer(), ForeignKey('known_event_type.id'),)
reported_by = relationship(
"User", cascade=False, cascade_backrefs=False, backref='known_events')
event_type = relationship(
"KnownEventType",
cascade=False,
cascade_backrefs=False, backref='known_events')
description = Column(Text)