incubator-airflow/tests/models/test_connection.py

277 строки
12 KiB
Python

# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from collections import namedtuple
from cryptography.fernet import Fernet
from mock import patch
from parameterized import parameterized
from airflow import models
from airflow.models.connection import Connection
ConnectionParts = namedtuple("ConnectionParts", ["conn_type", "login", "password", "host", "port", "schema"])
class ConnectionTest(unittest.TestCase):
def setUp(self):
models._fernet = None
def tearDown(self):
models._fernet = None
@patch('airflow.models.configuration.conf.get')
def test_connection_extra_no_encryption(self, mock_get):
"""
Tests extras on a new connection without encryption. The fernet key
is set to a non-base64-encoded string and the extra is stored without
encryption.
"""
mock_get.return_value = ''
test_connection = Connection(extra='testextra')
self.assertFalse(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')
@patch('airflow.models.configuration.conf.get')
def test_connection_extra_with_encryption(self, mock_get):
"""
Tests extras on a new connection with encryption.
"""
mock_get.return_value = Fernet.generate_key().decode()
test_connection = Connection(extra='testextra')
self.assertTrue(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')
@patch('airflow.models.configuration.conf.get')
def test_connection_extra_with_encryption_rotate_fernet_key(self, mock_get):
"""
Tests rotating encrypted extras.
"""
key1 = Fernet.generate_key()
key2 = Fernet.generate_key()
mock_get.return_value = key1.decode()
test_connection = Connection(extra='testextra')
self.assertTrue(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')
self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra')
# Test decrypt of old value with new key
mock_get.return_value = ','.join([key2.decode(), key1.decode()])
models._fernet = None
self.assertEqual(test_connection.extra, 'testextra')
# Test decrypt of new value with new key
test_connection.rotate_fernet_key()
self.assertTrue(test_connection.is_extra_encrypted)
self.assertEqual(test_connection.extra, 'testextra')
self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
def test_connection_from_uri_without_extras(self):
uri = 'scheme://user:password@host%2flocation:1234/schema'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password')
self.assertEqual(connection.port, 1234)
self.assertIsNone(connection.extra)
def test_connection_from_uri_with_extras(self):
uri = 'scheme://user:password@host%2flocation:1234/schema?' \
'extra1=a%20value&extra2=%2fpath%2f'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password')
self.assertEqual(connection.port, 1234)
self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
'extra2': '/path/'})
def test_connection_from_uri_with_colon_in_hostname(self):
uri = 'scheme://user:password@host%2flocation%3ax%3ay:1234/schema?' \
'extra1=a%20value&extra2=%2fpath%2f'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location:x:y')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password')
self.assertEqual(connection.port, 1234)
self.assertDictEqual(connection.extra_dejson, {'extra1': 'a value',
'extra2': '/path/'})
def test_connection_from_uri_with_encoded_password(self):
uri = 'scheme://user:password%20with%20space@host%2flocation%3ax%3ay:1234/schema'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location:x:y')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password with space')
self.assertEqual(connection.port, 1234)
def test_connection_from_uri_with_encoded_user(self):
uri = 'scheme://domain%2fuser:password@host%2flocation%3ax%3ay:1234/schema'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host/location:x:y')
self.assertEqual(connection.schema, 'schema')
self.assertEqual(connection.login, 'domain/user')
self.assertEqual(connection.password, 'password')
self.assertEqual(connection.port, 1234)
def test_connection_from_uri_with_encoded_schema(self):
uri = 'scheme://user:password%20with%20space@host:1234/schema%2ftest'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host')
self.assertEqual(connection.schema, 'schema/test')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password with space')
self.assertEqual(connection.port, 1234)
def test_connection_from_uri_no_schema(self):
uri = 'scheme://user:password%20with%20space@host:1234'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host')
self.assertEqual(connection.schema, '')
self.assertEqual(connection.login, 'user')
self.assertEqual(connection.password, 'password with space')
self.assertEqual(connection.port, 1234)
def test_connection_from_uri_with_underscore(self):
uri = 'google-cloud-platform://?extra__google_cloud_platform__key_' \
'path=%2Fkeys%2Fkey.json&extra__google_cloud_platform__scope=' \
'https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&extra' \
'__google_cloud_platform__project=airflow'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'google_cloud_platform')
self.assertEqual(connection.host, '')
self.assertEqual(connection.schema, '')
self.assertEqual(connection.login, None)
self.assertEqual(connection.password, None)
self.assertEqual(connection.extra_dejson, dict(
extra__google_cloud_platform__key_path='/keys/key.json',
extra__google_cloud_platform__project='airflow',
extra__google_cloud_platform__scope='https://www.googleapis.com/'
'auth/cloud-platform'))
def test_connection_from_uri_without_authinfo(self):
uri = 'scheme://host:1234'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, 'host')
self.assertEqual(connection.schema, '')
self.assertEqual(connection.login, None)
self.assertEqual(connection.password, None)
self.assertEqual(connection.port, 1234)
def test_connection_from_uri_with_path(self):
uri = 'scheme://%2FTmP%2F:1234'
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, 'scheme')
self.assertEqual(connection.host, '/TmP/')
self.assertEqual(connection.schema, '')
self.assertEqual(connection.login, None)
self.assertEqual(connection.password, None)
self.assertEqual(connection.port, 1234)
@parameterized.expand(
[
(
"http://:password@host:80/database",
ConnectionParts(
conn_type="http", login='', password="password", host="host", port=80, schema="database"
),
),
(
"http://user:@host:80/database",
ConnectionParts(
conn_type="http", login="user", password=None, host="host", port=80, schema="database"
),
),
(
"http://user:password@/database",
ConnectionParts(
conn_type="http", login="user", password="password", host="", port=None, schema="database"
),
),
(
"http://user:password@host:80/",
ConnectionParts(
conn_type="http", login="user", password="password", host="host", port=80, schema=""
),
),
(
"http://user:password@/",
ConnectionParts(
conn_type="http", login="user", password="password", host="", port=None, schema=""
),
),
(
"postgresql://user:password@%2Ftmp%2Fz6rqdzqh%2Fexample%3Awest1%3Atestdb/testdb",
ConnectionParts(
conn_type="postgres",
login="user",
password="password",
host="/tmp/z6rqdzqh/example:west1:testdb",
port=None,
schema="testdb",
),
),
(
"postgresql://user@%2Ftmp%2Fz6rqdzqh%2Fexample%3Aeurope-west1%3Atestdb/testdb",
ConnectionParts(
conn_type="postgres",
login="user",
password=None,
host="/tmp/z6rqdzqh/example:europe-west1:testdb",
port=None,
schema="testdb",
),
),
(
"postgresql://%2Ftmp%2Fz6rqdzqh%2Fexample%3Aeurope-west1%3Atestdb",
ConnectionParts(
conn_type="postgres",
login=None,
password=None,
host="/tmp/z6rqdzqh/example:europe-west1:testdb",
port=None,
schema="",
),
),
]
)
def test_connection_from_with_auth_info(self, uri, uri_parts):
connection = Connection(uri=uri)
self.assertEqual(connection.conn_type, uri_parts.conn_type)
self.assertEqual(connection.login, uri_parts.login)
self.assertEqual(connection.password, uri_parts.password)
self.assertEqual(connection.host, uri_parts.host)
self.assertEqual(connection.port, uri_parts.port)
self.assertEqual(connection.schema, uri_parts.schema)