From 4069c55013975c6b3903ce67dedd80090db9543b Mon Sep 17 00:00:00 2001 From: Aaron Keys Date: Thu, 14 May 2015 23:14:17 -0700 Subject: [PATCH] add sqlite hook --- airflow/hooks/__init__.py | 1 + airflow/hooks/sqlite_hook.py | 90 ++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 airflow/hooks/sqlite_hook.py diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py index 14c5578046..e678a29aa0 100644 --- a/airflow/hooks/__init__.py +++ b/airflow/hooks/__init__.py @@ -15,6 +15,7 @@ _hooks = { 'postgres_hook': ['PostgresHook'], 'presto_hook': ['PrestoHook'], 'samba_hook': ['SambaHook'], + 'sqlite_hook': ['SqliteHook'], 'S3_hook': ['S3Hook'], } diff --git a/airflow/hooks/sqlite_hook.py b/airflow/hooks/sqlite_hook.py new file mode 100644 index 0000000000..d4cf621577 --- /dev/null +++ b/airflow/hooks/sqlite_hook.py @@ -0,0 +1,90 @@ +import logging + +import sqlite3 + +from airflow.hooks.base_hook import BaseHook + + +class SqliteHook(BaseHook): + + ''' + Interact with SQLite. + ''' + + def __init__( + self, sqlite_conn_id='sqlite_default'): + self.sqlite_conn_id = sqlite_conn_id + + def get_conn(self): + """ + Returns a sqlite connection object + """ + conn = self.get_connection(self.sqlite_conn_id) + conn = sqlite3.connect(conn.host) + return conn + + def get_records(self, sql): + ''' + Executes the sql and returns a set of records. + ''' + conn = self.get_conn() + cur = conn.cursor() + cur.execute(sql) + rows = cur.fetchall() + cur.close() + conn.close() + return rows + + def get_pandas_df(self, sql): + ''' + Executes the sql and returns a pandas dataframe + ''' + import pandas.io.sql as psql + conn = self.get_conn() + df = psql.read_sql(sql, con=conn) + conn.close() + return df + + def run(self, sql): + conn = self.get_conn() + cur = conn.cursor() + cur.execute(sql) + conn.commit() + cur.close() + conn.close() + + def insert_rows(self, table, rows, target_fields=None): + """ + A generic way to insert a set of tuples into a table, + the whole set of inserts is treated as one transaction + """ + if target_fields: + target_fields = ", ".join(target_fields) + target_fields = "({})".format(target_fields) + else: + target_fields = '' + conn = self.get_conn() + cur = conn.cursor() + i = 0 + for row in rows: + i += 1 + l = [] + for cell in row: + if isinstance(cell, basestring): + l.append("'" + str(cell).replace("'", "''") + "'") + elif cell is None: + l.append('NULL') + else: + l.append(str(cell)) + values = tuple(l) + sql = "INSERT INTO {0} {1} VALUES ({2});".format( + table, + target_fields, + ",".join(values)) + cur.execute(sql) + conn.commit() + conn.commit() + cur.close() + conn.close() + logging.info( + "Done loading. Loaded a total of {i} rows".format(**locals()))