Merge pull request #53 from airbnb/poke_context

Passing context to BaseSensor poke method
This commit is contained in:
Maxime Beauchemin 2015-06-19 22:06:48 -07:00
Родитель a1800221e3 62afdc81f0
Коммит 259f0d1615
1 изменённых файлов: 11 добавлений и 15 удалений

Просмотреть файл

@ -36,7 +36,7 @@ class BaseSensorOperator(BaseOperator):
self.poke_interval = poke_interval self.poke_interval = poke_interval
self.timeout = timeout self.timeout = timeout
def poke(self): def poke(self, context):
''' '''
Function that the sensors defined while deriving this class should Function that the sensors defined while deriving this class should
override. override.
@ -45,7 +45,7 @@ class BaseSensorOperator(BaseOperator):
def execute(self, context): def execute(self, context):
started_at = datetime.now() started_at = datetime.now()
while not self.poke(): while not self.poke(context):
sleep(self.poke_interval) sleep(self.poke_interval)
if (datetime.now() - started_at).seconds > self.timeout: if (datetime.now() - started_at).seconds > self.timeout:
raise AirflowException('Snap. Time is OUT.') raise AirflowException('Snap. Time is OUT.')
@ -81,7 +81,7 @@ class SqlSensor(BaseSensorOperator):
session.commit() session.commit()
session.close() session.close()
def poke(self): def poke(self, context):
logging.info('Poking: ' + self.sql) logging.info('Poking: ' + self.sql)
records = self.hook.get_records(self.sql) records = self.hook.get_records(self.sql)
if not records: if not records:
@ -105,30 +105,26 @@ class ExternalTaskSensor(BaseSensorOperator):
wait for wait for
:type external_task_id: string :type external_task_id: string
""" """
template_fields = ('execution_date',)
@apply_defaults @apply_defaults
def __init__(self, external_dag_id, external_task_id, *args, **kwargs): def __init__(self, external_dag_id, external_task_id, *args, **kwargs):
super(ExternalTaskSensor, self).__init__(*args, **kwargs) super(ExternalTaskSensor, self).__init__(*args, **kwargs)
self.external_dag_id = external_dag_id self.external_dag_id = external_dag_id
self.external_task_id = external_task_id self.external_task_id = external_task_id
self.execution_date = "{{ execution_date }}"
def poke(self): def poke(self, context):
logging.info( logging.info(
'Poking for ' 'Poking for '
'{self.external_dag_id}.' '{self.external_dag_id}.'
'{self.external_task_id} on ' '{self.external_task_id} on '
'{self.execution_date} ... '.format(**locals())) '{context[execution_date]} ... '.format(**locals()))
TI = TaskInstance TI = TaskInstance
session = settings.Session() session = settings.Session()
import dateutil.parser
self.execution_date = dateutil.parser.parse(self.execution_date)
count = session.query(TI).filter( count = session.query(TI).filter(
TI.dag_id == self.external_dag_id, TI.dag_id == self.external_dag_id,
TI.task_id == self.external_task_id, TI.task_id == self.external_task_id,
TI.state == State.SUCCESS, TI.state == State.SUCCESS,
TI.execution_date == self.execution_date TI.execution_date == context['execution_date'],
).count() ).count()
session.commit() session.commit()
session.close() session.close()
@ -169,7 +165,7 @@ class HivePartitionSensor(BaseSensorOperator):
self.partition = partition self.partition = partition
self.schema = schema self.schema = schema
def poke(self): def poke(self, context):
logging.info( logging.info(
'Poking for table {self.schema}.{self.table}, ' 'Poking for table {self.schema}.{self.table}, '
'partition {self.partition}'.format(**locals())) 'partition {self.partition}'.format(**locals()))
@ -196,7 +192,7 @@ class HdfsSensor(BaseSensorOperator):
self.filepath = filepath self.filepath = filepath
self.hdfs_conn_id = hdfs_conn_id self.hdfs_conn_id = hdfs_conn_id
def poke(self): def poke(self, context):
sb = hooks.HDFSHook(self.hdfs_conn_id).get_conn() sb = hooks.HDFSHook(self.hdfs_conn_id).get_conn()
logging.getLogger("snakebite").setLevel(logging.WARNING) logging.getLogger("snakebite").setLevel(logging.WARNING)
logging.info( logging.info(
@ -255,7 +251,7 @@ class S3KeySensor(BaseSensorOperator):
session.commit() session.commit()
session.close() session.close()
def poke(self): def poke(self, context):
hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id) hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id)
full_url = "s3://" + self.bucket_name + self.bucket_key full_url = "s3://" + self.bucket_name + self.bucket_key
logging.info('Poking for key : {full_url}'.format(**locals())) logging.info('Poking for key : {full_url}'.format(**locals()))
@ -307,7 +303,7 @@ class S3PrefixSensor(BaseSensorOperator):
session.commit() session.commit()
session.close() session.close()
def poke(self): def poke(self, context):
logging.info('Poking for prefix : {self.prefix}\n' logging.info('Poking for prefix : {self.prefix}\n'
'in bucket s3://{self.bucket_name}'.format(**locals())) 'in bucket s3://{self.bucket_name}'.format(**locals()))
hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id) hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id)
@ -331,7 +327,7 @@ class TimeSensor(BaseSensorOperator):
super(TimeSensor, self).__init__(*args, **kwargs) super(TimeSensor, self).__init__(*args, **kwargs)
self.target_time = target_time self.target_time = target_time
def poke(self): def poke(self, context):
logging.info( logging.info(
'Checking if the time ({0}) has come'.format(self.target_time)) 'Checking if the time ({0}) has come'.format(self.target_time))
return datetime.now().time() > self.target_time return datetime.now().time() > self.target_time