diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index b787c2496d..9d2ea280a9 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -26,14 +26,16 @@ class DbApiHook(BaseHook): """ supports_autocommit = False - def __init__(self, **kwargs): - try: - self.conn_id_name = kwargs[self.conn_name_attr] - except NameError: + def __init__(self, *args, **kwargs): + if not self.conn_name_attr: raise AirflowException("conn_name_attr is not defined") - except KeyError: - raise AirflowException( - self.conn_name_attr + " was not passed in the kwargs") + elif len(args) == 1: + setattr(self, self.conn_name_attr, args[0]) + elif self.conn_name_attr not in kwargs: + setattr(self, self.conn_name_attr, self.default_conn_name) + else: + setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr]) + def get_pandas_df(self, sql, parameters=None): ''' diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py index b4f1bc4631..856a4d53f8 100644 --- a/airflow/hooks/mysql_hook.py +++ b/airflow/hooks/mysql_hook.py @@ -16,7 +16,7 @@ class MySqlHook(DbApiHook): """ Returns a mysql connection object """ - conn = self.get_connection(self.conn_id_name) + conn = self.get_connection(self.mysql_conn_id) conn = MySQLdb.connect( conn.host, conn.login, diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py index 1682ef5c33..fc5bfa1432 100644 --- a/airflow/hooks/postgres_hook.py +++ b/airflow/hooks/postgres_hook.py @@ -13,7 +13,7 @@ class PostgresHook(DbApiHook): supports_autocommit = True def get_conn(self): - conn = self.get_connection(self.conn_id_name) + conn = self.get_connection(self.postgres_conn_id) return psycopg2.connect( host=conn.host, user=conn.login, diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py index 2355eef7eb..f2eabf581e 100644 --- a/airflow/hooks/presto_hook.py +++ b/airflow/hooks/presto_hook.py @@ -26,7 +26,7 @@ class PrestoHook(DbApiHook): def get_conn(self): """Returns a connection object""" - db = self.get_connection(self.conn_id_name) + db = self.get_connection(self.presto_conn_id) return presto.connect( host=db.host, port=db.port, diff --git a/airflow/hooks/sqlite_hook.py b/airflow/hooks/sqlite_hook.py index 1dafd86386..697e762b45 100644 --- a/airflow/hooks/sqlite_hook.py +++ b/airflow/hooks/sqlite_hook.py @@ -17,6 +17,6 @@ class SqliteHook(DbApiHook): """ Returns a sqlite connection object """ - conn = self.get_connection(self.conn_id_name) + conn = self.get_connection(self.sqlite_conn_id) conn = sqlite3.connect(conn.host) return conn