[AIRFLOW-2200] Add snowflake operator with tests
This commit is contained in:
Родитель
9c0c4264c3
Коммит
c4ba1051a7
|
@ -0,0 +1,93 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import snowflake.connector
|
||||
|
||||
from airflow.hooks.dbapi_hook import DbApiHook
|
||||
|
||||
|
||||
class SnowflakeHook(DbApiHook):
|
||||
"""
|
||||
Interact with Snowflake.
|
||||
|
||||
get_sqlalchemy_engine() depends on snowflake-sqlalchemy
|
||||
|
||||
"""
|
||||
|
||||
conn_name_attr = 'snowflake_conn_id'
|
||||
default_conn_name = 'snowflake_default'
|
||||
supports_autocommit = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SnowflakeHook, self).__init__(*args, **kwargs)
|
||||
self.account = kwargs.pop("account", None)
|
||||
self.warehouse = kwargs.pop("warehouse", None)
|
||||
self.database = kwargs.pop("database", None)
|
||||
|
||||
def _get_conn_params(self):
|
||||
'''
|
||||
one method to fetch connection params as a dict
|
||||
used in get_uri() and get_connection()
|
||||
'''
|
||||
conn = self.get_connection(self.snowflake_conn_id)
|
||||
account = conn.extra_dejson.get('account', None)
|
||||
warehouse = conn.extra_dejson.get('warehouse', None)
|
||||
database = conn.extra_dejson.get('database', None)
|
||||
|
||||
conn_config = {
|
||||
"user": conn.login,
|
||||
"password": conn.password or '',
|
||||
"schema": conn.schema or '',
|
||||
"database": self.database or database or '',
|
||||
"account": self.account or account or '',
|
||||
"warehouse": self.warehouse or warehouse or ''
|
||||
}
|
||||
return conn_config
|
||||
|
||||
def get_uri(self):
|
||||
'''
|
||||
override DbApiHook get_uri method for get_sqlalchemy_engine()
|
||||
'''
|
||||
conn_config = self._get_conn_params()
|
||||
uri = 'snowflake://{user}:{password}@{account}/{database}/'
|
||||
uri += '{schema}?warehouse={warehouse}'
|
||||
return uri.format(
|
||||
**conn_config)
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Returns a snowflake.connection object
|
||||
"""
|
||||
conn_config = self._get_conn_params()
|
||||
conn = snowflake.connector.connect(**conn_config)
|
||||
return conn
|
||||
|
||||
def _get_aws_credentials(self):
|
||||
'''
|
||||
returns aws_access_key_id, aws_secret_access_key
|
||||
from extra
|
||||
|
||||
intended to be used by external import and export statements
|
||||
'''
|
||||
if self.snowflake_conn_id:
|
||||
connection_object = self.get_connection(self.snowflake_conn_id)
|
||||
if 'aws_secret_access_key' in connection_object.extra_dejson:
|
||||
aws_access_key_id = connection_object.extra_dejson.get(
|
||||
'aws_access_key_id')
|
||||
aws_secret_access_key = connection_object.extra_dejson.get(
|
||||
'aws_secret_access_key')
|
||||
return aws_access_key_id, aws_secret_access_key
|
||||
|
||||
def set_autocommit(self, conn, autocommit):
|
||||
conn.autocommit(autocommit)
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from airflow.contrib.hooks.snowflake_hook import SnowflakeHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class SnowflakeOperator(BaseOperator):
|
||||
"""
|
||||
Executes sql code in a Snowflake database
|
||||
|
||||
:param snowflake_conn_id: reference to specific snowflake connection id
|
||||
:type snowflake_conn_id: string
|
||||
:param sql: the sql code to be executed
|
||||
:type sql: Can receive a str representing a sql statement,
|
||||
a list of str (sql statements), or reference to a template file.
|
||||
Template reference are recognized by str ending in '.sql'
|
||||
:param warehouse: name of warehouse which overwrite defined
|
||||
one in connection
|
||||
:type warehouse: string
|
||||
:param database: name of database which overwrite defined one in connection
|
||||
:type database: string
|
||||
"""
|
||||
|
||||
template_fields = ('sql',)
|
||||
template_ext = ('.sql',)
|
||||
ui_color = '#ededed'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self, sql, snowflake_conn_id='snowflake_default', parameters=None,
|
||||
autocommit=True, warehouse=None, database=None, *args, **kwargs):
|
||||
super(SnowflakeOperator, self).__init__(*args, **kwargs)
|
||||
self.snowflake_conn_id = snowflake_conn_id
|
||||
self.sql = sql
|
||||
self.autocommit = autocommit
|
||||
self.parameters = parameters
|
||||
self.warehouse = warehouse
|
||||
self.database = database
|
||||
|
||||
def get_hook(self):
|
||||
return SnowflakeHook(snowflake_conn_id=self.snowflake_conn_id,
|
||||
warehouse=self.warehouse, database=self.database)
|
||||
|
||||
def execute(self, context):
|
||||
self.log.info('Executing: %s', self.sql)
|
||||
hook = self.get_hook()
|
||||
hook.run(
|
||||
self.sql,
|
||||
autocommit=self.autocommit,
|
||||
parameters=self.parameters)
|
|
@ -593,6 +593,7 @@ class Connection(Base, LoggingMixin):
|
|||
('databricks', 'Databricks',),
|
||||
('aws', 'Amazon Web Services',),
|
||||
('emr', 'Elastic MapReduce',),
|
||||
('snowflake', 'Snowflake',),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
|
9
setup.py
9
setup.py
|
@ -166,7 +166,8 @@ cloudant = ['cloudant>=0.5.9,<2.0'] # major update coming soon, clamp to 0.x
|
|||
redis = ['redis>=2.10.5']
|
||||
kubernetes = ['kubernetes>=3.0.0',
|
||||
'cryptography>=2.0.0']
|
||||
|
||||
snowflake = ['snowflake-connector-python>=1.5.2',
|
||||
'snowflake-sqlalchemy>=1.1.0']
|
||||
zendesk = ['zdesk']
|
||||
|
||||
all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid
|
||||
|
@ -191,7 +192,8 @@ devel_minreq = devel + kubernetes + mysql + doc + password + s3 + cgroups
|
|||
devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
|
||||
devel_all = (sendgrid + devel + all_dbs + doc + samba + s3 + slack + crypto + oracle +
|
||||
docker + ssh + kubernetes + celery + azure + redis + gcp_api + datadog +
|
||||
zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins + druid)
|
||||
zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins +
|
||||
druid + snowflake)
|
||||
|
||||
# Snakebite & Google Cloud Dataflow are not Python 3 compatible :'(
|
||||
if PY3:
|
||||
|
@ -298,7 +300,8 @@ def do_setup():
|
|||
'webhdfs': webhdfs,
|
||||
'jira': jira,
|
||||
'redis': redis,
|
||||
'kubernetes': kubernetes
|
||||
'kubernetes': kubernetes,
|
||||
'snowflake': snowflake
|
||||
},
|
||||
classifiers=[
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import mock
|
||||
import unittest
|
||||
|
||||
from airflow.contrib.hooks.snowflake_hook import SnowflakeHook
|
||||
|
||||
|
||||
class TestSnowflakeHook(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestSnowflakeHook, self).setUp()
|
||||
|
||||
self.cur = mock.MagicMock()
|
||||
self.conn = conn = mock.MagicMock()
|
||||
self.conn.cursor.return_value = self.cur
|
||||
|
||||
self.conn.login = 'user'
|
||||
self.conn.password = 'pw'
|
||||
self.conn.schema = 'public'
|
||||
self.conn.extra_dejson = {'database': 'db',
|
||||
'account': 'airflow',
|
||||
'warehouse': 'af_wh'}
|
||||
|
||||
class UnitTestSnowflakeHook(SnowflakeHook):
|
||||
conn_name_attr = 'snowflake_conn_id'
|
||||
|
||||
def get_conn(self):
|
||||
return conn
|
||||
|
||||
def get_connection(self, connection_id):
|
||||
return conn
|
||||
|
||||
self.db_hook = UnitTestSnowflakeHook()
|
||||
|
||||
def test_get_uri(self):
|
||||
uri_shouldbe = 'snowflake://user:pw@airflow/db/public?warehouse=af_wh'
|
||||
self.assertEqual(uri_shouldbe, self.db_hook.get_uri())
|
||||
|
||||
def test_get_conn_params(self):
|
||||
conn_params_shouldbe = {'user': 'user',
|
||||
'password': 'pw',
|
||||
'schema': 'public',
|
||||
'database': 'db',
|
||||
'account': 'airflow',
|
||||
'warehouse': 'af_wh'}
|
||||
self.assertEqual(conn_params_shouldbe, self.db_hook._get_conn_params())
|
||||
|
||||
def test_get_conn(self):
|
||||
self.assertEqual(self.db_hook.get_conn(), self.conn)
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
from airflow import DAG, configuration
|
||||
from airflow.utils import timezone
|
||||
|
||||
from airflow.contrib.operators.snowflake_operator import SnowflakeOperator
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
try:
|
||||
import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
|
||||
DEFAULT_DATE = timezone.datetime(2015, 1, 1)
|
||||
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
|
||||
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
|
||||
TEST_DAG_ID = 'unit_test_dag'
|
||||
LONG_MOCK_PATH = 'airflow.contrib.operators.snowflake_operator.'
|
||||
LONG_MOCK_PATH += 'SnowflakeOperator.get_hook'
|
||||
|
||||
|
||||
class TestSnowflakeOperator(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestSnowflakeOperator, self).setUp()
|
||||
configuration.load_test_config()
|
||||
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
|
||||
dag = DAG(TEST_DAG_ID, default_args=args)
|
||||
self.dag = dag
|
||||
|
||||
@mock.patch(LONG_MOCK_PATH)
|
||||
def test_snowflake_operator(self, mock_get_hook):
|
||||
sql = """
|
||||
CREATE TABLE IF NOT EXISTS test_airflow (
|
||||
dummy VARCHAR(50)
|
||||
);
|
||||
"""
|
||||
t = SnowflakeOperator(
|
||||
task_id='basic_snowflake',
|
||||
sql=sql,
|
||||
dag=self.dag)
|
||||
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
|
||||
ignore_ti_state=True)
|
Загрузка…
Ссылка в новой задаче