541 строка
22 KiB
Python
541 строка
22 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import json
|
|
import re
|
|
import unittest
|
|
from collections import namedtuple
|
|
from unittest import mock
|
|
|
|
import sqlalchemy
|
|
from cryptography.fernet import Fernet
|
|
from parameterized import parameterized
|
|
|
|
from airflow import AirflowException
|
|
from airflow.hooks.base_hook import BaseHook
|
|
from airflow.models import Connection, crypto
|
|
from airflow.models.connection import CONN_TYPE_TO_HOOK
|
|
from airflow.providers.sqlite.hooks.sqlite import SqliteHook
|
|
from airflow.utils.module_loading import import_string
|
|
from tests.test_utils.config import conf_vars
|
|
|
|
ConnectionParts = namedtuple("ConnectionParts", ["conn_type", "login", "password", "host", "port", "schema"])
|
|
|
|
|
|
class UriTestCaseConfig:
|
|
def __init__(
|
|
self,
|
|
test_conn_uri: str,
|
|
test_conn_attributes: dict,
|
|
description: str,
|
|
):
|
|
"""
|
|
|
|
:param test_conn_uri: URI that we use to create connection
|
|
:param test_conn_attributes: we expect a connection object created with `test_uri` to have these
|
|
attributes
|
|
:param description: human-friendly name appended to parameterized test
|
|
"""
|
|
self.test_uri = test_conn_uri
|
|
self.test_conn_attributes = test_conn_attributes
|
|
self.description = description
|
|
|
|
@staticmethod
|
|
def uri_test_name(func, num, param):
|
|
return "{}_{}_{}".format(func.__name__, num, param.args[0].description.replace(' ', '_'))
|
|
|
|
|
|
class TestConnection(unittest.TestCase):
|
|
def setUp(self):
|
|
crypto._fernet = None
|
|
|
|
def tearDown(self):
|
|
crypto._fernet = None
|
|
|
|
@conf_vars({('core', 'fernet_key'): ''})
|
|
def test_connection_extra_no_encryption(self):
|
|
"""
|
|
Tests extras on a new connection without encryption. The fernet key
|
|
is set to a non-base64-encoded string and the extra is stored without
|
|
encryption.
|
|
"""
|
|
test_connection = Connection(extra='testextra')
|
|
self.assertFalse(test_connection.is_extra_encrypted)
|
|
self.assertEqual(test_connection.extra, 'testextra')
|
|
|
|
@conf_vars({('core', 'fernet_key'): Fernet.generate_key().decode()})
|
|
def test_connection_extra_with_encryption(self):
|
|
"""
|
|
Tests extras on a new connection with encryption.
|
|
"""
|
|
test_connection = Connection(extra='testextra')
|
|
self.assertTrue(test_connection.is_extra_encrypted)
|
|
self.assertEqual(test_connection.extra, 'testextra')
|
|
|
|
def test_connection_extra_with_encryption_rotate_fernet_key(self):
|
|
"""
|
|
Tests rotating encrypted extras.
|
|
"""
|
|
key1 = Fernet.generate_key()
|
|
key2 = Fernet.generate_key()
|
|
|
|
with conf_vars({('core', 'fernet_key'): key1.decode()}):
|
|
test_connection = Connection(extra='testextra')
|
|
self.assertTrue(test_connection.is_extra_encrypted)
|
|
self.assertEqual(test_connection.extra, 'testextra')
|
|
self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra')
|
|
|
|
# Test decrypt of old value with new key
|
|
with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}):
|
|
crypto._fernet = None
|
|
self.assertEqual(test_connection.extra, 'testextra')
|
|
|
|
# Test decrypt of new value with new key
|
|
test_connection.rotate_fernet_key()
|
|
self.assertTrue(test_connection.is_extra_encrypted)
|
|
self.assertEqual(test_connection.extra, 'testextra')
|
|
self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
|
|
|
|
test_from_uri_params = [
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='host/location',
|
|
schema='schema',
|
|
login='user',
|
|
password='password',
|
|
port=1234,
|
|
extra=None,
|
|
),
|
|
description='without extras',
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?'
|
|
'extra1=a%20value&extra2=%2Fpath%2F',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='host/location',
|
|
schema='schema',
|
|
login='user',
|
|
password='password',
|
|
port=1234,
|
|
extra_dejson={'extra1': 'a value', 'extra2': '/path/'}
|
|
),
|
|
description='with extras'
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?extra1=a%20value&extra2=',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='host/location',
|
|
schema='schema',
|
|
login='user',
|
|
password='password',
|
|
port=1234,
|
|
extra_dejson={'extra1': 'a value', 'extra2': ''}
|
|
),
|
|
description='with empty extras'
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://user:password@host%2Flocation%3Ax%3Ay:1234/schema?'
|
|
'extra1=a%20value&extra2=%2Fpath%2F',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='host/location:x:y',
|
|
schema='schema',
|
|
login='user',
|
|
password='password',
|
|
port=1234,
|
|
extra_dejson={'extra1': 'a value', 'extra2': '/path/'},
|
|
),
|
|
description='with colon in hostname'
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://user:password%20with%20space@host%2Flocation%3Ax%3Ay:1234/schema',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='host/location:x:y',
|
|
schema='schema',
|
|
login='user',
|
|
password='password with space',
|
|
port=1234,
|
|
),
|
|
description='with encoded password'
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://domain%2Fuser:password@host%2Flocation%3Ax%3Ay:1234/schema',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='host/location:x:y',
|
|
schema='schema',
|
|
login='domain/user',
|
|
password='password',
|
|
port=1234,
|
|
),
|
|
description='with encoded user',
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://user:password%20with%20space@host:1234/schema%2Ftest',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='host',
|
|
schema='schema/test',
|
|
login='user',
|
|
password='password with space',
|
|
port=1234,
|
|
),
|
|
description='with encoded schema'
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://user:password%20with%20space@host:1234',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='host',
|
|
schema='',
|
|
login='user',
|
|
password='password with space',
|
|
port=1234,
|
|
),
|
|
description='no schema'
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='google-cloud-platform://?extra__google_cloud_platform__key_'
|
|
'path=%2Fkeys%2Fkey.json&extra__google_cloud_platform__scope='
|
|
'https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&extra'
|
|
'__google_cloud_platform__project=airflow',
|
|
test_conn_attributes=dict(
|
|
conn_type='google_cloud_platform',
|
|
host='',
|
|
schema='',
|
|
login=None,
|
|
password=None,
|
|
port=None,
|
|
extra_dejson=dict(
|
|
extra__google_cloud_platform__key_path='/keys/key.json',
|
|
extra__google_cloud_platform__scope='https://www.googleapis.com/auth/cloud-platform',
|
|
extra__google_cloud_platform__project='airflow',
|
|
)
|
|
),
|
|
description='with underscore',
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://host:1234',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='host',
|
|
schema='',
|
|
login=None,
|
|
password=None,
|
|
port=1234,
|
|
),
|
|
description='without auth info'
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://%2FTmP%2F:1234',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
host='/TmP/',
|
|
schema='',
|
|
login=None,
|
|
password=None,
|
|
port=1234,
|
|
),
|
|
description='with path',
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme:///airflow',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
schema='airflow',
|
|
),
|
|
description='schema only',
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://@:1234',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
port=1234,
|
|
),
|
|
description='port only',
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://:password%2F%21%40%23%24%25%5E%26%2A%28%29%7B%7D@',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
password='password/!@#$%^&*(){}',
|
|
),
|
|
description='password only',
|
|
),
|
|
UriTestCaseConfig(
|
|
test_conn_uri='scheme://login%2F%21%40%23%24%25%5E%26%2A%28%29%7B%7D@',
|
|
test_conn_attributes=dict(
|
|
conn_type='scheme',
|
|
login='login/!@#$%^&*(){}',
|
|
),
|
|
description='login only',
|
|
),
|
|
]
|
|
|
|
# pylint: disable=undefined-variable
|
|
@parameterized.expand([(x,) for x in test_from_uri_params], UriTestCaseConfig.uri_test_name)
|
|
def test_connection_from_uri(self, test_config: UriTestCaseConfig):
|
|
|
|
connection = Connection(uri=test_config.test_uri)
|
|
for conn_attr, expected_val in test_config.test_conn_attributes.items():
|
|
actual_val = getattr(connection, conn_attr)
|
|
if expected_val is None:
|
|
self.assertIsNone(expected_val)
|
|
if isinstance(expected_val, dict):
|
|
self.assertDictEqual(expected_val, actual_val)
|
|
else:
|
|
self.assertEqual(expected_val, actual_val)
|
|
|
|
# pylint: disable=undefined-variable
|
|
@parameterized.expand([(x,) for x in test_from_uri_params], UriTestCaseConfig.uri_test_name)
|
|
def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig):
|
|
"""
|
|
This test verifies that when we create a conn_1 from URI, and we generate a URI from that conn, that
|
|
when we create a conn_2 from the generated URI, we get an equivalent conn.
|
|
1. Parse URI to create `Connection` object, `connection`.
|
|
2. Using this connection, generate URI `generated_uri`..
|
|
3. Using this`generated_uri`, parse and create new Connection `new_conn`.
|
|
4. Verify that `new_conn` has same attributes as `connection`.
|
|
"""
|
|
connection = Connection(uri=test_config.test_uri)
|
|
generated_uri = connection.get_uri()
|
|
new_conn = Connection(uri=generated_uri)
|
|
self.assertEqual(connection.conn_type, new_conn.conn_type)
|
|
self.assertEqual(connection.login, new_conn.login)
|
|
self.assertEqual(connection.password, new_conn.password)
|
|
self.assertEqual(connection.host, new_conn.host)
|
|
self.assertEqual(connection.port, new_conn.port)
|
|
self.assertEqual(connection.schema, new_conn.schema)
|
|
self.assertDictEqual(connection.extra_dejson, new_conn.extra_dejson)
|
|
|
|
# pylint: disable=undefined-variable
|
|
@parameterized.expand([(x,) for x in test_from_uri_params], UriTestCaseConfig.uri_test_name)
|
|
def test_connection_get_uri_from_conn(self, test_config: UriTestCaseConfig):
|
|
"""
|
|
This test verifies that if we create conn_1 from attributes (rather than from URI), and we generate a
|
|
URI, that when we create conn_2 from this URI, we get an equivalent conn.
|
|
1. Build conn init params using `test_conn_attributes` and store in `conn_kwargs`
|
|
2. Instantiate conn `connection` from `conn_kwargs`.
|
|
3. Generate uri `get_uri` from this conn.
|
|
4. Create conn `new_conn` from this uri.
|
|
5. Verify `new_conn` has same attributes as `connection`.
|
|
"""
|
|
conn_kwargs = {}
|
|
for k, v in test_config.test_conn_attributes.items():
|
|
if k == 'extra_dejson':
|
|
conn_kwargs.update({'extra': json.dumps(v)})
|
|
else:
|
|
conn_kwargs.update({k: v})
|
|
|
|
connection = Connection(conn_id='test_conn', **conn_kwargs) # type: ignore
|
|
gen_uri = connection.get_uri()
|
|
new_conn = Connection(conn_id='test_conn', uri=gen_uri)
|
|
for conn_attr, expected_val in test_config.test_conn_attributes.items():
|
|
actual_val = getattr(new_conn, conn_attr)
|
|
if expected_val is None:
|
|
self.assertIsNone(expected_val)
|
|
if isinstance(expected_val, dict):
|
|
self.assertDictEqual(expected_val, actual_val)
|
|
else:
|
|
self.assertEqual(expected_val, actual_val)
|
|
|
|
@parameterized.expand(
|
|
[
|
|
(
|
|
"http://:password@host:80/database",
|
|
ConnectionParts(
|
|
conn_type="http", login='', password="password", host="host", port=80, schema="database"
|
|
),
|
|
),
|
|
(
|
|
"http://user:@host:80/database",
|
|
ConnectionParts(
|
|
conn_type="http", login="user", password=None, host="host", port=80, schema="database"
|
|
),
|
|
),
|
|
(
|
|
"http://user:password@/database",
|
|
ConnectionParts(
|
|
conn_type="http", login="user", password="password", host="", port=None, schema="database"
|
|
),
|
|
),
|
|
(
|
|
"http://user:password@host:80/",
|
|
ConnectionParts(
|
|
conn_type="http", login="user", password="password", host="host", port=80, schema=""
|
|
),
|
|
),
|
|
(
|
|
"http://user:password@/",
|
|
ConnectionParts(
|
|
conn_type="http", login="user", password="password", host="", port=None, schema=""
|
|
),
|
|
),
|
|
(
|
|
"postgresql://user:password@%2Ftmp%2Fz6rqdzqh%2Fexample%3Awest1%3Atestdb/testdb",
|
|
ConnectionParts(
|
|
conn_type="postgres",
|
|
login="user",
|
|
password="password",
|
|
host="/tmp/z6rqdzqh/example:west1:testdb",
|
|
port=None,
|
|
schema="testdb",
|
|
),
|
|
),
|
|
(
|
|
"postgresql://user@%2Ftmp%2Fz6rqdzqh%2Fexample%3Aeurope-west1%3Atestdb/testdb",
|
|
ConnectionParts(
|
|
conn_type="postgres",
|
|
login="user",
|
|
password=None,
|
|
host="/tmp/z6rqdzqh/example:europe-west1:testdb",
|
|
port=None,
|
|
schema="testdb",
|
|
),
|
|
),
|
|
(
|
|
"postgresql://%2Ftmp%2Fz6rqdzqh%2Fexample%3Aeurope-west1%3Atestdb",
|
|
ConnectionParts(
|
|
conn_type="postgres",
|
|
login=None,
|
|
password=None,
|
|
host="/tmp/z6rqdzqh/example:europe-west1:testdb",
|
|
port=None,
|
|
schema="",
|
|
),
|
|
),
|
|
]
|
|
)
|
|
def test_connection_from_with_auth_info(self, uri, uri_parts):
|
|
connection = Connection(uri=uri)
|
|
|
|
self.assertEqual(connection.conn_type, uri_parts.conn_type)
|
|
self.assertEqual(connection.login, uri_parts.login)
|
|
self.assertEqual(connection.password, uri_parts.password)
|
|
self.assertEqual(connection.host, uri_parts.host)
|
|
self.assertEqual(connection.port, uri_parts.port)
|
|
self.assertEqual(connection.schema, uri_parts.schema)
|
|
|
|
@mock.patch.dict('os.environ', {
|
|
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
|
|
})
|
|
def test_using_env_var(self):
|
|
conn = SqliteHook.get_connection(conn_id='test_uri')
|
|
self.assertEqual('ec2.compute.com', conn.host)
|
|
self.assertEqual('the_database', conn.schema)
|
|
self.assertEqual('username', conn.login)
|
|
self.assertEqual('password', conn.password)
|
|
self.assertEqual(5432, conn.port)
|
|
|
|
@mock.patch.dict('os.environ', {
|
|
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
|
|
})
|
|
def test_using_unix_socket_env_var(self):
|
|
conn = SqliteHook.get_connection(conn_id='test_uri_no_creds')
|
|
self.assertEqual('ec2.compute.com', conn.host)
|
|
self.assertEqual('the_database', conn.schema)
|
|
self.assertIsNone(conn.login)
|
|
self.assertIsNone(conn.password)
|
|
self.assertIsNone(conn.port)
|
|
|
|
def test_param_setup(self):
|
|
conn = Connection(conn_id='local_mysql', conn_type='mysql',
|
|
host='localhost', login='airflow',
|
|
password='airflow', schema='airflow')
|
|
self.assertEqual('localhost', conn.host)
|
|
self.assertEqual('airflow', conn.schema)
|
|
self.assertEqual('airflow', conn.login)
|
|
self.assertEqual('airflow', conn.password)
|
|
self.assertIsNone(conn.port)
|
|
|
|
def test_env_var_priority(self):
|
|
conn = SqliteHook.get_connection(conn_id='airflow_db')
|
|
self.assertNotEqual('ec2.compute.com', conn.host)
|
|
|
|
with mock.patch.dict('os.environ', {
|
|
'AIRFLOW_CONN_AIRFLOW_DB': 'postgres://username:password@ec2.compute.com:5432/the_database',
|
|
}):
|
|
conn = SqliteHook.get_connection(conn_id='airflow_db')
|
|
self.assertEqual('ec2.compute.com', conn.host)
|
|
self.assertEqual('the_database', conn.schema)
|
|
self.assertEqual('username', conn.login)
|
|
self.assertEqual('password', conn.password)
|
|
self.assertEqual(5432, conn.port)
|
|
|
|
@mock.patch.dict('os.environ', {
|
|
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
|
|
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
|
|
})
|
|
def test_dbapi_get_uri(self):
|
|
conn = BaseHook.get_connection(conn_id='test_uri')
|
|
hook = conn.get_hook()
|
|
self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', hook.get_uri())
|
|
conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds')
|
|
hook2 = conn2.get_hook()
|
|
self.assertEqual('postgres://ec2.compute.com/the_database', hook2.get_uri())
|
|
|
|
@mock.patch.dict('os.environ', {
|
|
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
|
|
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
|
|
})
|
|
def test_dbapi_get_sqlalchemy_engine(self):
|
|
conn = BaseHook.get_connection(conn_id='test_uri')
|
|
hook = conn.get_hook()
|
|
engine = hook.get_sqlalchemy_engine()
|
|
self.assertIsInstance(engine, sqlalchemy.engine.Engine)
|
|
self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', str(engine.url))
|
|
|
|
@mock.patch.dict('os.environ', {
|
|
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
|
|
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
|
|
})
|
|
def test_get_connections_env_var(self):
|
|
conns = SqliteHook.get_connections(conn_id='test_uri')
|
|
assert len(conns) == 1
|
|
assert conns[0].host == 'ec2.compute.com'
|
|
assert conns[0].schema == 'the_database'
|
|
assert conns[0].login == 'username'
|
|
assert conns[0].password == 'password'
|
|
assert conns[0].port == 5432
|
|
|
|
def test_connection_mixed(self):
|
|
with self.assertRaisesRegex(
|
|
AirflowException,
|
|
re.escape(
|
|
"You must create an object using the URI or individual values (conn_type, host, login, "
|
|
"password, schema, port or extra).You can't mix these two ways to create this object."
|
|
)
|
|
):
|
|
Connection(conn_id="TEST_ID", uri="mysql://", schema="AAA")
|
|
|
|
|
|
class TestConnTypeToHook(unittest.TestCase):
|
|
def test_enforce_alphabetical_order(self):
|
|
current_keys = list(CONN_TYPE_TO_HOOK.keys())
|
|
expected_keys = sorted(current_keys)
|
|
|
|
self.assertEqual(expected_keys, current_keys)
|
|
|
|
def test_hooks_importable(self):
|
|
for hook_path, _ in CONN_TYPE_TO_HOOK.values():
|
|
self.assertTrue(issubclass(import_string(hook_path), BaseHook))
|