make a baseSqlHook and apply it to mysql, postgres, sqlite

This commit is contained in:
Steve Mardenfeld 2015-07-16 17:04:10 -07:00
Родитель f76b7fef38
Коммит 0daef23e8c
5 изменённых файлов: 51 добавлений и 82 удалений

Просмотреть файл

@ -4,6 +4,7 @@ abstracting the underlying modules
'''
from airflow.utils import import_module_attrs as _import_module_attrs
from airflow.hooks.base_hook import BaseHook as _BaseHook
from airflow.hooks.base_hook import BaseHook as _BaseSqlHook
_hooks = {
'hive_hooks': [

Просмотреть файл

@ -48,3 +48,46 @@ class BaseHook(object):
def run(self, sql):
raise NotImplemented()
class BaseSqlHook(BaseHook):
"""
"""
def __init__(self, source):
pass
# add get first
def get_pandas_df(self, sql):
'''
Executes the sql and returns a pandas dataframe
'''
import pandas.io.sql as psql
conn = self.get_conn()
df = psql.read_sql(sql, con=conn)
conn.close()
return df
def get_records(self, sql):
'''
Executes the sql and returns a set of records.
'''
conn = self.get_conn()
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
cur.close()
conn.close()
return rows
def get_first(self, sql):
'''
Executes the sql and returns a set of records.
'''
conn = self.get_conn()
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchone()
cur.close()
conn.close()
return rows

Просмотреть файл

@ -4,10 +4,10 @@ import logging
import MySQLdb
from airflow.hooks.base_hook import BaseHook
from airflow.hooks.base_hook import BaseSqlHook
class MySqlHook(BaseHook):
class MySqlHook(BaseSqlHook):
'''
Interact with MySQL.
'''
@ -28,28 +28,6 @@ class MySqlHook(BaseHook):
conn.schema)
return conn
def get_records(self, sql):
'''
Executes the sql and returns a set of records.
'''
conn = self.get_conn()
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
cur.close()
conn.close()
return rows
def get_pandas_df(self, sql):
'''
Executes the sql and returns a pandas dataframe
'''
import pandas.io.sql as psql
conn = self.get_conn()
df = psql.read_sql(sql, con=conn)
conn.close()
return df
def run(self, sql):
conn = self.get_conn()
cur = conn.cursor()

Просмотреть файл

@ -4,8 +4,10 @@ from airflow import settings
from airflow.utils import AirflowException
from airflow.models import Connection
from airflow.hooks.base_hook import BaseSqlHook
class PostgresHook(object):
class PostgresHook(BaseSqlHook):
'''
Interact with Postgres.
'''
@ -45,28 +47,6 @@ class PostgresHook(object):
port=self.port)
return conn
def get_records(self, sql):
'''
Executes the sql and returns a set of records.
'''
conn = self.get_conn()
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
cur.close()
conn.close()
return rows
def get_pandas_df(self, sql):
'''
Executes the sql and returns a pandas dataframe
'''
import pandas.io.sql as psql
conn = self.get_conn()
df = psql.read_sql(sql, con=conn)
conn.close()
return df
def run(self, sql, autocommit=False):
conn = self.get_conn()
conn.autocommit = autocommit

Просмотреть файл

@ -2,10 +2,10 @@ import logging
import sqlite3
from airflow.hooks.base_hook import BaseHook
from airflow.hooks.base_hook import BaseSqlHook
class SqliteHook(BaseHook):
class SqliteHook(BaseSqlHook):
"""
Interact with SQLite.
@ -76,36 +76,3 @@ class SqliteHook(BaseHook):
conn.close()
logging.info(
"Done loading. Loaded a total of {i} rows".format(**locals()))
def get_records(self, sql):
"""
Executes the sql and returns a set of records.
>>> h = SqliteHook()
>>> sql = "SELECT * FROM test_table WHERE i=1 LIMIT 1;"
>>> h.get_records(sql)
[(1,)]
"""
conn = self.get_conn()
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
cur.close()
conn.close()
return rows
def get_pandas_df(self, sql):
"""
Executes the sql and returns a pandas dataframe
>>> h = SqliteHook()
>>> sql = "SELECT * FROM test_table WHERE i=1 LIMIT 1;"
>>> h.get_pandas_df(sql)
i
0 1
"""
import pandas.io.sql as psql
conn = self.get_conn()
df = psql.read_sql(sql, con=conn)
conn.close()
return df