Merge pull request #140 from mistercrunch/hive_server2
Hive2SambaOperator to use hiveserver2, refactoring (breaking down) HiveHooks
This commit is contained in:
Коммит
9274c8b024
|
@ -1,5 +1,7 @@
|
|||
from airflow.hooks.mysql_hook import MySqlHook
|
||||
from airflow.hooks.hive_hook import HiveHook
|
||||
from airflow.hooks.hive_hooks import HiveCliHook
|
||||
from airflow.hooks.hive_hooks import HiveMetastoreHook
|
||||
from airflow.hooks.hive_hooks import HiveServer2Hook
|
||||
from airflow.hooks.presto_hook import PrestoHook
|
||||
from airflow.hooks.samba_hook import SambaHook
|
||||
from airflow.hooks.S3_hook import S3Hook
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
from airflow import settings
|
||||
from airflow.models import Connection
|
||||
|
||||
|
||||
class BaseHook(object):
|
||||
"""
|
||||
Abstract base class for hooks, hooks are meant as an interface to
|
||||
|
@ -9,6 +13,18 @@ class BaseHook(object):
|
|||
def __init__(self, source):
|
||||
pass
|
||||
|
||||
def get_connection(self, conn_id):
|
||||
session = settings.Session()
|
||||
db = session.query(
|
||||
Connection).filter(
|
||||
Connection.conn_id == conn_id).first()
|
||||
if not db:
|
||||
raise Exception(
|
||||
"The conn_id `{0}` isn't defined".format(conn_id))
|
||||
session.expunge_all()
|
||||
session.close()
|
||||
return db
|
||||
|
||||
def get_conn(self):
|
||||
raise NotImplemented()
|
||||
|
||||
|
|
|
@ -1,140 +1,39 @@
|
|||
import csv
|
||||
import logging
|
||||
import json
|
||||
import subprocess
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from airflow.models import Connection
|
||||
from airflow import settings
|
||||
|
||||
from thrift.transport import TSocket
|
||||
from thrift.transport import TTransport
|
||||
from thrift.protocol import TBinaryProtocol
|
||||
from hive_service import ThriftHive
|
||||
import pyhs2
|
||||
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
|
||||
|
||||
class HiveHook(BaseHook):
|
||||
class HiveCliHook(BaseHook):
|
||||
'''
|
||||
Interact with Hive. This class is both a wrapper around the Hive Thrift
|
||||
client and the Hive CLI.
|
||||
Simple wrapper around the hive CLI
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
hive_conn_id='hive_default'):
|
||||
session = settings.Session()
|
||||
db = session.query(
|
||||
Connection).filter(
|
||||
Connection.conn_id == hive_conn_id)
|
||||
if db.count() == 0:
|
||||
raise Exception("The conn_id you provided isn't defined")
|
||||
else:
|
||||
db = db.all()[0]
|
||||
self.host = db.host
|
||||
self.db = db.schema
|
||||
hive_cli_conn_id="hive_cli_default"
|
||||
):
|
||||
conn = self.get_connection(hive_cli_conn_id)
|
||||
self.hive_cli_params = ""
|
||||
try:
|
||||
self.hive_cli_params = json.loads(db.extra)['hive_cli_params']
|
||||
self.hive_cli_params = json.loads(
|
||||
conn.extra)['hive_cli_params']
|
||||
except:
|
||||
pass
|
||||
|
||||
self.port = db.port
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
# Connection to Hive
|
||||
self.hive = self.get_hive_client()
|
||||
|
||||
def __getstate__(self):
|
||||
# This is for pickling to work despite the thirft hive client not
|
||||
# being pickable
|
||||
d = dict(self.__dict__)
|
||||
del d['hive']
|
||||
return d
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
self.__dict__['hive'] = self.get_hive_client()
|
||||
|
||||
def get_hive_client(self):
|
||||
'''
|
||||
Returns a Hive thrift client.
|
||||
'''
|
||||
transport = TSocket.TSocket(self.host, self.port)
|
||||
transport = TTransport.TBufferedTransport(transport)
|
||||
protocol = TBinaryProtocol.TBinaryProtocol(transport)
|
||||
return ThriftHive.Client(protocol)
|
||||
|
||||
def get_conn(self):
|
||||
return self.hive
|
||||
|
||||
def check_for_partition(self, schema, table, partition):
|
||||
'''
|
||||
Checks whether a partition exists
|
||||
|
||||
>>> hh = HiveHook()
|
||||
>>> t = 'static_babynames_partitioned'
|
||||
>>> hh.check_for_partition('airflow', t, "ds='2015-01-01'")
|
||||
True
|
||||
'''
|
||||
self.hive._oprot.trans.open()
|
||||
partitions = self.hive.get_partitions_by_filter(
|
||||
schema, table, partition, 1)
|
||||
self.hive._oprot.trans.close()
|
||||
if partitions:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_records(self, hql, schema=None):
|
||||
'''
|
||||
Get a set of records from a Hive query.
|
||||
|
||||
>>> hh = HiveHook()
|
||||
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
|
||||
>>> hh.get_records(sql)
|
||||
[['340698']]
|
||||
'''
|
||||
self.hive._oprot.trans.open()
|
||||
if schema:
|
||||
self.hive.execute("USE " + schema)
|
||||
self.hive.execute(hql)
|
||||
records = self.hive.fetchAll()
|
||||
self.hive._oprot.trans.close()
|
||||
return [row.split("\t") for row in records]
|
||||
|
||||
def get_pandas_df(self, hql, schema=None):
|
||||
'''
|
||||
Get a pandas dataframe from a Hive query
|
||||
|
||||
>>> hh = HiveHook()
|
||||
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
|
||||
>>> df = hh.get_pandas_df(sql)
|
||||
>>> df.to_dict()
|
||||
{0: {0: '340698'}}
|
||||
'''
|
||||
import pandas as pd
|
||||
self.hive._oprot.trans.open()
|
||||
if schema:
|
||||
self.hive.execute("USE " + schema)
|
||||
self.hive.execute(hql)
|
||||
records = self.hive.fetchAll()
|
||||
self.hive._oprot.trans.close()
|
||||
df = pd.DataFrame([row.split("\t") for row in records])
|
||||
return df
|
||||
|
||||
def run(self, hql, schema=None):
|
||||
self.hive._oprot.trans.open()
|
||||
if schema:
|
||||
self.hive.execute("USE " + schema)
|
||||
self.hive.execute(hql)
|
||||
self.hive._oprot.trans.close()
|
||||
|
||||
def run_cli(self, hql, schema=None):
|
||||
'''
|
||||
Run an hql statement using the hive cli
|
||||
|
||||
>>> hh = HiveHook()
|
||||
>>> hh = HiveCliHook()
|
||||
>>> hh.run_cli("USE airflow;")
|
||||
'''
|
||||
if schema:
|
||||
|
@ -158,20 +57,77 @@ class HiveHook(BaseHook):
|
|||
if sp.returncode:
|
||||
raise Exception(all_err)
|
||||
|
||||
def kill(self):
|
||||
if hasattr(self, 'sp'):
|
||||
if self.sp.poll() is None:
|
||||
print("Killing the Hive job")
|
||||
self.sp.kill()
|
||||
|
||||
|
||||
class HiveMetastoreHook(BaseHook):
|
||||
'''
|
||||
Wrapper to interact with the Hive Metastore
|
||||
'''
|
||||
def __init__(self, metastore_conn_id='metastore_default'):
|
||||
self.metastore_conn = self.get_connection(metastore_conn_id)
|
||||
self.metastore = self.get_metastore_client()
|
||||
|
||||
def __getstate__(self):
|
||||
# This is for pickling to work despite the thirft hive client not
|
||||
# being pickable
|
||||
d = dict(self.__dict__)
|
||||
del d['metastore']
|
||||
return d
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__.update(d)
|
||||
self.__dict__['metastore'] = self.get_metastore_client()
|
||||
|
||||
def get_metastore_client(self):
|
||||
'''
|
||||
Returns a Hive thrift client.
|
||||
'''
|
||||
ms = self.metastore_conn
|
||||
transport = TSocket.TSocket(ms.host, ms.port)
|
||||
transport = TTransport.TBufferedTransport(transport)
|
||||
protocol = TBinaryProtocol.TBinaryProtocol(transport)
|
||||
return ThriftHive.Client(protocol)
|
||||
|
||||
def get_conn(self):
|
||||
return self.metastore
|
||||
|
||||
def check_for_partition(self, schema, table, partition):
|
||||
'''
|
||||
Checks whether a partition exists
|
||||
|
||||
>>> hh = HiveMetastoreHook()
|
||||
>>> t = 'static_babynames_partitioned'
|
||||
>>> hh.check_for_partition('airflow', t, "ds='2015-01-01'")
|
||||
True
|
||||
'''
|
||||
self.metastore._oprot.trans.open()
|
||||
partitions = self.metastore.get_partitions_by_filter(
|
||||
schema, table, partition, 1)
|
||||
self.metastore._oprot.trans.close()
|
||||
if partitions:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_table(self, db, table_name):
|
||||
'''
|
||||
Get a metastore table object
|
||||
|
||||
>>> hh = HiveHook()
|
||||
>>> hh = HiveMetastoreHook()
|
||||
>>> t = hh.get_table(db='airflow', table_name='static_babynames')
|
||||
>>> t.tableName
|
||||
'static_babynames'
|
||||
>>> [col.name for col in t.sd.cols]
|
||||
['state', 'year', 'name', 'gender', 'num']
|
||||
'''
|
||||
self.hive._oprot.trans.open()
|
||||
table = self.hive.get_table(dbname=db, tbl_name=table_name)
|
||||
self.hive._oprot.trans.close()
|
||||
self.metastore._oprot.trans.open()
|
||||
table = self.metastore.get_table(dbname=db, tbl_name=table_name)
|
||||
self.metastore._oprot.trans.close()
|
||||
return table
|
||||
|
||||
def get_partitions(self, schema, table_name):
|
||||
|
@ -180,7 +136,7 @@ class HiveHook(BaseHook):
|
|||
for tables with less than 32767 (java short max val).
|
||||
For subpartitionned table, the number might easily exceed this.
|
||||
|
||||
>>> hh = HiveHook()
|
||||
>>> hh = HiveMetastoreHook()
|
||||
>>> t = 'static_babynames_partitioned'
|
||||
>>> parts = hh.get_partitions(schema='airflow', table_name=t)
|
||||
>>> len(parts)
|
||||
|
@ -188,8 +144,8 @@ class HiveHook(BaseHook):
|
|||
>>> max(parts)
|
||||
'2015-01-01'
|
||||
'''
|
||||
self.hive._oprot.trans.open()
|
||||
table = self.hive.get_table(dbname=schema, tbl_name=table_name)
|
||||
self.metastore._oprot.trans.open()
|
||||
table = self.metastore.get_table(dbname=schema, tbl_name=table_name)
|
||||
if len(table.partitionKeys) == 0:
|
||||
raise Exception("The table isn't partitionned")
|
||||
elif len(table.partitionKeys) > 1:
|
||||
|
@ -197,10 +153,10 @@ class HiveHook(BaseHook):
|
|||
"The table is partitionned by multiple columns, "
|
||||
"use a signal table!")
|
||||
else:
|
||||
parts = self.hive.get_partitions(
|
||||
parts = self.metastore.get_partitions(
|
||||
db_name=schema, tbl_name=table_name, max_parts=32767)
|
||||
|
||||
self.hive._oprot.trans.close()
|
||||
self.metastore._oprot.trans.close()
|
||||
return [p.values[0] for p in parts]
|
||||
|
||||
def max_partition(self, schema, table_name):
|
||||
|
@ -209,15 +165,84 @@ class HiveHook(BaseHook):
|
|||
for tables that have a single partition key. For subpartitionned
|
||||
table, we recommend using signal tables.
|
||||
|
||||
>>> hh = HiveHook()
|
||||
>>> hh = HiveMetastoreHook()
|
||||
>>> t = 'static_babynames_partitioned'
|
||||
>>> hh.max_partition(schema='airflow', table_name=t)
|
||||
'2015-01-01'
|
||||
'''
|
||||
return max(self.get_partitions(schema, table_name))
|
||||
|
||||
def kill(self):
|
||||
if hasattr(self, 'sp'):
|
||||
if self.sp.poll() is None:
|
||||
print("Killing the Hive job")
|
||||
self.sp.kill()
|
||||
|
||||
class HiveServer2Hook(BaseHook):
|
||||
'''
|
||||
Wrapper around the pyhs2 lib
|
||||
'''
|
||||
def __init__(self, hiveserver2_conn_id='hiveserver2_default'):
|
||||
self.hiveserver2_conn = self.get_connection(hiveserver2_conn_id)
|
||||
|
||||
def get_results(self, hql, schema='default'):
|
||||
schema = schema or 'default'
|
||||
with pyhs2.connect(
|
||||
host=self.hiveserver2_conn.host,
|
||||
port=self.hiveserver2_conn.port,
|
||||
authMechanism="NOSASL",
|
||||
user='airflow',
|
||||
database=schema) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(hql)
|
||||
return {
|
||||
'data': cur.fetchall(),
|
||||
'header': cur.getSchema(),
|
||||
}
|
||||
|
||||
def to_csv(self, hql, csv_filepath, schema='default'):
|
||||
schema = schema or 'default'
|
||||
with pyhs2.connect(
|
||||
host=self.hiveserver2_conn.host,
|
||||
port=self.hiveserver2_conn.port,
|
||||
authMechanism="NOSASL",
|
||||
user='airflow',
|
||||
database=schema) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(hql)
|
||||
schema = cur.getSchema()
|
||||
with open(csv_filepath, 'w') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([c['columnName'] for c in cur.getSchema()])
|
||||
while cur.hasMoreRows:
|
||||
writer.writerows(
|
||||
[row for row in cur.fetchmany() if row])
|
||||
|
||||
def get_records(self, hql, schema='default'):
|
||||
'''
|
||||
Get a set of records from a Hive query.
|
||||
|
||||
>>> hh = HiveServer2Hook()
|
||||
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
|
||||
>>> hh.get_records(sql)
|
||||
[[340698]]
|
||||
'''
|
||||
return self.get_results(hql, schema=schema)['data']
|
||||
|
||||
def get_pandas_df(self, hql, schema='default'):
|
||||
'''
|
||||
Get a pandas dataframe from a Hive query
|
||||
|
||||
>>> hh = HiveServer2Hook()
|
||||
>>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
|
||||
>>> df = hh.get_pandas_df(sql)
|
||||
>>> df.to_dict()
|
||||
{'num': {0: 340698}}
|
||||
'''
|
||||
import pandas as pd
|
||||
res = self.get_results(hql, schema=schema)
|
||||
df = pd.DataFrame(res['data'])
|
||||
df.columns = [c['columnName'] for c in res['header']]
|
||||
return df
|
||||
|
||||
def run(self, hql, schema=None):
|
||||
self.hive._oprot.trans.open()
|
||||
if schema:
|
||||
self.hive.execute("USE " + schema)
|
||||
self.hive.execute(hql)
|
||||
self.hive._oprot.trans.close()
|
|
@ -1,39 +1,27 @@
|
|||
import os
|
||||
from smbclient import SambaClient
|
||||
|
||||
from airflow import settings
|
||||
from airflow.models import Connection
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
|
||||
|
||||
class SambaHook(object):
|
||||
class SambaHook(BaseHook):
|
||||
'''
|
||||
Allows for interaction with an samba server.
|
||||
'''
|
||||
|
||||
def __init__(self, samba_conn_id=None):
|
||||
session = settings.Session()
|
||||
samba_conn = session.query(
|
||||
Connection).filter(
|
||||
Connection.conn_id == samba_conn_id).first()
|
||||
if not samba_conn:
|
||||
raise Exception("The samba id you provided isn't defined")
|
||||
self.host = samba_conn.host
|
||||
self.login = samba_conn.login
|
||||
self.psw = samba_conn.password
|
||||
self.db = samba_conn.schema
|
||||
session.commit()
|
||||
session.close()
|
||||
def __init__(self, samba_conn_id):
|
||||
self.conn = self.get_connection(samba_conn_id)
|
||||
|
||||
def get_conn(self):
|
||||
samba = SambaClient(
|
||||
server=self.host, share='', username=self.login, password=self.psw)
|
||||
server=self.conn.host,
|
||||
share=self.conn.schema,
|
||||
username=self.conn.login,
|
||||
ip=self.conn.host,
|
||||
password=self.conn.password)
|
||||
return samba
|
||||
|
||||
def push_from_local(self, destination_filepath, local_filepath):
|
||||
samba = self.get_conn()
|
||||
samba.cwd(os.path.dirname(self.destination_filepath))
|
||||
f = open(local_filepath, 'r')
|
||||
filename = os.path.basename(destination_filepath)
|
||||
samba.storbinary("STOR " + filename, f)
|
||||
f.close()
|
||||
samba.quit()
|
||||
if samba.exists(destination_filepath):
|
||||
samba.remove(destination_filepath)
|
||||
samba.upload(local_filepath, destination_filepath)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from airflow.configuration import conf
|
||||
import datetime
|
||||
|
||||
|
||||
def max_partition(
|
||||
table, schema="default",
|
||||
hive_conn_id='hive_default'):
|
||||
metastore_conn_id='metastore_default'):
|
||||
'''
|
||||
Gets the max partition for a table.
|
||||
|
||||
|
@ -17,11 +16,14 @@ def max_partition(
|
|||
:param hive_conn_id: The hive connection you are interested in.
|
||||
If your default is set you don't need to use this parameter.
|
||||
:type hive_conn_id: string
|
||||
|
||||
>>> max_partition('airflow.static_babynames_partitioned')
|
||||
'2015-01-01'
|
||||
'''
|
||||
from airflow.hooks.hive_hook import HiveHook
|
||||
from airflow.hooks import HiveMetastoreHook
|
||||
if '.' in table:
|
||||
schema, table = table.split('.')
|
||||
hh = HiveHook(hive_conn_id=hive_conn_id)
|
||||
hh = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
|
||||
return hh.max_partition(schema=schema, table_name=table)
|
||||
|
||||
|
||||
|
@ -52,7 +54,7 @@ def _closest_date(target_dt, date_list, before_target=None):
|
|||
|
||||
def closest_ds_partition(
|
||||
table, ds, before=True, schema="default",
|
||||
hive_conn_id='hive_default'):
|
||||
metastore_conn_id='metastore_default'):
|
||||
'''
|
||||
This function finds the date in a list closest to the target date.
|
||||
An optional paramter can be given to get the closest before or after.
|
||||
|
@ -66,12 +68,15 @@ def closest_ds_partition(
|
|||
:returns: The closest date
|
||||
:rtype: str or None
|
||||
|
||||
>>> tbl = 'airflow.static_babynames_partitioned'
|
||||
>>> closest_ds_partition(tbl, '2015-01-02')
|
||||
'2015-01-01'
|
||||
'''
|
||||
from airflow.hooks.hive_hook import HiveHook
|
||||
from airflow.hooks import HiveMetastoreHook
|
||||
if '.' in table:
|
||||
schema, table = table.split('.')
|
||||
target_dt = datetime.datetime.strptime(ds, '%Y-%m-%d')
|
||||
hh = HiveHook(hive_conn_id=hive_conn_id)
|
||||
hh = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
|
||||
partitions = hh.get_partitions(schema=schema, table_name=table)
|
||||
parts = [datetime.datetime.strptime(p, '%Y-%m-%d') for p in partitions]
|
||||
if partitions is None or partitions == []:
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
import logging
|
||||
import tempfile
|
||||
|
||||
from airflow.configuration import conf
|
||||
from airflow.hooks import HiveHook, SambaHook
|
||||
from airflow.hooks import HiveServer2Hook, SambaHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils import apply_defaults
|
||||
|
||||
|
@ -13,10 +12,10 @@ class Hive2SambaOperator(BaseOperator):
|
|||
|
||||
:param hql: the hql to be exported
|
||||
:type hql: string
|
||||
:param hive_dbid: reference to the Hive database
|
||||
:type hive_dbid: string
|
||||
:param samba_dbid: reference to the samba destination
|
||||
:type samba_dbid: string
|
||||
:param hiveserver2_conn_id: reference to the hiveserver2 service
|
||||
:type hiveserver2_conn_id: string
|
||||
:param samba_conn_id: reference to the samba destination
|
||||
:type samba_conn_id: string
|
||||
"""
|
||||
|
||||
__mapper_args__ = {
|
||||
|
@ -28,33 +27,20 @@ class Hive2SambaOperator(BaseOperator):
|
|||
@apply_defaults
|
||||
def __init__(
|
||||
self, hql,
|
||||
samba_dbid,
|
||||
destination_filepath,
|
||||
hive_dbid='hive_default',
|
||||
samba_conn_id='samba_default',
|
||||
hiveserver2_conn_id='hiveserver2_default',
|
||||
*args, **kwargs):
|
||||
super(Hive2SambaOperator, self).__init__(*args, **kwargs)
|
||||
|
||||
self.hive_dbid = hive_dbid
|
||||
self.samba_dbid = samba_dbid
|
||||
self.hiveserver2_conn_id = hiveserver2_conn_id
|
||||
self.samba_conn_id = samba_conn_id
|
||||
self.destination_filepath = destination_filepath
|
||||
self.samba = SambaHook(samba_dbid=samba_dbid)
|
||||
self.hook = HiveHook(hive_dbid=hive_dbid)
|
||||
self.samba = SambaHook(samba_conn_id=samba_conn_id)
|
||||
self.hook = HiveServer2Hook(hiveserver2_conn_id=hiveserver2_conn_id)
|
||||
self.hql = hql.strip().rstrip(';')
|
||||
|
||||
def execute(self, context):
|
||||
tmpfile = tempfile.NamedTemporaryFile()
|
||||
hql = """\
|
||||
INSERT OVERWRITE LOCAL DIRECTORY '{tmpfile.name}'
|
||||
ROW FORMAT DELIMITED
|
||||
FIELDS TERMINATED BY ','
|
||||
{self.hql};
|
||||
""".format(**locals())
|
||||
logging.info('Executing: ' + hql)
|
||||
self.hook.run_cli(hql=hql)
|
||||
|
||||
self.hook.to_csv(hql=self.hql, csv_filepath=tmpfile.name)
|
||||
self.samba.push_from_local(self.destination_filepath, tmpfile.name)
|
||||
|
||||
# Cleaning up
|
||||
hql = "DROP TABLE {table};"
|
||||
self.hook.run_cli(hql=self.hql)
|
||||
tmpfile.close()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import re
|
||||
|
||||
from airflow.hooks import HiveHook
|
||||
from airflow.hooks import HiveCliHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils import apply_defaults
|
||||
|
||||
|
@ -32,15 +32,14 @@ class HiveOperator(BaseOperator):
|
|||
@apply_defaults
|
||||
def __init__(
|
||||
self, hql,
|
||||
hive_conn_id='hive_default',
|
||||
hive_cli_conn_id='hive_cli_default',
|
||||
hiveconf_jinja_translate=False,
|
||||
script_begin_tag=None,
|
||||
*args, **kwargs):
|
||||
|
||||
super(HiveOperator, self).__init__(*args, **kwargs)
|
||||
self.hiveconf_jinja_translate = hiveconf_jinja_translate
|
||||
self.hive_conn_id = hive_conn_id
|
||||
self.hook = HiveHook(hive_conn_id=hive_conn_id)
|
||||
self.hook = HiveCliHook(hive_cli_conn_id=hive_cli_conn_id)
|
||||
self.hql = hql
|
||||
self.script_begin_tag = script_begin_tag
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from urlparse import urlparse
|
|||
from time import sleep
|
||||
|
||||
from airflow import settings
|
||||
from airflow.hooks import HiveHook
|
||||
from airflow.hooks import HiveMetastoreHook
|
||||
from airflow.hooks import S3Hook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.models import Connection as DB
|
||||
|
@ -169,7 +169,7 @@ class HivePartitionSensor(BaseSensorOperator):
|
|||
def __init__(
|
||||
self,
|
||||
table, partition="ds='{{ ds }}'",
|
||||
hive_conn_id='hive_default',
|
||||
metastore_conn_id='metastore_default',
|
||||
schema='default',
|
||||
*args, **kwargs):
|
||||
super(HivePartitionSensor, self).__init__(*args, **kwargs)
|
||||
|
@ -177,8 +177,8 @@ class HivePartitionSensor(BaseSensorOperator):
|
|||
schema, table = table.split('.')
|
||||
if not partition:
|
||||
partition = "ds='{{ ds }}'"
|
||||
self.hive_conn_id = hive_conn_id
|
||||
self.hook = HiveHook(hive_conn_id=hive_conn_id)
|
||||
self.metastore_conn_id = metastore_conn_id
|
||||
self.hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
|
||||
self.table = table
|
||||
self.partition = partition
|
||||
self.schema = schema
|
||||
|
|
|
@ -21,6 +21,7 @@ pygments
|
|||
pyhive
|
||||
PySmbClient
|
||||
python-dateutil
|
||||
pyhs2
|
||||
requests
|
||||
setproctitle
|
||||
snakebite
|
||||
|
|
1
setup.py
1
setup.py
|
@ -31,6 +31,7 @@ setup(
|
|||
'pygments>=2.0.1',
|
||||
'pysmbclient>=0.1.3',
|
||||
'pyhive>=0.1.3',
|
||||
'pyhs2>=0.6.0',
|
||||
'python-dateutil>=2.3',
|
||||
'requests>=2.5.1',
|
||||
'setproctitle>=1.1.8',
|
||||
|
|
|
@ -37,7 +37,7 @@ class HivePrestoTest(unittest.TestCase):
|
|||
SELECT state, year, name, gender, num FROM static_babynames;
|
||||
"""
|
||||
t = operators.HiveOperator(task_id='basic_hql', hql=hql, dag=self.dag)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
|
||||
|
||||
def test_presto(self):
|
||||
sql = """
|
||||
|
@ -45,14 +45,14 @@ class HivePrestoTest(unittest.TestCase):
|
|||
"""
|
||||
t = operators.PrestoCheckOperator(
|
||||
task_id='presto_check', sql=sql, dag=self.dag)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
|
||||
|
||||
def test_hdfs_sensor(self):
|
||||
t = operators.HdfsSensor(
|
||||
task_id='hdfs_sensor_check',
|
||||
filepath='/user/hive/warehouse/airflow.db/static_babynames',
|
||||
dag=self.dag)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
|
||||
|
||||
def test_sql_sensor(self):
|
||||
t = operators.SqlSensor(
|
||||
|
@ -60,7 +60,23 @@ class HivePrestoTest(unittest.TestCase):
|
|||
conn_id='presto_default',
|
||||
sql="SELECT 'x' FROM airflow.static_babynames LIMIT 1;",
|
||||
dag=self.dag)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
|
||||
|
||||
def test_hive_partition_sensor(self):
|
||||
t = operators.HivePartitionSensor(
|
||||
task_id='hive_partition_check',
|
||||
table='airflow.static_babynames_partitioned',
|
||||
dag=self.dag)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
|
||||
|
||||
def test_hive2samba(self):
|
||||
t = operators.Hive2SambaOperator(
|
||||
task_id='hive2samba_check',
|
||||
samba_conn_id='tableau_samba',
|
||||
hql="SELECT * FROM airflow.static_babynames LIMIT 1000",
|
||||
destination_filepath='test_airflow3.csv',
|
||||
dag=self.dag)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True)
|
||||
|
||||
|
||||
class CoreTest(unittest.TestCase):
|
||||
|
|
Загрузка…
Ссылка в новой задаче