Merge pull request #53 from airbnb/poke_context

Passing context to BaseSensor poke method
This commit is contained in:
Maxime Beauchemin 2015-06-19 22:06:48 -07:00
Родитель a1800221e3 62afdc81f0
Коммит 259f0d1615
1 изменённых файлов: 11 добавлений и 15 удалений

Просмотреть файл

@ -36,7 +36,7 @@ class BaseSensorOperator(BaseOperator):
self.poke_interval = poke_interval
self.timeout = timeout
def poke(self):
def poke(self, context):
'''
Function that the sensors defined while deriving this class should
override.
@ -45,7 +45,7 @@ class BaseSensorOperator(BaseOperator):
def execute(self, context):
started_at = datetime.now()
while not self.poke():
while not self.poke(context):
sleep(self.poke_interval)
if (datetime.now() - started_at).seconds > self.timeout:
raise AirflowException('Snap. Time is OUT.')
@ -81,7 +81,7 @@ class SqlSensor(BaseSensorOperator):
session.commit()
session.close()
def poke(self):
def poke(self, context):
logging.info('Poking: ' + self.sql)
records = self.hook.get_records(self.sql)
if not records:
@ -105,30 +105,26 @@ class ExternalTaskSensor(BaseSensorOperator):
wait for
:type external_task_id: string
"""
template_fields = ('execution_date',)
@apply_defaults
def __init__(self, external_dag_id, external_task_id, *args, **kwargs):
super(ExternalTaskSensor, self).__init__(*args, **kwargs)
self.external_dag_id = external_dag_id
self.external_task_id = external_task_id
self.execution_date = "{{ execution_date }}"
def poke(self):
def poke(self, context):
logging.info(
'Poking for '
'{self.external_dag_id}.'
'{self.external_task_id} on '
'{self.execution_date} ... '.format(**locals()))
'{context[execution_date]} ... '.format(**locals()))
TI = TaskInstance
session = settings.Session()
import dateutil.parser
self.execution_date = dateutil.parser.parse(self.execution_date)
count = session.query(TI).filter(
TI.dag_id == self.external_dag_id,
TI.task_id == self.external_task_id,
TI.state == State.SUCCESS,
TI.execution_date == self.execution_date
TI.execution_date == context['execution_date'],
).count()
session.commit()
session.close()
@ -169,7 +165,7 @@ class HivePartitionSensor(BaseSensorOperator):
self.partition = partition
self.schema = schema
def poke(self):
def poke(self, context):
logging.info(
'Poking for table {self.schema}.{self.table}, '
'partition {self.partition}'.format(**locals()))
@ -196,7 +192,7 @@ class HdfsSensor(BaseSensorOperator):
self.filepath = filepath
self.hdfs_conn_id = hdfs_conn_id
def poke(self):
def poke(self, context):
sb = hooks.HDFSHook(self.hdfs_conn_id).get_conn()
logging.getLogger("snakebite").setLevel(logging.WARNING)
logging.info(
@ -255,7 +251,7 @@ class S3KeySensor(BaseSensorOperator):
session.commit()
session.close()
def poke(self):
def poke(self, context):
hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id)
full_url = "s3://" + self.bucket_name + self.bucket_key
logging.info('Poking for key : {full_url}'.format(**locals()))
@ -307,7 +303,7 @@ class S3PrefixSensor(BaseSensorOperator):
session.commit()
session.close()
def poke(self):
def poke(self, context):
logging.info('Poking for prefix : {self.prefix}\n'
'in bucket s3://{self.bucket_name}'.format(**locals()))
hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id)
@ -331,7 +327,7 @@ class TimeSensor(BaseSensorOperator):
super(TimeSensor, self).__init__(*args, **kwargs)
self.target_time = target_time
def poke(self):
def poke(self, context):
logging.info(
'Checking if the time ({0}) has come'.format(self.target_time))
return datetime.now().time() > self.target_time