Merge pull request #140 from mistercrunch/hive_server2

Hive2SambaOperator to use hiveserver2, refactoring (breaking down) HiveHooks
This commit is contained in:
Maxime Beauchemin 2015-02-17 08:23:56 -08:00
Родитель d7e12c41ce dadc534e55
Коммит 9274c8b024
11 изменённых файлов: 235 добавлений и 196 удалений

Просмотреть файл

@ -1,5 +1,7 @@
from airflow.hooks.mysql_hook import MySqlHook
from airflow.hooks.hive_hook import HiveHook
from airflow.hooks.hive_hooks import HiveCliHook
from airflow.hooks.hive_hooks import HiveMetastoreHook
from airflow.hooks.hive_hooks import HiveServer2Hook
from airflow.hooks.presto_hook import PrestoHook
from airflow.hooks.samba_hook import SambaHook
from airflow.hooks.S3_hook import S3Hook

Просмотреть файл

@ -1,3 +1,7 @@
from airflow import settings
from airflow.models import Connection
class BaseHook(object):
"""
Abstract base class for hooks, hooks are meant as an interface to
@ -9,6 +13,18 @@ class BaseHook(object):
def __init__(self, source):
pass
def get_connection(self, conn_id):
session = settings.Session()
db = session.query(
Connection).filter(
Connection.conn_id == conn_id).first()
if not db:
raise Exception(
"The conn_id `{0}` isn't defined".format(conn_id))
session.expunge_all()
session.close()
return db
def get_conn(self):
raise NotImplemented()

Просмотреть файл

@ -1,140 +1,39 @@
import csv
import logging
import json
import subprocess
from tempfile import NamedTemporaryFile
from airflow.models import Connection
from airflow import settings
from thrift.transport import TSocket
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol
from hive_service import ThriftHive
import pyhs2
from airflow.hooks.base_hook import BaseHook
class HiveHook(BaseHook):
class HiveCliHook(BaseHook):
'''
Interact with Hive. This class is both a wrapper around the Hive Thrift
client and the Hive CLI.
Simple wrapper around the hive CLI
'''
def __init__(
self,
hive_conn_id='hive_default'):
session = settings.Session()
db = session.query(
Connection).filter(
Connection.conn_id == hive_conn_id)
if db.count() == 0:
raise Exception("The conn_id you provided isn't defined")
else:
db = db.all()[0]
self.host = db.host
self.db = db.schema
hive_cli_conn_id="hive_cli_default"
):
conn = self.get_connection(hive_cli_conn_id)
self.hive_cli_params = ""
try:
self.hive_cli_params = json.loads(db.extra)['hive_cli_params']
self.hive_cli_params = json.loads(
conn.extra)['hive_cli_params']
except:
pass
self.port = db.port
session.commit()
session.close()
# Connection to Hive
self.hive = self.get_hive_client()
def __getstate__(self):
# This is for pickling to work despite the thirft hive client not
# being pickable
d = dict(self.__dict__)
del d['hive']
return d
def __setstate__(self, d):
self.__dict__.update(d)
self.__dict__['hive'] = self.get_hive_client()
def get_hive_client(self):
'''
Returns a Hive thrift client.
'''
transport = TSocket.TSocket(self.host, self.port)
transport = TTransport.TBufferedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
return ThriftHive.Client(protocol)
def get_conn(self):
return self.hive
def check_for_partition(self, schema, table, partition):
'''
Checks whether a partition exists
>>> hh = HiveHook()
>>> t = 'static_babynames_partitioned'
>>> hh.check_for_partition('airflow', t, "ds='2015-01-01'")
True
'''
self.hive._oprot.trans.open()
partitions = self.hive.get_partitions_by_filter(
schema, table, partition, 1)
self.hive._oprot.trans.close()
if partitions:
return True
else:
return False
def get_records(self, hql, schema=None):
'''
Get a set of records from a Hive query.
>>> hh = HiveHook()
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
>>> hh.get_records(sql)
[['340698']]
'''
self.hive._oprot.trans.open()
if schema:
self.hive.execute("USE " + schema)
self.hive.execute(hql)
records = self.hive.fetchAll()
self.hive._oprot.trans.close()
return [row.split("\t") for row in records]
def get_pandas_df(self, hql, schema=None):
'''
Get a pandas dataframe from a Hive query
>>> hh = HiveHook()
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
>>> df = hh.get_pandas_df(sql)
>>> df.to_dict()
{0: {0: '340698'}}
'''
import pandas as pd
self.hive._oprot.trans.open()
if schema:
self.hive.execute("USE " + schema)
self.hive.execute(hql)
records = self.hive.fetchAll()
self.hive._oprot.trans.close()
df = pd.DataFrame([row.split("\t") for row in records])
return df
def run(self, hql, schema=None):
self.hive._oprot.trans.open()
if schema:
self.hive.execute("USE " + schema)
self.hive.execute(hql)
self.hive._oprot.trans.close()
def run_cli(self, hql, schema=None):
'''
Run an hql statement using the hive cli
>>> hh = HiveHook()
>>> hh = HiveCliHook()
>>> hh.run_cli("USE airflow;")
'''
if schema:
@ -158,20 +57,77 @@ class HiveHook(BaseHook):
if sp.returncode:
raise Exception(all_err)
def kill(self):
if hasattr(self, 'sp'):
if self.sp.poll() is None:
print("Killing the Hive job")
self.sp.kill()
class HiveMetastoreHook(BaseHook):
'''
Wrapper to interact with the Hive Metastore
'''
def __init__(self, metastore_conn_id='metastore_default'):
self.metastore_conn = self.get_connection(metastore_conn_id)
self.metastore = self.get_metastore_client()
def __getstate__(self):
# This is for pickling to work despite the thirft hive client not
# being pickable
d = dict(self.__dict__)
del d['metastore']
return d
def __setstate__(self, d):
self.__dict__.update(d)
self.__dict__['metastore'] = self.get_metastore_client()
def get_metastore_client(self):
'''
Returns a Hive thrift client.
'''
ms = self.metastore_conn
transport = TSocket.TSocket(ms.host, ms.port)
transport = TTransport.TBufferedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
return ThriftHive.Client(protocol)
def get_conn(self):
return self.metastore
def check_for_partition(self, schema, table, partition):
'''
Checks whether a partition exists
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> hh.check_for_partition('airflow', t, "ds='2015-01-01'")
True
'''
self.metastore._oprot.trans.open()
partitions = self.metastore.get_partitions_by_filter(
schema, table, partition, 1)
self.metastore._oprot.trans.close()
if partitions:
return True
else:
return False
def get_table(self, db, table_name):
'''
Get a metastore table object
>>> hh = HiveHook()
>>> hh = HiveMetastoreHook()
>>> t = hh.get_table(db='airflow', table_name='static_babynames')
>>> t.tableName
'static_babynames'
>>> [col.name for col in t.sd.cols]
['state', 'year', 'name', 'gender', 'num']
'''
self.hive._oprot.trans.open()
table = self.hive.get_table(dbname=db, tbl_name=table_name)
self.hive._oprot.trans.close()
self.metastore._oprot.trans.open()
table = self.metastore.get_table(dbname=db, tbl_name=table_name)
self.metastore._oprot.trans.close()
return table
def get_partitions(self, schema, table_name):
@ -180,7 +136,7 @@ class HiveHook(BaseHook):
for tables with less than 32767 (java short max val).
For subpartitionned table, the number might easily exceed this.
>>> hh = HiveHook()
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> parts = hh.get_partitions(schema='airflow', table_name=t)
>>> len(parts)
@ -188,8 +144,8 @@ class HiveHook(BaseHook):
>>> max(parts)
'2015-01-01'
'''
self.hive._oprot.trans.open()
table = self.hive.get_table(dbname=schema, tbl_name=table_name)
self.metastore._oprot.trans.open()
table = self.metastore.get_table(dbname=schema, tbl_name=table_name)
if len(table.partitionKeys) == 0:
raise Exception("The table isn't partitionned")
elif len(table.partitionKeys) > 1:
@ -197,10 +153,10 @@ class HiveHook(BaseHook):
"The table is partitionned by multiple columns, "
"use a signal table!")
else:
parts = self.hive.get_partitions(
parts = self.metastore.get_partitions(
db_name=schema, tbl_name=table_name, max_parts=32767)
self.hive._oprot.trans.close()
self.metastore._oprot.trans.close()
return [p.values[0] for p in parts]
def max_partition(self, schema, table_name):
@ -209,15 +165,84 @@ class HiveHook(BaseHook):
for tables that have a single partition key. For subpartitionned
table, we recommend using signal tables.
>>> hh = HiveHook()
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
>>> hh.max_partition(schema='airflow', table_name=t)
'2015-01-01'
'''
return max(self.get_partitions(schema, table_name))
def kill(self):
if hasattr(self, 'sp'):
if self.sp.poll() is None:
print("Killing the Hive job")
self.sp.kill()
class HiveServer2Hook(BaseHook):
'''
Wrapper around the pyhs2 lib
'''
def __init__(self, hiveserver2_conn_id='hiveserver2_default'):
self.hiveserver2_conn = self.get_connection(hiveserver2_conn_id)
def get_results(self, hql, schema='default'):
schema = schema or 'default'
with pyhs2.connect(
host=self.hiveserver2_conn.host,
port=self.hiveserver2_conn.port,
authMechanism="NOSASL",
user='airflow',
database=schema) as conn:
with conn.cursor() as cur:
cur.execute(hql)
return {
'data': cur.fetchall(),
'header': cur.getSchema(),
}
def to_csv(self, hql, csv_filepath, schema='default'):
schema = schema or 'default'
with pyhs2.connect(
host=self.hiveserver2_conn.host,
port=self.hiveserver2_conn.port,
authMechanism="NOSASL",
user='airflow',
database=schema) as conn:
with conn.cursor() as cur:
cur.execute(hql)
schema = cur.getSchema()
with open(csv_filepath, 'w') as f:
writer = csv.writer(f)
writer.writerow([c['columnName'] for c in cur.getSchema()])
while cur.hasMoreRows:
writer.writerows(
[row for row in cur.fetchmany() if row])
def get_records(self, hql, schema='default'):
'''
Get a set of records from a Hive query.
>>> hh = HiveServer2Hook()
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
>>> hh.get_records(sql)
[[340698]]
'''
return self.get_results(hql, schema=schema)['data']
def get_pandas_df(self, hql, schema='default'):
'''
Get a pandas dataframe from a Hive query
>>> hh = HiveServer2Hook()
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
>>> df = hh.get_pandas_df(sql)
>>> df.to_dict()
{'num': {0: 340698}}
'''
import pandas as pd
res = self.get_results(hql, schema=schema)
df = pd.DataFrame(res['data'])
df.columns = [c['columnName'] for c in res['header']]
return df
def run(self, hql, schema=None):
self.hive._oprot.trans.open()
if schema:
self.hive.execute("USE " + schema)
self.hive.execute(hql)
self.hive._oprot.trans.close()

Просмотреть файл

@ -1,39 +1,27 @@
import os
from smbclient import SambaClient
from airflow import settings
from airflow.models import Connection
from airflow.hooks.base_hook import BaseHook
class SambaHook(object):
class SambaHook(BaseHook):
'''
Allows for interaction with an samba server.
'''
def __init__(self, samba_conn_id=None):
session = settings.Session()
samba_conn = session.query(
Connection).filter(
Connection.conn_id == samba_conn_id).first()
if not samba_conn:
raise Exception("The samba id you provided isn't defined")
self.host = samba_conn.host
self.login = samba_conn.login
self.psw = samba_conn.password
self.db = samba_conn.schema
session.commit()
session.close()
def __init__(self, samba_conn_id):
self.conn = self.get_connection(samba_conn_id)
def get_conn(self):
samba = SambaClient(
server=self.host, share='', username=self.login, password=self.psw)
server=self.conn.host,
share=self.conn.schema,
username=self.conn.login,
ip=self.conn.host,
password=self.conn.password)
return samba
def push_from_local(self, destination_filepath, local_filepath):
samba = self.get_conn()
samba.cwd(os.path.dirname(self.destination_filepath))
f = open(local_filepath, 'r')
filename = os.path.basename(destination_filepath)
samba.storbinary("STOR " + filename, f)
f.close()
samba.quit()
if samba.exists(destination_filepath):
samba.remove(destination_filepath)
samba.upload(local_filepath, destination_filepath)

Просмотреть файл

@ -1,10 +1,9 @@
from airflow.configuration import conf
import datetime
def max_partition(
table, schema="default",
hive_conn_id='hive_default'):
metastore_conn_id='metastore_default'):
'''
Gets the max partition for a table.
@ -17,11 +16,14 @@ def max_partition(
:param hive_conn_id: The hive connection you are interested in.
If your default is set you don't need to use this parameter.
:type hive_conn_id: string
>>> max_partition('airflow.static_babynames_partitioned')
'2015-01-01'
'''
from airflow.hooks.hive_hook import HiveHook
from airflow.hooks import HiveMetastoreHook
if '.' in table:
schema, table = table.split('.')
hh = HiveHook(hive_conn_id=hive_conn_id)
hh = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
return hh.max_partition(schema=schema, table_name=table)
@ -52,7 +54,7 @@ def _closest_date(target_dt, date_list, before_target=None):
def closest_ds_partition(
table, ds, before=True, schema="default",
hive_conn_id='hive_default'):
metastore_conn_id='metastore_default'):
'''
This function finds the date in a list closest to the target date.
An optional paramter can be given to get the closest before or after.
@ -66,12 +68,15 @@ def closest_ds_partition(
:returns: The closest date
:rtype: str or None
>>> tbl = 'airflow.static_babynames_partitioned'
>>> closest_ds_partition(tbl, '2015-01-02')
'2015-01-01'
'''
from airflow.hooks.hive_hook import HiveHook
from airflow.hooks import HiveMetastoreHook
if '.' in table:
schema, table = table.split('.')
target_dt = datetime.datetime.strptime(ds, '%Y-%m-%d')
hh = HiveHook(hive_conn_id=hive_conn_id)
hh = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
partitions = hh.get_partitions(schema=schema, table_name=table)
parts = [datetime.datetime.strptime(p, '%Y-%m-%d') for p in partitions]
if partitions is None or partitions == []:

Просмотреть файл

@ -1,8 +1,7 @@
import logging
import tempfile
from airflow.configuration import conf
from airflow.hooks import HiveHook, SambaHook
from airflow.hooks import HiveServer2Hook, SambaHook
from airflow.models import BaseOperator
from airflow.utils import apply_defaults
@ -13,10 +12,10 @@ class Hive2SambaOperator(BaseOperator):
:param hql: the hql to be exported
:type hql: string
:param hive_dbid: reference to the Hive database
:type hive_dbid: string
:param samba_dbid: reference to the samba destination
:type samba_dbid: string
:param hiveserver2_conn_id: reference to the hiveserver2 service
:type hiveserver2_conn_id: string
:param samba_conn_id: reference to the samba destination
:type samba_conn_id: string
"""
__mapper_args__ = {
@ -28,33 +27,20 @@ class Hive2SambaOperator(BaseOperator):
@apply_defaults
def __init__(
self, hql,
samba_dbid,
destination_filepath,
hive_dbid='hive_default',
samba_conn_id='samba_default',
hiveserver2_conn_id='hiveserver2_default',
*args, **kwargs):
super(Hive2SambaOperator, self).__init__(*args, **kwargs)
self.hive_dbid = hive_dbid
self.samba_dbid = samba_dbid
self.hiveserver2_conn_id = hiveserver2_conn_id
self.samba_conn_id = samba_conn_id
self.destination_filepath = destination_filepath
self.samba = SambaHook(samba_dbid=samba_dbid)
self.hook = HiveHook(hive_dbid=hive_dbid)
self.samba = SambaHook(samba_conn_id=samba_conn_id)
self.hook = HiveServer2Hook(hiveserver2_conn_id=hiveserver2_conn_id)
self.hql = hql.strip().rstrip(';')
def execute(self, context):
tmpfile = tempfile.NamedTemporaryFile()
hql = """\
INSERT OVERWRITE LOCAL DIRECTORY '{tmpfile.name}'
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ','
{self.hql};
""".format(**locals())
logging.info('Executing: ' + hql)
self.hook.run_cli(hql=hql)
self.hook.to_csv(hql=self.hql, csv_filepath=tmpfile.name)
self.samba.push_from_local(self.destination_filepath, tmpfile.name)
# Cleaning up
hql = "DROP TABLE {table};"
self.hook.run_cli(hql=self.hql)
tmpfile.close()

Просмотреть файл

@ -1,7 +1,7 @@
import logging
import re
from airflow.hooks import HiveHook
from airflow.hooks import HiveCliHook
from airflow.models import BaseOperator
from airflow.utils import apply_defaults
@ -32,15 +32,14 @@ class HiveOperator(BaseOperator):
@apply_defaults
def __init__(
self, hql,
hive_conn_id='hive_default',
hive_cli_conn_id='hive_cli_default',
hiveconf_jinja_translate=False,
script_begin_tag=None,
*args, **kwargs):
super(HiveOperator, self).__init__(*args, **kwargs)
self.hiveconf_jinja_translate = hiveconf_jinja_translate
self.hive_conn_id = hive_conn_id
self.hook = HiveHook(hive_conn_id=hive_conn_id)
self.hook = HiveCliHook(hive_cli_conn_id=hive_cli_conn_id)
self.hql = hql
self.script_begin_tag = script_begin_tag

Просмотреть файл

@ -4,7 +4,7 @@ from urlparse import urlparse
from time import sleep
from airflow import settings
from airflow.hooks import HiveHook
from airflow.hooks import HiveMetastoreHook
from airflow.hooks import S3Hook
from airflow.models import BaseOperator
from airflow.models import Connection as DB
@ -169,7 +169,7 @@ class HivePartitionSensor(BaseSensorOperator):
def __init__(
self,
table, partition="ds='{{ ds }}'",
hive_conn_id='hive_default',
metastore_conn_id='metastore_default',
schema='default',
*args, **kwargs):
super(HivePartitionSensor, self).__init__(*args, **kwargs)
@ -177,8 +177,8 @@ class HivePartitionSensor(BaseSensorOperator):
schema, table = table.split('.')
if not partition:
partition = "ds='{{ ds }}'"
self.hive_conn_id = hive_conn_id
self.hook = HiveHook(hive_conn_id=hive_conn_id)
self.metastore_conn_id = metastore_conn_id
self.hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
self.table = table
self.partition = partition
self.schema = schema

Просмотреть файл

@ -21,6 +21,7 @@ pygments
pyhive
PySmbClient
python-dateutil
pyhs2
requests
setproctitle
snakebite

Просмотреть файл

@ -31,6 +31,7 @@ setup(
'pygments>=2.0.1',
'pysmbclient>=0.1.3',
'pyhive>=0.1.3',
'pyhs2>=0.6.0',
'python-dateutil>=2.3',
'requests>=2.5.1',
'setproctitle>=1.1.8',

Просмотреть файл

@ -37,7 +37,7 @@ class HivePrestoTest(unittest.TestCase):
SELECT state, year, name, gender, num FROM static_babynames;
"""
t = operators.HiveOperator(task_id='basic_hql', hql=hql, dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
def test_presto(self):
sql = """
@ -45,14 +45,14 @@ class HivePrestoTest(unittest.TestCase):
"""
t = operators.PrestoCheckOperator(
task_id='presto_check', sql=sql, dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
def test_hdfs_sensor(self):
t = operators.HdfsSensor(
task_id='hdfs_sensor_check',
filepath='/user/hive/warehouse/airflow.db/static_babynames',
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
def test_sql_sensor(self):
t = operators.SqlSensor(
@ -60,7 +60,23 @@ class HivePrestoTest(unittest.TestCase):
conn_id='presto_default',
sql="SELECT 'x' FROM airflow.static_babynames LIMIT 1;",
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
def test_hive_partition_sensor(self):
t = operators.HivePartitionSensor(
task_id='hive_partition_check',
table='airflow.static_babynames_partitioned',
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
def test_hive2samba(self):
t = operators.Hive2SambaOperator(
task_id='hive2samba_check',
samba_conn_id='tableau_samba',
hql="SELECT * FROM airflow.static_babynames LIMIT 1000",
destination_filepath='test_airflow3.csv',
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
class CoreTest(unittest.TestCase):