301 строка
8.8 KiB
Python
301 строка
8.8 KiB
Python
from copy import copy
|
|
from datetime import datetime, timedelta
|
|
from email.mime.text import MIMEText
|
|
from email.mime.multipart import MIMEMultipart
|
|
from functools import wraps
|
|
import inspect
|
|
import logging
|
|
import re
|
|
import smtplib
|
|
|
|
from sqlalchemy import event, exc
|
|
from sqlalchemy.pool import Pool
|
|
|
|
from airflow.configuration import conf
|
|
from airflow import settings
|
|
|
|
|
|
class State(object):
|
|
"""
|
|
Static class with task instance states constants and color method to
|
|
avoid hardcoding.
|
|
"""
|
|
QUEUED = "queued"
|
|
RUNNING = "running"
|
|
SUCCESS = "success"
|
|
SHUTDOWN = "shutdown" # External request to shut down
|
|
FAILED = "failed"
|
|
UP_FOR_RETRY = "up_for_retry"
|
|
|
|
state_color = {
|
|
QUEUED: 'grey',
|
|
RUNNING: 'lime',
|
|
SUCCESS: 'green',
|
|
SHUTDOWN: 'orange',
|
|
FAILED: 'red',
|
|
UP_FOR_RETRY: 'yellow',
|
|
}
|
|
|
|
@classmethod
|
|
def color(cls, state):
|
|
return cls.state_color[state]
|
|
|
|
@classmethod
|
|
def runnable(cls):
|
|
return [None, cls.FAILED, cls.UP_FOR_RETRY]
|
|
|
|
|
|
def pessimistic_connection_handling():
|
|
@event.listens_for(Pool, "checkout")
|
|
def ping_connection(dbapi_connection, connection_record, connection_proxy):
|
|
'''
|
|
Disconnect Handling - Pessimistic, taken from:
|
|
http://docs.sqlalchemy.org/en/rel_0_9/core/pooling.html
|
|
'''
|
|
cursor = dbapi_connection.cursor()
|
|
try:
|
|
cursor.execute("SELECT 1")
|
|
except:
|
|
raise exc.DisconnectionError()
|
|
cursor.close()
|
|
|
|
|
|
def initdb():
|
|
from airflow import models
|
|
logging.info("Creating all tables")
|
|
models.Base.metadata.create_all(settings.engine)
|
|
|
|
# Creating the local_mysql DB connection
|
|
C = models.Connection
|
|
session = settings.Session()
|
|
|
|
conn = session.query(C).filter(C.conn_id == 'local_mysql').first()
|
|
if not conn:
|
|
session.add(
|
|
models.Connection(
|
|
conn_id='local_mysql', conn_type='mysql',
|
|
host='localhost', login='airflow', password='airflow',
|
|
schema='airflow'))
|
|
session.commit()
|
|
|
|
conn = session.query(C).filter(C.conn_id == 'presto_default').first()
|
|
if not conn:
|
|
session.add(
|
|
models.Connection(
|
|
conn_id='presto_default', conn_type='presto',
|
|
host='localhost',
|
|
schema='hive', port=3400))
|
|
session.commit()
|
|
|
|
conn = session.query(C).filter(C.conn_id == 'hive_cli_default').first()
|
|
if not conn:
|
|
session.add(
|
|
models.Connection(
|
|
conn_id='hive_cli_default', conn_type='hive_cli',
|
|
schema='default',))
|
|
session.commit()
|
|
|
|
conn = session.query(C).filter(C.conn_id == 'hiveserver2_default').first()
|
|
if not conn:
|
|
session.add(
|
|
models.Connection(
|
|
conn_id='hiveserver2_default', conn_type='hiveserver2',
|
|
host='localhost',
|
|
schema='default', port=10000))
|
|
session.commit()
|
|
|
|
conn = session.query(C).filter(C.conn_id == 'metastore_default').first()
|
|
if not conn:
|
|
session.add(
|
|
models.Connection(
|
|
conn_id='metastore_default', conn_type='hive_metastore',
|
|
host='localhost',
|
|
port=10001))
|
|
session.commit()
|
|
|
|
# Known event types
|
|
KET = models.KnownEventType
|
|
if not session.query(KET).filter(KET.know_event_type == 'Holiday').first():
|
|
session.add(KET(know_event_type='Holiday'))
|
|
if not session.query(KET).filter(KET.know_event_type == 'Outage').first():
|
|
session.add(KET(know_event_type='Outage'))
|
|
if not session.query(KET).filter(
|
|
KET.know_event_type == 'Natural Disaster').first():
|
|
session.add(KET(know_event_type='Natural Disaster'))
|
|
if not session.query(KET).filter(
|
|
KET.know_event_type == 'Marketing Campain').first():
|
|
session.add(KET(know_event_type='Marketing Campain'))
|
|
session.commit()
|
|
session.close()
|
|
|
|
|
|
def resetdb():
|
|
'''
|
|
Clear out the database
|
|
'''
|
|
from airflow import models
|
|
|
|
logging.info("Dropping tables that exist")
|
|
models.Base.metadata.drop_all(settings.engine)
|
|
initdb()
|
|
|
|
|
|
def validate_key(k, max_length=250):
|
|
if type(k) is not str:
|
|
raise TypeError("The key has to be a string")
|
|
elif len(k) > max_length:
|
|
raise Exception("The key has to be less than {0} characters".format(
|
|
max_length))
|
|
elif not re.match(r'^[A-Za-z0-9_\-\.]+$', k):
|
|
raise Exception(
|
|
"The key ({k}) has to be made of alphanumeric characters, dashes, "
|
|
"dots and underscores exclusively".format(**locals()))
|
|
else:
|
|
return True
|
|
|
|
|
|
def date_range(start_date, end_date=datetime.now(), delta=timedelta(1)):
|
|
l = []
|
|
if end_date >= start_date:
|
|
while start_date <= end_date:
|
|
l.append(start_date)
|
|
start_date += delta
|
|
else:
|
|
raise Exception("start_date can't be after end_date")
|
|
return l
|
|
|
|
|
|
def json_ser(obj):
|
|
"""
|
|
json serializer that deals with dates
|
|
usage: json.dumps(object, default=utils.json_ser)
|
|
"""
|
|
if isinstance(obj, datetime):
|
|
obj = obj.isoformat()
|
|
return obj
|
|
|
|
|
|
def alchemy_to_dict(obj):
|
|
"""
|
|
Transforms a SQLAlchemy model instance into a dictionary
|
|
"""
|
|
if not obj:
|
|
return None
|
|
d = {}
|
|
for c in obj.__table__.columns:
|
|
value = getattr(obj, c.name)
|
|
if type(value) == datetime:
|
|
value = value.isoformat()
|
|
d[c.name] = value
|
|
return d
|
|
|
|
|
|
def readfile(filepath):
|
|
f = open(filepath)
|
|
content = f.read()
|
|
f.close()
|
|
return content
|
|
|
|
|
|
def apply_defaults(func):
|
|
'''
|
|
Function decorator that Looks for an argument named "default_args", and
|
|
fills the unspecified arguments from it.
|
|
|
|
Since python2.* isn't clear about which arguments are missing when
|
|
calling a function, and that this can be quite confusing with multi-level
|
|
inheritance and argument defaults, this decorator also alerts with
|
|
specific information about the missing arguments.
|
|
'''
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if len(args) > 1:
|
|
print args
|
|
raise Exception(
|
|
"Use keyword arguments when initializing operators")
|
|
dag_args = {}
|
|
dag_params = {}
|
|
if 'dag' in kwargs and kwargs['dag']:
|
|
dag = kwargs['dag']
|
|
dag_args = copy(dag.default_args) or {}
|
|
dag_params = copy(dag.params) or {}
|
|
|
|
params = {}
|
|
if 'params' in kwargs:
|
|
params = kwargs['params']
|
|
dag_params.update(params)
|
|
|
|
default_args = {}
|
|
if 'default_args' in kwargs:
|
|
default_args = kwargs['default_args']
|
|
if 'params' in default_args:
|
|
dag_params.update(default_args['params'])
|
|
del default_args['params']
|
|
|
|
dag_args.update(default_args)
|
|
default_args = dag_args
|
|
arg_spec = inspect.getargspec(func)
|
|
num_defaults = len(arg_spec.defaults) if arg_spec.defaults else 0
|
|
non_optional_args = arg_spec.args[:-num_defaults]
|
|
if 'self' in non_optional_args:
|
|
non_optional_args.remove('self')
|
|
for arg in func.__code__.co_varnames:
|
|
if arg in default_args and arg not in kwargs:
|
|
kwargs[arg] = default_args[arg]
|
|
missing_args = list(set(non_optional_args) - set(kwargs))
|
|
if missing_args:
|
|
msg = "Argument {0} is required".format(missing_args)
|
|
raise Exception(msg)
|
|
|
|
kwargs['params'] = dag_params
|
|
|
|
result = func(*args, **kwargs)
|
|
return result
|
|
return wrapper
|
|
|
|
|
|
def ask_yesno(question):
|
|
yes = set(['yes', 'y'])
|
|
no = set(['no', 'n'])
|
|
|
|
done = False
|
|
print(question)
|
|
while not done:
|
|
choice = raw_input().lower()
|
|
if choice in yes:
|
|
return True
|
|
elif choice in no:
|
|
return False
|
|
else:
|
|
print("Please respond by yes or no.")
|
|
|
|
|
|
def send_email(to, subject, html_content):
|
|
SMTP_HOST = conf.get('smtp', 'SMTP_HOST')
|
|
SMTP_MAIL_FROM = conf.get('smtp', 'SMTP_MAIL_FROM')
|
|
SMTP_PORT = conf.get('smtp', 'SMTP_PORT')
|
|
SMTP_USER = conf.get('smtp', 'SMTP_USER')
|
|
SMTP_PASSWORD = conf.get('smtp', 'SMTP_PASSWORD')
|
|
|
|
if isinstance(to, unicode) or isinstance(to, str):
|
|
if ',' in to:
|
|
to = to.split(',')
|
|
elif ';' in to:
|
|
to = to.split(';')
|
|
else:
|
|
to = [to]
|
|
|
|
msg = MIMEMultipart('alternative')
|
|
msg['Subject'] = subject
|
|
msg['From'] = SMTP_MAIL_FROM
|
|
msg['To'] = ", ".join(to)
|
|
mime_text = MIMEText(html_content, 'html')
|
|
msg.attach(mime_text)
|
|
s = smtplib.SMTP(SMTP_HOST, SMTP_PORT)
|
|
s.starttls()
|
|
if SMTP_USER and SMTP_PASSWORD:
|
|
s.login(SMTP_USER, SMTP_PASSWORD)
|
|
logging.info("Sent an altert email to " + str(to))
|
|
s.sendmail(SMTP_MAIL_FROM, to, msg.as_string())
|
|
s.quit()
|