Merge pull request #7 from mistercrunch/minor_touchups

Minor fixes, utility functions
This commit is contained in:
Maxime Beauchemin 2014-11-10 13:47:00 -08:00
Родитель fc8f28a261 9a033dc93b
Коммит 01597a5cbd
3 изменённых файлов: 16 добавлений и 8 удалений

Просмотреть файл

@ -45,7 +45,7 @@ class HiveHook(BaseHook):
def __setstate__(self, d):
self.__dict__.update(d)
d['hive'] = self.get_hive_client()
self.__dict__['hive'] = self.get_hive_client()
def get_hive_client(self):
transport = TSocket.TSocket(self.host, self.port)
@ -61,7 +61,7 @@ class HiveHook(BaseHook):
self.hive._oprot.trans.open()
partitions = self.hive.get_partitions_by_filter(
schema, table, partition, 1)
self.transport.close()
self.hive._oprot.trans.close()
if partitions:
return True
else:

Просмотреть файл

@ -864,10 +864,11 @@ class BaseOperator(Base):
def __repr__(self):
return "<Task({self.task_type}): {self.task_id}>".format(self=self)
@staticmethod
def append_only_new(l, item):
def append_only_new(self, l, item):
if item in l:
raise Exception('Dependency already registered')
raise Exception(
'Dependency {self}, {item} already registered'
''.format(**locals()))
else:
l.append(item)
@ -921,7 +922,7 @@ class DAG(Base):
full_filepath = Column(String(2000))
tasks = relationship(
"BaseOperator", cascade="merge, delete, delete-orphan", backref='dag')
"BaseOperator", cascade="all, delete-orphan", backref='dag')
def __init__(
self, dag_id,
@ -1061,6 +1062,7 @@ class DAG(Base):
for task in self.tasks:
if task.task_id == task_id:
return task
raise Exception("Task {task_id} not found".format(**locals()))
def tree_view(self):
"""
@ -1085,6 +1087,10 @@ class DAG(Base):
task.dag = self
self.task_count = len(self.tasks)
def add_tasks(self, tasks):
for task in tasks:
self.add_task(task)
def db_merge(self):
session = settings.Session()
session.merge(self)

Просмотреть файл

@ -118,11 +118,13 @@ class HivePartitionSensor(BaseSensorOperator):
def __init__(
self,
table, partition,
table, partition="ds='{{ ds }}'",
hive_dbid=getconf().get('hooks', 'HIVE_DEFAULT_DBID'),
schema='default',
*args, **kwargs):
super(HivePartitionSensor, self).__init__(*args, **kwargs)
if '.' in table:
schema, table = table.split('.')
self.hive_dbid = hive_dbid
self.hook = HiveHook(hive_dbid=hive_dbid)
self.table = table
@ -131,7 +133,7 @@ class HivePartitionSensor(BaseSensorOperator):
def poke(self):
logging.info(
'Poking for table {self.table}, '
'Poking for table {self.schema}{self.table}, '
'partition {self.partition}'.format(**locals()))
return self.hook.check_for_partition(
self.schema, self.table, self.partition)