This commit is contained in:
Maxime 2015-02-24 07:43:48 +00:00
Родитель df772e2a59
Коммит a1089dd067
4 изменённых файлов: 56 добавлений и 23 удалений

Просмотреть файл

@ -2,11 +2,16 @@ from airflow import settings
from airflow.models import Connection
from airflow.hooks.base_hook import BaseHook
from pyhive import presto
from pyhive.exc import DatabaseError
import logging
logging.getLogger("pyhive").setLevel(logging.INFO)
class PrestoException(Exception):
pass
class PrestoHook(BaseHook):
"""
Interact with Presto through PyHive!
@ -56,16 +61,26 @@ class PrestoHook(BaseHook):
'''
Get a set of records from Presto
'''
self.cursor.execute(self._strip_sql(hql), parameters)
return self.cursor.fetchall()
try:
self.cursor.execute(self._strip_sql(hql), parameters)
records = self.cursor.fetchall()
except DatabaseError as e:
obj = eval(str(e))
raise PrestoException(obj['message'])
return records
def get_first(self, hql, parameters=None):
'''
Returns only the first row, regardless of how many rows the query
returns.
'''
self.cursor.execute(self._strip_sql(hql), parameters)
return self.cursor.fetchone()
try:
self.cursor.execute(self._strip_sql(hql), parameters)
record = self.cursor.fetchone()
except DatabaseError as e:
obj = eval(str(e))
raise PrestoException(obj['message'])
return record
def get_pandas_df(self, hql, parameters=None):
'''

Просмотреть файл

@ -612,20 +612,25 @@ class TaskInstance(Base):
session.add(Log(State.FAILED, self))
# Let's go deeper
try:
if self.try_number <= task.retries:
self.state = State.UP_FOR_RETRY
if task.email_on_retry and task.email:
self.email_alert(error, is_retry=True)
else:
self.state = State.FAILED
if task.email_on_failure and task.email:
self.email_alert(error, is_retry=False)
#try:
if self.try_number <= task.retries:
self.state = State.UP_FOR_RETRY
if task.email_on_retry and task.email:
self.email_alert(error, is_retry=True)
else:
self.state = State.FAILED
if task.email_on_failure and task.email:
self.email_alert(error, is_retry=False)
'''
except Exception as e2:
logging.error(
'Failed to send email to: ' + str(task.email))
logging.error(str(e2))
'''
if not test_mode:
session.merge(self)
session.commit()
@ -675,17 +680,14 @@ class TaskInstance(Base):
for attr in task.__class__.template_fields:
result = getattr(task, attr)
try:
template = self.task.get_template(attr)
result = template.render(**jinja_context)
except Exception as e:
logging.exception(e)
template = self.task.get_template(attr)
result = template.render(**jinja_context)
setattr(task, attr, result)
def email_alert(self, exception, is_retry=False):
task = self.task
title = "Airflow alert: {self}".format(**locals())
exception = exception.replace('\n', '<br>')
exception = str(exception).replace('\n', '<br>')
try_ = task.retries + 1
body = (
"Try {self.try_number} out of {try_}<br>"
@ -880,7 +882,7 @@ class BaseOperator(Base):
if hasattr(self, 'dag'):
env = self.dag.get_template_env()
else:
env = jinja2.Environment()
env = jinja2.Environment(cache_size=0)
exts = self.__class__.template_ext
if any([content.endswith(ext) for ext in exts]):

Просмотреть файл

@ -1,6 +1,5 @@
import logging
from airflow.configuration import conf
from airflow.hooks import PrestoHook
from airflow.models import BaseOperator
from airflow.utils import apply_defaults

Просмотреть файл

@ -147,6 +147,7 @@ class Airflow(BaseView):
models.Connection.conn_id)
db_choices = [(db.conn_id, db.conn_id) for db in dbs]
conn_id_str = request.args.get('conn_id')
csv = request.args.get('csv') == "true"
sql = request.args.get('sql')
class QueryForm(Form):
@ -184,6 +185,12 @@ class Airflow(BaseView):
if not has_data and error:
flash('No data', 'error')
if csv:
return Response(
response=df.to_csv(index=False),
status=200,
mimetype="application/text")
form = QueryForm(request.form, data=data)
session.commit()
session.close()
@ -196,10 +203,11 @@ class Airflow(BaseView):
@expose('/chart_data')
@login_required
@wwwutils.gzipped
#@cache.cached(timeout=3600, key_prefix=wwwutils.make_cache_key)
@cache.cached(timeout=3600, key_prefix=wwwutils.make_cache_key)
def chart_data(self):
session = settings.Session()
chart_id = request.args.get('chart_id')
csv = request.args.get('csv') == "true"
chart = session.query(models.Chart).filter_by(id=chart_id).all()[0]
db = session.query(
models.Connection).filter_by(conn_id=chart.conn_id).all()[0]
@ -243,6 +251,12 @@ class Airflow(BaseView):
except Exception as e:
payload['error'] += "SQL execution failed. Details: " + str(e)
if csv:
return Response(
response=df.to_csv(index=False),
status=200,
mimetype="application/text")
if not payload['error'] and len(df) == CHART_LIMIT:
payload['warning'] = (
"Data has been truncated to {0}"
@ -610,7 +624,10 @@ class Airflow(BaseView):
dag = dagbag.dags[dag_id]
task = copy.copy(dag.get_task(task_id))
ti = models.TaskInstance(task=task, execution_date=dttm)
ti.render_templates()
try:
ti.render_templates()
except Exception as e:
flash("Error rendering template: " + str(e), "error")
title = "{dag_id}.{task_id} [{execution_date}] rendered"
html_dict = {}
for template_field in task.__class__.template_fields: