diff --git a/airflow/hooks/hive_hook.py b/airflow/hooks/hive_hook.py index d36c3e09e1..0b5503d501 100644 --- a/airflow/hooks/hive_hook.py +++ b/airflow/hooks/hive_hook.py @@ -45,7 +45,7 @@ class HiveHook(BaseHook): def __setstate__(self, d): self.__dict__.update(d) - d['hive'] = self.get_hive_client() + self.__dict__['hive'] = self.get_hive_client() def get_hive_client(self): transport = TSocket.TSocket(self.host, self.port) @@ -61,7 +61,7 @@ class HiveHook(BaseHook): self.hive._oprot.trans.open() partitions = self.hive.get_partitions_by_filter( schema, table, partition, 1) - self.transport.close() + self.hive._oprot.trans.close() if partitions: return True else: diff --git a/airflow/models.py b/airflow/models.py index 41c10a23b9..ebc3262add 100644 --- a/airflow/models.py +++ b/airflow/models.py @@ -864,10 +864,11 @@ class BaseOperator(Base): def __repr__(self): return "".format(self=self) - @staticmethod - def append_only_new(l, item): + def append_only_new(self, l, item): if item in l: - raise Exception('Dependency already registered') + raise Exception( + 'Dependency {self}, {item} already registered' + ''.format(**locals())) else: l.append(item) @@ -921,7 +922,7 @@ class DAG(Base): full_filepath = Column(String(2000)) tasks = relationship( - "BaseOperator", cascade="merge, delete, delete-orphan", backref='dag') + "BaseOperator", cascade="all, delete-orphan", backref='dag') def __init__( self, dag_id, @@ -1061,6 +1062,7 @@ class DAG(Base): for task in self.tasks: if task.task_id == task_id: return task + raise Exception("Task {task_id} not found".format(**locals())) def tree_view(self): """ @@ -1085,6 +1087,10 @@ class DAG(Base): task.dag = self self.task_count = len(self.tasks) + def add_tasks(self, tasks): + for task in tasks: + self.add_task(task) + def db_merge(self): session = settings.Session() session.merge(self) diff --git a/airflow/operators/sensors.py b/airflow/operators/sensors.py index 0a39da0281..8eb7558465 100644 --- a/airflow/operators/sensors.py +++ b/airflow/operators/sensors.py @@ -118,11 +118,13 @@ class HivePartitionSensor(BaseSensorOperator): def __init__( self, - table, partition, + table, partition="ds='{{ ds }}'", hive_dbid=getconf().get('hooks', 'HIVE_DEFAULT_DBID'), schema='default', *args, **kwargs): super(HivePartitionSensor, self).__init__(*args, **kwargs) + if '.' in table: + schema, table = table.split('.') self.hive_dbid = hive_dbid self.hook = HiveHook(hive_dbid=hive_dbid) self.table = table @@ -131,7 +133,7 @@ class HivePartitionSensor(BaseSensorOperator): def poke(self): logging.info( - 'Poking for table {self.table}, ' + 'Poking for table {self.schema}{self.table}, ' 'partition {self.partition}'.format(**locals())) return self.hook.check_for_partition( self.schema, self.table, self.partition)