[AIRFLOW-2200] Add snowflake operator with tests

This commit is contained in:
devinXL8 2018-04-04 15:10:00 -04:00
Родитель 9c0c4264c3
Коммит c4ba1051a7
6 изменённых файлов: 288 добавлений и 3 удалений

Просмотреть файл

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import snowflake.connector
from airflow.hooks.dbapi_hook import DbApiHook
class SnowflakeHook(DbApiHook):
"""
Interact with Snowflake.
get_sqlalchemy_engine() depends on snowflake-sqlalchemy
"""
conn_name_attr = 'snowflake_conn_id'
default_conn_name = 'snowflake_default'
supports_autocommit = True
def __init__(self, *args, **kwargs):
super(SnowflakeHook, self).__init__(*args, **kwargs)
self.account = kwargs.pop("account", None)
self.warehouse = kwargs.pop("warehouse", None)
self.database = kwargs.pop("database", None)
def _get_conn_params(self):
'''
one method to fetch connection params as a dict
used in get_uri() and get_connection()
'''
conn = self.get_connection(self.snowflake_conn_id)
account = conn.extra_dejson.get('account', None)
warehouse = conn.extra_dejson.get('warehouse', None)
database = conn.extra_dejson.get('database', None)
conn_config = {
"user": conn.login,
"password": conn.password or '',
"schema": conn.schema or '',
"database": self.database or database or '',
"account": self.account or account or '',
"warehouse": self.warehouse or warehouse or ''
}
return conn_config
def get_uri(self):
'''
override DbApiHook get_uri method for get_sqlalchemy_engine()
'''
conn_config = self._get_conn_params()
uri = 'snowflake://{user}:{password}@{account}/{database}/'
uri += '{schema}?warehouse={warehouse}'
return uri.format(
**conn_config)
def get_conn(self):
"""
Returns a snowflake.connection object
"""
conn_config = self._get_conn_params()
conn = snowflake.connector.connect(**conn_config)
return conn
def _get_aws_credentials(self):
'''
returns aws_access_key_id, aws_secret_access_key
from extra
intended to be used by external import and export statements
'''
if self.snowflake_conn_id:
connection_object = self.get_connection(self.snowflake_conn_id)
if 'aws_secret_access_key' in connection_object.extra_dejson:
aws_access_key_id = connection_object.extra_dejson.get(
'aws_access_key_id')
aws_secret_access_key = connection_object.extra_dejson.get(
'aws_secret_access_key')
return aws_access_key_id, aws_secret_access_key
def set_autocommit(self, conn, autocommit):
conn.autocommit(autocommit)

Просмотреть файл

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow.contrib.hooks.snowflake_hook import SnowflakeHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
class SnowflakeOperator(BaseOperator):
"""
Executes sql code in a Snowflake database
:param snowflake_conn_id: reference to specific snowflake connection id
:type snowflake_conn_id: string
:param sql: the sql code to be executed
:type sql: Can receive a str representing a sql statement,
a list of str (sql statements), or reference to a template file.
Template reference are recognized by str ending in '.sql'
:param warehouse: name of warehouse which overwrite defined
one in connection
:type warehouse: string
:param database: name of database which overwrite defined one in connection
:type database: string
"""
template_fields = ('sql',)
template_ext = ('.sql',)
ui_color = '#ededed'
@apply_defaults
def __init__(
self, sql, snowflake_conn_id='snowflake_default', parameters=None,
autocommit=True, warehouse=None, database=None, *args, **kwargs):
super(SnowflakeOperator, self).__init__(*args, **kwargs)
self.snowflake_conn_id = snowflake_conn_id
self.sql = sql
self.autocommit = autocommit
self.parameters = parameters
self.warehouse = warehouse
self.database = database
def get_hook(self):
return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id,
warehouse=self.warehouse, database=self.database)
def execute(self, context):
self.log.info('Executing: %s', self.sql)
hook = self.get_hook()
hook.run(
self.sql,
autocommit=self.autocommit,
parameters=self.parameters)

Просмотреть файл

@ -593,6 +593,7 @@ class Connection(Base, LoggingMixin):
('databricks', 'Databricks',),
('aws', 'Amazon Web Services',),
('emr', 'Elastic MapReduce',),
('snowflake', 'Snowflake',),
]
def __init__(

Просмотреть файл

@ -166,7 +166,8 @@ cloudant = ['cloudant>=0.5.9,<2.0'] # major update coming soon, clamp to 0.x
redis = ['redis>=2.10.5']
kubernetes = ['kubernetes>=3.0.0',
'cryptography>=2.0.0']
snowflake = ['snowflake-connector-python>=1.5.2',
'snowflake-sqlalchemy>=1.1.0']
zendesk = ['zdesk']
all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid
@ -191,7 +192,8 @@ devel_minreq = devel + kubernetes + mysql + doc + password + s3 + cgroups
devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
devel_all = (sendgrid + devel + all_dbs + doc + samba + s3 + slack + crypto + oracle +
docker + ssh + kubernetes + celery + azure + redis + gcp_api + datadog +
zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins + druid)
zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins +
druid + snowflake)
# Snakebite & Google Cloud Dataflow are not Python 3 compatible :'(
if PY3:
@ -298,7 +300,8 @@ def do_setup():
'webhdfs': webhdfs,
'jira': jira,
'redis': redis,
'kubernetes': kubernetes
'kubernetes': kubernetes,
'snowflake': snowflake
},
classifiers=[
'Development Status :: 5 - Production/Stable',

Просмотреть файл

@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import mock
import unittest
from airflow.contrib.hooks.snowflake_hook import SnowflakeHook
class TestSnowflakeHook(unittest.TestCase):
def setUp(self):
super(TestSnowflakeHook, self).setUp()
self.cur = mock.MagicMock()
self.conn = conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
self.conn.login = 'user'
self.conn.password = 'pw'
self.conn.schema = 'public'
self.conn.extra_dejson = {'database': 'db',
'account': 'airflow',
'warehouse': 'af_wh'}
class UnitTestSnowflakeHook(SnowflakeHook):
conn_name_attr = 'snowflake_conn_id'
def get_conn(self):
return conn
def get_connection(self, connection_id):
return conn
self.db_hook = UnitTestSnowflakeHook()
def test_get_uri(self):
uri_shouldbe = 'snowflake://user:pw@airflow/db/public?warehouse=af_wh'
self.assertEqual(uri_shouldbe, self.db_hook.get_uri())
def test_get_conn_params(self):
conn_params_shouldbe = {'user': 'user',
'password': 'pw',
'schema': 'public',
'database': 'db',
'account': 'airflow',
'warehouse': 'af_wh'}
self.assertEqual(conn_params_shouldbe, self.db_hook._get_conn_params())
def test_get_conn(self):
self.assertEqual(self.db_hook.get_conn(), self.conn)

Просмотреть файл

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from airflow import DAG, configuration
from airflow.utils import timezone
from airflow.contrib.operators.snowflake_operator import SnowflakeOperator
try:
from unittest import mock
except ImportError:
try:
import mock
except ImportError:
mock = None
DEFAULT_DATE = timezone.datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
TEST_DAG_ID = 'unit_test_dag'
LONG_MOCK_PATH = 'airflow.contrib.operators.snowflake_operator.'
LONG_MOCK_PATH += 'SnowflakeOperator.get_hook'
class TestSnowflakeOperator(unittest.TestCase):
def setUp(self):
super(TestSnowflakeOperator, self).setUp()
configuration.load_test_config()
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
dag = DAG(TEST_DAG_ID, default_args=args)
self.dag = dag
@mock.patch(LONG_MOCK_PATH)
def test_snowflake_operator(self, mock_get_hook):
sql = """
CREATE TABLE IF NOT EXISTS test_airflow (
dummy VARCHAR(50)
);
"""
t = SnowflakeOperator(
task_id='basic_snowflake',
sql=sql,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
ignore_ti_state=True)