Merge pull request #53 from airbnb/poke_context
Passing context to BaseSensor poke method
This commit is contained in:
Коммит
259f0d1615
|
@ -36,7 +36,7 @@ class BaseSensorOperator(BaseOperator):
|
||||||
self.poke_interval = poke_interval
|
self.poke_interval = poke_interval
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
def poke(self):
|
def poke(self, context):
|
||||||
'''
|
'''
|
||||||
Function that the sensors defined while deriving this class should
|
Function that the sensors defined while deriving this class should
|
||||||
override.
|
override.
|
||||||
|
@ -45,7 +45,7 @@ class BaseSensorOperator(BaseOperator):
|
||||||
|
|
||||||
def execute(self, context):
|
def execute(self, context):
|
||||||
started_at = datetime.now()
|
started_at = datetime.now()
|
||||||
while not self.poke():
|
while not self.poke(context):
|
||||||
sleep(self.poke_interval)
|
sleep(self.poke_interval)
|
||||||
if (datetime.now() - started_at).seconds > self.timeout:
|
if (datetime.now() - started_at).seconds > self.timeout:
|
||||||
raise AirflowException('Snap. Time is OUT.')
|
raise AirflowException('Snap. Time is OUT.')
|
||||||
|
@ -81,7 +81,7 @@ class SqlSensor(BaseSensorOperator):
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def poke(self):
|
def poke(self, context):
|
||||||
logging.info('Poking: ' + self.sql)
|
logging.info('Poking: ' + self.sql)
|
||||||
records = self.hook.get_records(self.sql)
|
records = self.hook.get_records(self.sql)
|
||||||
if not records:
|
if not records:
|
||||||
|
@ -105,30 +105,26 @@ class ExternalTaskSensor(BaseSensorOperator):
|
||||||
wait for
|
wait for
|
||||||
:type external_task_id: string
|
:type external_task_id: string
|
||||||
"""
|
"""
|
||||||
template_fields = ('execution_date',)
|
|
||||||
|
|
||||||
@apply_defaults
|
@apply_defaults
|
||||||
def __init__(self, external_dag_id, external_task_id, *args, **kwargs):
|
def __init__(self, external_dag_id, external_task_id, *args, **kwargs):
|
||||||
super(ExternalTaskSensor, self).__init__(*args, **kwargs)
|
super(ExternalTaskSensor, self).__init__(*args, **kwargs)
|
||||||
self.external_dag_id = external_dag_id
|
self.external_dag_id = external_dag_id
|
||||||
self.external_task_id = external_task_id
|
self.external_task_id = external_task_id
|
||||||
self.execution_date = "{{ execution_date }}"
|
|
||||||
|
|
||||||
def poke(self):
|
def poke(self, context):
|
||||||
logging.info(
|
logging.info(
|
||||||
'Poking for '
|
'Poking for '
|
||||||
'{self.external_dag_id}.'
|
'{self.external_dag_id}.'
|
||||||
'{self.external_task_id} on '
|
'{self.external_task_id} on '
|
||||||
'{self.execution_date} ... '.format(**locals()))
|
'{context[execution_date]} ... '.format(**locals()))
|
||||||
TI = TaskInstance
|
TI = TaskInstance
|
||||||
session = settings.Session()
|
session = settings.Session()
|
||||||
import dateutil.parser
|
|
||||||
self.execution_date = dateutil.parser.parse(self.execution_date)
|
|
||||||
count = session.query(TI).filter(
|
count = session.query(TI).filter(
|
||||||
TI.dag_id == self.external_dag_id,
|
TI.dag_id == self.external_dag_id,
|
||||||
TI.task_id == self.external_task_id,
|
TI.task_id == self.external_task_id,
|
||||||
TI.state == State.SUCCESS,
|
TI.state == State.SUCCESS,
|
||||||
TI.execution_date == self.execution_date
|
TI.execution_date == context['execution_date'],
|
||||||
).count()
|
).count()
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
@ -169,7 +165,7 @@ class HivePartitionSensor(BaseSensorOperator):
|
||||||
self.partition = partition
|
self.partition = partition
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
|
|
||||||
def poke(self):
|
def poke(self, context):
|
||||||
logging.info(
|
logging.info(
|
||||||
'Poking for table {self.schema}.{self.table}, '
|
'Poking for table {self.schema}.{self.table}, '
|
||||||
'partition {self.partition}'.format(**locals()))
|
'partition {self.partition}'.format(**locals()))
|
||||||
|
@ -196,7 +192,7 @@ class HdfsSensor(BaseSensorOperator):
|
||||||
self.filepath = filepath
|
self.filepath = filepath
|
||||||
self.hdfs_conn_id = hdfs_conn_id
|
self.hdfs_conn_id = hdfs_conn_id
|
||||||
|
|
||||||
def poke(self):
|
def poke(self, context):
|
||||||
sb = hooks.HDFSHook(self.hdfs_conn_id).get_conn()
|
sb = hooks.HDFSHook(self.hdfs_conn_id).get_conn()
|
||||||
logging.getLogger("snakebite").setLevel(logging.WARNING)
|
logging.getLogger("snakebite").setLevel(logging.WARNING)
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -255,7 +251,7 @@ class S3KeySensor(BaseSensorOperator):
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def poke(self):
|
def poke(self, context):
|
||||||
hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id)
|
hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id)
|
||||||
full_url = "s3://" + self.bucket_name + self.bucket_key
|
full_url = "s3://" + self.bucket_name + self.bucket_key
|
||||||
logging.info('Poking for key : {full_url}'.format(**locals()))
|
logging.info('Poking for key : {full_url}'.format(**locals()))
|
||||||
|
@ -307,7 +303,7 @@ class S3PrefixSensor(BaseSensorOperator):
|
||||||
session.commit()
|
session.commit()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
def poke(self):
|
def poke(self, context):
|
||||||
logging.info('Poking for prefix : {self.prefix}\n'
|
logging.info('Poking for prefix : {self.prefix}\n'
|
||||||
'in bucket s3://{self.bucket_name}'.format(**locals()))
|
'in bucket s3://{self.bucket_name}'.format(**locals()))
|
||||||
hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id)
|
hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id)
|
||||||
|
@ -331,7 +327,7 @@ class TimeSensor(BaseSensorOperator):
|
||||||
super(TimeSensor, self).__init__(*args, **kwargs)
|
super(TimeSensor, self).__init__(*args, **kwargs)
|
||||||
self.target_time = target_time
|
self.target_time = target_time
|
||||||
|
|
||||||
def poke(self):
|
def poke(self, context):
|
||||||
logging.info(
|
logging.info(
|
||||||
'Checking if the time ({0}) has come'.format(self.target_time))
|
'Checking if the time ({0}) has come'.format(self.target_time))
|
||||||
return datetime.now().time() > self.target_time
|
return datetime.now().time() > self.target_time
|
||||||
|
|
Загрузка…
Ссылка в новой задаче