diff --git a/airflow/operators/__init__.py b/airflow/operators/__init__.py index b079822e4b..1e60982e2d 100644 --- a/airflow/operators/__init__.py +++ b/airflow/operators/__init__.py @@ -1,6 +1,7 @@ from bash_operator import BashOperator from mysql_operator import MySqlOperator from hive_operator import HiveOperator +from presto_check_operator import PrestoCheckOperator from sensors import SqlSensor from sensors import ExternalTaskSensor from sensors import HivePartitionSensor diff --git a/airflow/operators/presto_check_operator.py b/airflow/operators/presto_check_operator.py new file mode 100644 index 0000000000..303908e1ac --- /dev/null +++ b/airflow/operators/presto_check_operator.py @@ -0,0 +1,42 @@ +import logging + +from airflow.configuration import conf +from airflow.hooks import PrestoHook +from airflow.models import BaseOperator +from airflow.utils import apply_defaults + + +class PrestoCheckOperator(BaseOperator): + """ + Performs a simple check using sql code in a specific Presto database. + + :param sql: the sql to be executed + :type sql: string + :param presto_dbid: reference to the Presto database + :type presto_dbid: string + """ + + __mapper_args__ = { + 'polymorphic_identity': 'PrestoCheckOperator' + } + template_fields = ('sql',) + template_ext = ('.hql', '.sql',) + + @apply_defaults + def __init__( + self, sql, + presto_conn_id=conf.get('hooks', 'PRESTO_DEFAULT_CONN_ID'), + *args, **kwargs): + super(PrestoCheckOperator, self).__init__(*args, **kwargs) + + self.presto_conn_id = presto_conn_id + self.hook = PrestoHook(presto_conn_id=presto_conn_id) + self.sql = sql + + def execute(self, execution_date=None): + logging.info('Executing SQL check: ' + self.sql) + records = self.hook.get_records(hql=self.sql) + if not records: + return False + else: + return not any([ bool(r) for r in records[0] ])