make a baseSqlHook and apply it to mysql, postgres, sqlite
This commit is contained in:
Родитель
f76b7fef38
Коммит
0daef23e8c
|
@ -4,6 +4,7 @@ abstracting the underlying modules
|
|||
'''
|
||||
from airflow.utils import import_module_attrs as _import_module_attrs
|
||||
from airflow.hooks.base_hook import BaseHook as _BaseHook
|
||||
from airflow.hooks.base_hook import BaseHook as _BaseSqlHook
|
||||
|
||||
_hooks = {
|
||||
'hive_hooks': [
|
||||
|
|
|
@ -48,3 +48,46 @@ class BaseHook(object):
|
|||
|
||||
def run(self, sql):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class BaseSqlHook(BaseHook):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, source):
|
||||
pass
|
||||
|
||||
# add get first
|
||||
|
||||
def get_pandas_df(self, sql):
|
||||
'''
|
||||
Executes the sql and returns a pandas dataframe
|
||||
'''
|
||||
import pandas.io.sql as psql
|
||||
conn = self.get_conn()
|
||||
df = psql.read_sql(sql, con=conn)
|
||||
conn.close()
|
||||
return df
|
||||
|
||||
def get_records(self, sql):
|
||||
'''
|
||||
Executes the sql and returns a set of records.
|
||||
'''
|
||||
conn = self.get_conn()
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
cur.close()
|
||||
conn.close()
|
||||
return rows
|
||||
|
||||
def get_first(self, sql):
|
||||
'''
|
||||
Executes the sql and returns a set of records.
|
||||
'''
|
||||
conn = self.get_conn()
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchone()
|
||||
cur.close()
|
||||
conn.close()
|
||||
return rows
|
||||
|
|
|
@ -4,10 +4,10 @@ import logging
|
|||
|
||||
import MySQLdb
|
||||
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
from airflow.hooks.base_hook import BaseSqlHook
|
||||
|
||||
|
||||
class MySqlHook(BaseHook):
|
||||
class MySqlHook(BaseSqlHook):
|
||||
'''
|
||||
Interact with MySQL.
|
||||
'''
|
||||
|
@ -28,28 +28,6 @@ class MySqlHook(BaseHook):
|
|||
conn.schema)
|
||||
return conn
|
||||
|
||||
def get_records(self, sql):
|
||||
'''
|
||||
Executes the sql and returns a set of records.
|
||||
'''
|
||||
conn = self.get_conn()
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
cur.close()
|
||||
conn.close()
|
||||
return rows
|
||||
|
||||
def get_pandas_df(self, sql):
|
||||
'''
|
||||
Executes the sql and returns a pandas dataframe
|
||||
'''
|
||||
import pandas.io.sql as psql
|
||||
conn = self.get_conn()
|
||||
df = psql.read_sql(sql, con=conn)
|
||||
conn.close()
|
||||
return df
|
||||
|
||||
def run(self, sql):
|
||||
conn = self.get_conn()
|
||||
cur = conn.cursor()
|
||||
|
|
|
@ -4,8 +4,10 @@ from airflow import settings
|
|||
from airflow.utils import AirflowException
|
||||
from airflow.models import Connection
|
||||
|
||||
from airflow.hooks.base_hook import BaseSqlHook
|
||||
|
||||
class PostgresHook(object):
|
||||
|
||||
class PostgresHook(BaseSqlHook):
|
||||
'''
|
||||
Interact with Postgres.
|
||||
'''
|
||||
|
@ -45,28 +47,6 @@ class PostgresHook(object):
|
|||
port=self.port)
|
||||
return conn
|
||||
|
||||
def get_records(self, sql):
|
||||
'''
|
||||
Executes the sql and returns a set of records.
|
||||
'''
|
||||
conn = self.get_conn()
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
cur.close()
|
||||
conn.close()
|
||||
return rows
|
||||
|
||||
def get_pandas_df(self, sql):
|
||||
'''
|
||||
Executes the sql and returns a pandas dataframe
|
||||
'''
|
||||
import pandas.io.sql as psql
|
||||
conn = self.get_conn()
|
||||
df = psql.read_sql(sql, con=conn)
|
||||
conn.close()
|
||||
return df
|
||||
|
||||
def run(self, sql, autocommit=False):
|
||||
conn = self.get_conn()
|
||||
conn.autocommit = autocommit
|
||||
|
|
|
@ -2,10 +2,10 @@ import logging
|
|||
|
||||
import sqlite3
|
||||
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
from airflow.hooks.base_hook import BaseSqlHook
|
||||
|
||||
|
||||
class SqliteHook(BaseHook):
|
||||
class SqliteHook(BaseSqlHook):
|
||||
|
||||
"""
|
||||
Interact with SQLite.
|
||||
|
@ -76,36 +76,3 @@ class SqliteHook(BaseHook):
|
|||
conn.close()
|
||||
logging.info(
|
||||
"Done loading. Loaded a total of {i} rows".format(**locals()))
|
||||
|
||||
def get_records(self, sql):
|
||||
"""
|
||||
Executes the sql and returns a set of records.
|
||||
|
||||
>>> h = SqliteHook()
|
||||
>>> sql = "SELECT * FROM test_table WHERE i=1 LIMIT 1;"
|
||||
>>> h.get_records(sql)
|
||||
[(1,)]
|
||||
"""
|
||||
conn = self.get_conn()
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
cur.close()
|
||||
conn.close()
|
||||
return rows
|
||||
|
||||
def get_pandas_df(self, sql):
|
||||
"""
|
||||
Executes the sql and returns a pandas dataframe
|
||||
|
||||
>>> h = SqliteHook()
|
||||
>>> sql = "SELECT * FROM test_table WHERE i=1 LIMIT 1;"
|
||||
>>> h.get_pandas_df(sql)
|
||||
i
|
||||
0 1
|
||||
"""
|
||||
import pandas.io.sql as psql
|
||||
conn = self.get_conn()
|
||||
df = psql.read_sql(sql, con=conn)
|
||||
conn.close()
|
||||
return df
|
||||
|
|
Загрузка…
Ссылка в новой задаче