diff --git a/airflow/operators/sensors.py b/airflow/operators/sensors.py index 2062937b7b..86af970986 100644 --- a/airflow/operators/sensors.py +++ b/airflow/operators/sensors.py @@ -36,7 +36,7 @@ class BaseSensorOperator(BaseOperator): self.poke_interval = poke_interval self.timeout = timeout - def poke(self): + def poke(self, context): ''' Function that the sensors defined while deriving this class should override. @@ -45,7 +45,7 @@ class BaseSensorOperator(BaseOperator): def execute(self, context): started_at = datetime.now() - while not self.poke(): + while not self.poke(context): sleep(self.poke_interval) if (datetime.now() - started_at).seconds > self.timeout: raise AirflowException('Snap. Time is OUT.') @@ -81,7 +81,7 @@ class SqlSensor(BaseSensorOperator): session.commit() session.close() - def poke(self): + def poke(self, context): logging.info('Poking: ' + self.sql) records = self.hook.get_records(self.sql) if not records: @@ -105,30 +105,26 @@ class ExternalTaskSensor(BaseSensorOperator): wait for :type external_task_id: string """ - template_fields = ('execution_date',) @apply_defaults def __init__(self, external_dag_id, external_task_id, *args, **kwargs): super(ExternalTaskSensor, self).__init__(*args, **kwargs) self.external_dag_id = external_dag_id self.external_task_id = external_task_id - self.execution_date = "{{ execution_date }}" - def poke(self): + def poke(self, context): logging.info( 'Poking for ' '{self.external_dag_id}.' '{self.external_task_id} on ' - '{self.execution_date} ... '.format(**locals())) + '{context[execution_date]} ... '.format(**locals())) TI = TaskInstance session = settings.Session() - import dateutil.parser - self.execution_date = dateutil.parser.parse(self.execution_date) count = session.query(TI).filter( TI.dag_id == self.external_dag_id, TI.task_id == self.external_task_id, TI.state == State.SUCCESS, - TI.execution_date == self.execution_date + TI.execution_date == context['execution_date'], ).count() session.commit() session.close() @@ -169,7 +165,7 @@ class HivePartitionSensor(BaseSensorOperator): self.partition = partition self.schema = schema - def poke(self): + def poke(self, context): logging.info( 'Poking for table {self.schema}.{self.table}, ' 'partition {self.partition}'.format(**locals())) @@ -196,7 +192,7 @@ class HdfsSensor(BaseSensorOperator): self.filepath = filepath self.hdfs_conn_id = hdfs_conn_id - def poke(self): + def poke(self, context): sb = hooks.HDFSHook(self.hdfs_conn_id).get_conn() logging.getLogger("snakebite").setLevel(logging.WARNING) logging.info( @@ -255,7 +251,7 @@ class S3KeySensor(BaseSensorOperator): session.commit() session.close() - def poke(self): + def poke(self, context): hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id) full_url = "s3://" + self.bucket_name + self.bucket_key logging.info('Poking for key : {full_url}'.format(**locals())) @@ -307,7 +303,7 @@ class S3PrefixSensor(BaseSensorOperator): session.commit() session.close() - def poke(self): + def poke(self, context): logging.info('Poking for prefix : {self.prefix}\n' 'in bucket s3://{self.bucket_name}'.format(**locals())) hook = hooks.S3Hook(s3_conn_id=self.s3_conn_id) @@ -331,7 +327,7 @@ class TimeSensor(BaseSensorOperator): super(TimeSensor, self).__init__(*args, **kwargs) self.target_time = target_time - def poke(self): + def poke(self, context): logging.info( 'Checking if the time ({0}) has come'.format(self.target_time)) return datetime.now().time() > self.target_time