From 0daef23e8cc9b54382dd8b79f84e341d4f400feb Mon Sep 17 00:00:00 2001 From: Steve Mardenfeld Date: Thu, 16 Jul 2015 17:04:10 -0700 Subject: [PATCH] make a baseSqlHook and apply it to mysql, postgres, sqlite --- airflow/hooks/__init__.py | 1 + airflow/hooks/base_hook.py | 43 ++++++++++++++++++++++++++++++++++ airflow/hooks/mysql_hook.py | 26 ++------------------ airflow/hooks/postgres_hook.py | 26 +++----------------- airflow/hooks/sqlite_hook.py | 37 ++--------------------------- 5 files changed, 51 insertions(+), 82 deletions(-) diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py index 1297a9c0f3..60ea190399 100644 --- a/airflow/hooks/__init__.py +++ b/airflow/hooks/__init__.py @@ -4,6 +4,7 @@ abstracting the underlying modules ''' from airflow.utils import import_module_attrs as _import_module_attrs from airflow.hooks.base_hook import BaseHook as _BaseHook +from airflow.hooks.base_hook import BaseHook as _BaseSqlHook _hooks = { 'hive_hooks': [ diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py index 4265b2e20e..62c06c8331 100644 --- a/airflow/hooks/base_hook.py +++ b/airflow/hooks/base_hook.py @@ -48,3 +48,46 @@ class BaseHook(object): def run(self, sql): raise NotImplemented() + + +class BaseSqlHook(BaseHook): + """ + """ + def __init__(self, source): + pass + + # add get first + + def get_pandas_df(self, sql): + ''' + Executes the sql and returns a pandas dataframe + ''' + import pandas.io.sql as psql + conn = self.get_conn() + df = psql.read_sql(sql, con=conn) + conn.close() + return df + + def get_records(self, sql): + ''' + Executes the sql and returns a set of records. + ''' + conn = self.get_conn() + cur = conn.cursor() + cur.execute(sql) + rows = cur.fetchall() + cur.close() + conn.close() + return rows + + def get_first(self, sql): + ''' + Executes the sql and returns a set of records. + ''' + conn = self.get_conn() + cur = conn.cursor() + cur.execute(sql) + rows = cur.fetchone() + cur.close() + conn.close() + return rows diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py index 93d92d3b4b..41a89ae5bd 100644 --- a/airflow/hooks/mysql_hook.py +++ b/airflow/hooks/mysql_hook.py @@ -4,10 +4,10 @@ import logging import MySQLdb -from airflow.hooks.base_hook import BaseHook +from airflow.hooks.base_hook import BaseSqlHook -class MySqlHook(BaseHook): +class MySqlHook(BaseSqlHook): ''' Interact with MySQL. ''' @@ -28,28 +28,6 @@ class MySqlHook(BaseHook): conn.schema) return conn - def get_records(self, sql): - ''' - Executes the sql and returns a set of records. - ''' - conn = self.get_conn() - cur = conn.cursor() - cur.execute(sql) - rows = cur.fetchall() - cur.close() - conn.close() - return rows - - def get_pandas_df(self, sql): - ''' - Executes the sql and returns a pandas dataframe - ''' - import pandas.io.sql as psql - conn = self.get_conn() - df = psql.read_sql(sql, con=conn) - conn.close() - return df - def run(self, sql): conn = self.get_conn() cur = conn.cursor() diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py index 4b029b0e98..c663c0c1a2 100644 --- a/airflow/hooks/postgres_hook.py +++ b/airflow/hooks/postgres_hook.py @@ -4,8 +4,10 @@ from airflow import settings from airflow.utils import AirflowException from airflow.models import Connection +from airflow.hooks.base_hook import BaseSqlHook -class PostgresHook(object): + +class PostgresHook(BaseSqlHook): ''' Interact with Postgres. ''' @@ -45,28 +47,6 @@ class PostgresHook(object): port=self.port) return conn - def get_records(self, sql): - ''' - Executes the sql and returns a set of records. - ''' - conn = self.get_conn() - cur = conn.cursor() - cur.execute(sql) - rows = cur.fetchall() - cur.close() - conn.close() - return rows - - def get_pandas_df(self, sql): - ''' - Executes the sql and returns a pandas dataframe - ''' - import pandas.io.sql as psql - conn = self.get_conn() - df = psql.read_sql(sql, con=conn) - conn.close() - return df - def run(self, sql, autocommit=False): conn = self.get_conn() conn.autocommit = autocommit diff --git a/airflow/hooks/sqlite_hook.py b/airflow/hooks/sqlite_hook.py index 5bb4c76d7a..64ef05b667 100644 --- a/airflow/hooks/sqlite_hook.py +++ b/airflow/hooks/sqlite_hook.py @@ -2,10 +2,10 @@ import logging import sqlite3 -from airflow.hooks.base_hook import BaseHook +from airflow.hooks.base_hook import BaseSqlHook -class SqliteHook(BaseHook): +class SqliteHook(BaseSqlHook): """ Interact with SQLite. @@ -76,36 +76,3 @@ class SqliteHook(BaseHook): conn.close() logging.info( "Done loading. Loaded a total of {i} rows".format(**locals())) - - def get_records(self, sql): - """ - Executes the sql and returns a set of records. - - >>> h = SqliteHook() - >>> sql = "SELECT * FROM test_table WHERE i=1 LIMIT 1;" - >>> h.get_records(sql) - [(1,)] - """ - conn = self.get_conn() - cur = conn.cursor() - cur.execute(sql) - rows = cur.fetchall() - cur.close() - conn.close() - return rows - - def get_pandas_df(self, sql): - """ - Executes the sql and returns a pandas dataframe - - >>> h = SqliteHook() - >>> sql = "SELECT * FROM test_table WHERE i=1 LIMIT 1;" - >>> h.get_pandas_df(sql) - i - 0 1 - """ - import pandas.io.sql as psql - conn = self.get_conn() - df = psql.read_sql(sql, con=conn) - conn.close() - return df