Fixing HiveHook to be picklable
This commit is contained in:
Родитель
eb885d2f24
Коммит
4faaa13a0d
|
@ -32,18 +32,31 @@ class HiveHook(BaseHook):
|
|||
session.close()
|
||||
|
||||
# Connection to Hive
|
||||
self.hive = self.get_hive_client()
|
||||
|
||||
def __getstate__(self):
|
||||
# This is for pickling to work despite the thirft hive client not
|
||||
# being pickable
|
||||
d = dict(self.__dict__)
|
||||
del d['hive']
|
||||
return d
|
||||
|
||||
def __setstate__(self, d):
|
||||
d['hive'] = self.get_hive_client()
|
||||
self.__dict__.update(d)
|
||||
|
||||
def get_hive_client(self):
|
||||
transport = TSocket.TSocket(self.host, self.port)
|
||||
self.transport = TTransport.TBufferedTransport(transport)
|
||||
protocol = TBinaryProtocol.TBinaryProtocol(self.transport)
|
||||
self.hive = ThriftHive.Client(protocol)
|
||||
transport = TTransport.TBufferedTransport(transport)
|
||||
protocol = TBinaryProtocol.TBinaryProtocol(transport)
|
||||
return ThriftHive.Client(protocol)
|
||||
|
||||
def get_conn(self):
|
||||
self.transport.open()
|
||||
return self.hive
|
||||
|
||||
def check_for_partition(self, schema, table, partition):
|
||||
try:
|
||||
self.transport.open()
|
||||
self.hive._oprot.trans.open()
|
||||
partitions = self.hive.get_partitions_by_filter(
|
||||
schema, table, partition, 1)
|
||||
self.transport.close()
|
||||
|
@ -56,31 +69,31 @@ class HiveHook(BaseHook):
|
|||
return False
|
||||
|
||||
def get_records(self, hql, schema=None):
|
||||
self.transport.open()
|
||||
self.hive._oprot.trans.open()
|
||||
if schema:
|
||||
self.hive.execute("USE " + schema)
|
||||
self.hive.execute(hql)
|
||||
records = self.hive.fetchAll()
|
||||
self.transport.close()
|
||||
self.hive._oprot.trans.close()
|
||||
return [row.split("\t") for row in records]
|
||||
|
||||
def get_pandas_df(self, hql, schema=None):
|
||||
import pandas as pd
|
||||
self.transport.open()
|
||||
self.hive._oprot.trans.open()
|
||||
if schema:
|
||||
self.hive.execute("USE " + schema)
|
||||
self.hive.execute(hql)
|
||||
records = self.hive.fetchAll()
|
||||
self.transport.close()
|
||||
self.hive._oprot.trans.close()
|
||||
df = pd.DataFrame([row.split("\t") for row in records])
|
||||
return df
|
||||
|
||||
def run(self, hql, schema=None):
|
||||
self.transport.open()
|
||||
self.hive._oprot.trans.open()
|
||||
if schema:
|
||||
self.hive.execute("USE " + schema)
|
||||
self.hive.execute(hql)
|
||||
self.transport.close()
|
||||
self.hive._oprot.trans.close()
|
||||
|
||||
def run_cli(self, hql, schema=None):
|
||||
if schema:
|
||||
|
@ -94,6 +107,7 @@ class HiveHook(BaseHook):
|
|||
for tables that have a single partition key. For subpartitionned
|
||||
table, we recommend using signal tables.
|
||||
'''
|
||||
self.hive._oprot.trans.open()
|
||||
table = self.hive.get_table(dbname=schema, tbl_name=table)
|
||||
if len(table.partitionKeys) == 0:
|
||||
raise Exception("The table isn't partitionned")
|
||||
|
@ -104,4 +118,6 @@ class HiveHook(BaseHook):
|
|||
else:
|
||||
parts = self.hive.get_partitions(
|
||||
db_name='core_data', tbl_name='dim_users', max_parts=32767)
|
||||
|
||||
self.hive._oprot.trans.close()
|
||||
return max([p.values[0] for p in parts])
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import subprocess
|
||||
import StringIO
|
||||
|
||||
from airflow import settings
|
||||
from airflow.models import DatabaseConnection
|
||||
|
|
|
@ -15,6 +15,7 @@ from sqlalchemy import (
|
|||
ForeignKey, func
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.serializer import loads, dumps
|
||||
from sqlalchemy.orm import relationship
|
||||
from airflow.executors import DEFAULT_EXECUTOR
|
||||
from airflow import settings
|
||||
|
@ -156,12 +157,12 @@ class DagPickle(Base):
|
|||
__tablename__ = "dag_pickle"
|
||||
|
||||
def __init__(self, dag, job):
|
||||
self.pickle = pickle.dumps(dag)
|
||||
self.dag_id = dag.dag_id
|
||||
self.job = job
|
||||
self.pickle = dumps(dag)
|
||||
|
||||
def get_object(self):
|
||||
return pickle.loads(self.pickle)
|
||||
return loads(self.pickle)
|
||||
|
||||
|
||||
class TaskInstance(Base):
|
||||
|
|
Загрузка…
Ссылка в новой задаче