Fixing HiveHook to be picklable

This commit is contained in:
Maxime Beauchemin 2014-11-01 15:50:11 +00:00
Родитель eb885d2f24
Коммит 4faaa13a0d
3 изменённых файлов: 30 добавлений и 14 удалений

Просмотреть файл

@ -32,18 +32,31 @@ class HiveHook(BaseHook):
session.close()
# Connection to Hive
self.hive = self.get_hive_client()
def __getstate__(self):
# This is for pickling to work despite the thirft hive client not
# being pickable
d = dict(self.__dict__)
del d['hive']
return d
def __setstate__(self, d):
d['hive'] = self.get_hive_client()
self.__dict__.update(d)
def get_hive_client(self):
transport = TSocket.TSocket(self.host, self.port)
self.transport = TTransport.TBufferedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(self.transport)
self.hive = ThriftHive.Client(protocol)
transport = TTransport.TBufferedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
return ThriftHive.Client(protocol)
def get_conn(self):
self.transport.open()
return self.hive
def check_for_partition(self, schema, table, partition):
try:
self.transport.open()
self.hive._oprot.trans.open()
partitions = self.hive.get_partitions_by_filter(
schema, table, partition, 1)
self.transport.close()
@ -56,31 +69,31 @@ class HiveHook(BaseHook):
return False
def get_records(self, hql, schema=None):
self.transport.open()
self.hive._oprot.trans.open()
if schema:
self.hive.execute("USE " + schema)
self.hive.execute(hql)
records = self.hive.fetchAll()
self.transport.close()
self.hive._oprot.trans.close()
return [row.split("\t") for row in records]
def get_pandas_df(self, hql, schema=None):
import pandas as pd
self.transport.open()
self.hive._oprot.trans.open()
if schema:
self.hive.execute("USE " + schema)
self.hive.execute(hql)
records = self.hive.fetchAll()
self.transport.close()
self.hive._oprot.trans.close()
df = pd.DataFrame([row.split("\t") for row in records])
return df
def run(self, hql, schema=None):
self.transport.open()
self.hive._oprot.trans.open()
if schema:
self.hive.execute("USE " + schema)
self.hive.execute(hql)
self.transport.close()
self.hive._oprot.trans.close()
def run_cli(self, hql, schema=None):
if schema:
@ -94,6 +107,7 @@ class HiveHook(BaseHook):
for tables that have a single partition key. For subpartitionned
table, we recommend using signal tables.
'''
self.hive._oprot.trans.open()
table = self.hive.get_table(dbname=schema, tbl_name=table)
if len(table.partitionKeys) == 0:
raise Exception("The table isn't partitionned")
@ -104,4 +118,6 @@ class HiveHook(BaseHook):
else:
parts = self.hive.get_partitions(
db_name='core_data', tbl_name='dim_users', max_parts=32767)
self.hive._oprot.trans.close()
return max([p.values[0] for p in parts])

Просмотреть файл

@ -1,5 +1,4 @@
import subprocess
import StringIO
from airflow import settings
from airflow.models import DatabaseConnection

Просмотреть файл

@ -15,6 +15,7 @@ from sqlalchemy import (
ForeignKey, func
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.serializer import loads, dumps
from sqlalchemy.orm import relationship
from airflow.executors import DEFAULT_EXECUTOR
from airflow import settings
@ -156,12 +157,12 @@ class DagPickle(Base):
__tablename__ = "dag_pickle"
def __init__(self, dag, job):
self.pickle = pickle.dumps(dag)
self.dag_id = dag.dag_id
self.job = job
self.pickle = dumps(dag)
def get_object(self):
return pickle.loads(self.pickle)
return loads(self.pickle)
class TaskInstance(Base):