[AIRFLOW-5768] GCP cloud sql don't store ephemeral connection in db (#6440)

This commit is contained in:
dstandish 2019-11-14 08:40:29 -08:00 коммит произвёл Jarek Potiuk
Родитель 61775833f7
Коммит 776e24aa05
8 изменённых файлов: 159 добавлений и 331 удалений

Просмотреть файл

@ -934,57 +934,16 @@ class CloudSqlDatabaseHook(BaseHook):
instance_specification += "=tcp:" + str(self.sql_proxy_tcp_port)
return instance_specification
@provide_session
def create_connection(self, session: Optional[Session] = None) -> None:
def create_connection(self) -> Connection:
"""
Create connection in the Connection table, according to whether it uses
proxy, TCP, UNIX sockets, SSL. Connection ID will be randomly generated.
:param session: Session of the SQL Alchemy ORM (automatically generated with
decorator).
Create Connection object, according to whether it uses proxy, TCP, UNIX sockets, SSL.
Connection ID will be randomly generated.
"""
assert session is not None
connection = Connection(conn_id=self.db_conn_id)
uri = self._generate_connection_uri()
self.log.info("Creating connection %s", self.db_conn_id)
connection.parse_from_uri(uri)
session.add(connection)
session.commit()
@provide_session
def retrieve_connection(self, session: Optional[Session] = None) -> Optional[Connection]:
"""
Retrieves the dynamically created connection from the Connection table.
:param session: Session of the SQL Alchemy ORM (automatically generated with
decorator).
"""
assert session is not None
self.log.info("Retrieving connection %s", self.db_conn_id)
connections = session.query(Connection).filter(
Connection.conn_id == self.db_conn_id)
if connections.count():
return connections[0]
return None
@provide_session
def delete_connection(self, session: Optional[Session] = None) -> None:
"""
Delete the dynamically created connection from the Connection table.
:param session: Session of the SQL Alchemy ORM (automatically generated with
decorator).
"""
assert session is not None
self.log.info("Deleting connection %s", self.db_conn_id)
connections = session.query(Connection).filter(
Connection.conn_id == self.db_conn_id)
if connections.count():
connection = connections[0]
session.delete(connection)
session.commit()
else:
self.log.info("Connection was already deleted!")
return connection
def get_sqlproxy_runner(self) -> CloudSqlProxyRunner:
"""
@ -1006,17 +965,15 @@ class CloudSqlDatabaseHook(BaseHook):
gcp_conn_id=self.gcp_conn_id
)
def get_database_hook(self) -> Union[PostgresHook, MySqlHook]:
def get_database_hook(self, connection: Connection) -> Union[PostgresHook, MySqlHook]:
"""
Retrieve database hook. This is the actual Postgres or MySQL database hook
that uses proxy or connects directly to the Google Cloud SQL database.
"""
if self.database_type == 'postgres':
self.db_hook = PostgresHook(postgres_conn_id=self.db_conn_id,
schema=self.database)
self.db_hook = PostgresHook(connection=connection, schema=self.database)
else:
self.db_hook = MySqlHook(mysql_conn_id=self.db_conn_id,
schema=self.database)
self.db_hook = MySqlHook(connection=connection, schema=self.database)
return self.db_hook
def cleanup_database_hook(self) -> None:

Просмотреть файл

@ -835,13 +835,10 @@ class CloudSqlQueryOperator(BaseOperator):
'extra__google_cloud_platform__project')
)
hook.validate_ssl_certs()
hook.create_connection()
connection = hook.create_connection()
hook.validate_socket_path_length()
database_hook = hook.get_database_hook(connection=connection)
try:
hook.validate_socket_path_length()
database_hook = hook.get_database_hook()
try:
self._execute_query(hook, database_hook)
finally:
hook.cleanup_database_hook()
self._execute_query(hook, database_hook)
finally:
hook.delete_connection()
hook.cleanup_database_hook()

Просмотреть файл

@ -47,6 +47,7 @@ class MySqlHook(DbApiHook):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.schema = kwargs.pop("schema", None)
self.connection = kwargs.pop("connection", None)
def set_autocommit(self, conn, autocommit):
"""
@ -69,7 +70,7 @@ class MySqlHook(DbApiHook):
"""
Returns a mysql connection object
"""
conn = self.get_connection(self.mysql_conn_id)
conn = self.connection or self.get_connection(self.mysql_conn_id)
conn_config = {
"user": conn.login,

Просмотреть файл

@ -56,6 +56,7 @@ class PostgresHook(DbApiHook):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.schema = kwargs.pop("schema", None)
self.connection = kwargs.pop("connection", None)
def _get_cursor(self, raw_cursor):
_cursor = raw_cursor.lower()
@ -68,8 +69,9 @@ class PostgresHook(DbApiHook):
raise ValueError('Invalid cursor passed {}'.format(_cursor))
def get_conn(self):
conn_id = getattr(self, self.conn_name_attr)
conn = self.get_connection(conn_id)
conn = self.connection or self.get_connection(conn_id)
# check for authentication via AWS IAM
if conn.extra_dejson.get('iam', False):

Просмотреть файл

@ -27,7 +27,6 @@ from parameterized import parameterized
from airflow.exceptions import AirflowException
from airflow.gcp.hooks.cloud_sql import CloudSqlDatabaseHook, CloudSqlHook
from airflow.hooks.base_hook import BaseHook
from airflow.models import Connection
from tests.compat import PropertyMock, mock
from tests.gcp.utils.base_gcp_mock import (
@ -1066,23 +1065,6 @@ class TestCloudsqlDatabaseHook(unittest.TestCase):
err = cm.exception
self.assertIn("needs to be set in connection", str(err))
@mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection')
def test_cloudsql_database_hook_create_delete_connection(self, get_connection):
connection = Connection()
connection.parse_from_uri("http://user:password@host:80/database")
connection.set_extra(json.dumps({
"location": "test",
"instance": "instance",
"database_type": "postgres"
}))
get_connection.return_value = connection
hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection',
default_gcp_project_id='google_connection')
hook.create_connection()
self.assertIsNotNone(hook.retrieve_connection())
hook.delete_connection()
self.assertIsNone(hook.retrieve_connection())
@mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection')
def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connection):
connection = Connection()
@ -1095,14 +1077,10 @@ class TestCloudsqlDatabaseHook(unittest.TestCase):
get_connection.return_value = connection
hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection',
default_gcp_project_id='google_connection')
hook.create_connection()
try:
with self.assertRaises(AirflowException) as cm:
hook.get_sqlproxy_runner()
err = cm.exception
self.assertIn('Proxy runner can only be retrieved in case of use_proxy = True', str(err))
finally:
hook.delete_connection()
with self.assertRaises(AirflowException) as cm:
hook.get_sqlproxy_runner()
err = cm.exception
self.assertIn('Proxy runner can only be retrieved in case of use_proxy = True', str(err))
@mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection')
def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection):
@ -1119,11 +1097,8 @@ class TestCloudsqlDatabaseHook(unittest.TestCase):
hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection',
default_gcp_project_id='google_connection')
hook.create_connection()
try:
proxy_runner = hook.get_sqlproxy_runner()
self.assertIsNotNone(proxy_runner)
finally:
hook.delete_connection()
proxy_runner = hook.get_sqlproxy_runner()
self.assertIsNotNone(proxy_runner)
@mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection')
def test_cloudsql_database_hook_get_database_hook(self, get_connection):
@ -1137,28 +1112,13 @@ class TestCloudsqlDatabaseHook(unittest.TestCase):
get_connection.return_value = connection
hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection',
default_gcp_project_id='google_connection')
hook.create_connection()
try:
db_hook = hook.get_database_hook()
self.assertIsNotNone(db_hook)
finally:
hook.delete_connection()
connection = hook.create_connection()
db_hook = hook.get_database_hook(connection=connection)
self.assertIsNotNone(db_hook)
class TestCloudSqlDatabaseHook(unittest.TestCase):
@staticmethod
def _setup_connections(get_connections, uri):
gcp_connection = mock.MagicMock()
gcp_connection.extra_dejson = mock.MagicMock()
gcp_connection.extra_dejson.get.return_value = 'empty_project'
cloudsql_connection = Connection()
cloudsql_connection.parse_from_uri(uri)
cloudsql_connection2 = Connection()
cloudsql_connection2.parse_from_uri(uri)
get_connections.side_effect = [[gcp_connection], [cloudsql_connection],
[cloudsql_connection2]]
@mock.patch('airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook.get_connection')
def setUp(self, m):
super().setUp()
@ -1213,235 +1173,144 @@ class TestCloudSqlDatabaseHook(unittest.TestCase):
)
self.assertEqual(sqlproxy_runner.instance_specification, instance_spec)
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_not_too_long_unix_socket_path(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_not_too_long_unix_socket_path(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
"project_id=example-project&location=europe-west1&" \
"instance=" \
"test_db_with_longname_but_with_limit_of_UNIX_socket&" \
"use_proxy=True&sql_proxy_use_tcp=False"
self._setup_connections(get_connections, uri)
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('postgres', conn.conn_type)
self.assertEqual('testdb', conn.schema)
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('postgres', connection.conn_type)
self.assertEqual('testdb', connection.schema)
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_correct_parameters_postgres(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_correct_parameters_postgres(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
"project_id=example-project&location=europe-west1&instance=testdb&" \
"use_proxy=False&use_ssl=False"
self._setup_connections(get_connections, uri)
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('postgres', conn.conn_type)
self.assertEqual('127.0.0.1', conn.host)
self.assertEqual(3200, conn.port)
self.assertEqual('testdb', conn.schema)
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('postgres', connection.conn_type)
self.assertEqual('127.0.0.1', connection.host)
self.assertEqual(3200, connection.port)
self.assertEqual('testdb', connection.schema)
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_correct_parameters_postgres_ssl(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_correct_parameters_postgres_ssl(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
"project_id=example-project&location=europe-west1&instance=testdb&" \
"use_proxy=False&use_ssl=True&sslcert=/bin/bash&" \
"sslkey=/bin/bash&sslrootcert=/bin/bash"
self._setup_connections(get_connections, uri)
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('postgres', conn.conn_type)
self.assertEqual('127.0.0.1', conn.host)
self.assertEqual(3200, conn.port)
self.assertEqual('testdb', conn.schema)
self.assertEqual('/bin/bash', conn.extra_dejson['sslkey'])
self.assertEqual('/bin/bash', conn.extra_dejson['sslcert'])
self.assertEqual('/bin/bash', conn.extra_dejson['sslrootcert'])
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('postgres', connection.conn_type)
self.assertEqual('127.0.0.1', connection.host)
self.assertEqual(3200, connection.port)
self.assertEqual('testdb', connection.schema)
self.assertEqual('/bin/bash', connection.extra_dejson['sslkey'])
self.assertEqual('/bin/bash', connection.extra_dejson['sslcert'])
self.assertEqual('/bin/bash', connection.extra_dejson['sslrootcert'])
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
"project_id=example-project&location=europe-west1&instance=testdb&" \
"use_proxy=True&sql_proxy_use_tcp=False"
self._setup_connections(get_connections, uri)
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('postgres', conn.conn_type)
self.assertIn('/tmp', conn.host)
self.assertIn('example-project:europe-west1:testdb', conn.host)
self.assertIsNone(conn.port)
self.assertEqual('testdb', conn.schema)
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('postgres', connection.conn_type)
self.assertIn('/tmp', connection.host)
self.assertIn('example-project:europe-west1:testdb', connection.host)
self.assertIsNone(connection.port)
self.assertEqual('testdb', connection.schema)
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_correct_parameters_project_id_missing(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_correct_parameters_project_id_missing(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
"location=europe-west1&instance=testdb&" \
"use_proxy=False&use_ssl=False"
self._setup_connections(get_connections, uri)
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('mysql', conn.conn_type)
self.assertEqual('127.0.0.1', conn.host)
self.assertEqual(3200, conn.port)
self.assertEqual('testdb', conn.schema)
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('mysql', connection.conn_type)
self.assertEqual('127.0.0.1', connection.host)
self.assertEqual(3200, connection.port)
self.assertEqual('testdb', connection.schema)
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
"project_id=example-project&location=europe-west1&instance=testdb&" \
"use_proxy=True&sql_proxy_use_tcp=True"
self._setup_connections(get_connections, uri)
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('postgres', conn.conn_type)
self.assertEqual('127.0.0.1', conn.host)
self.assertNotEqual(3200, conn.port)
self.assertEqual('testdb', conn.schema)
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('postgres', connection.conn_type)
self.assertEqual('127.0.0.1', connection.host)
self.assertNotEqual(3200, connection.port)
self.assertEqual('testdb', connection.schema)
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_correct_parameters_mysql(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_correct_parameters_mysql(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
"project_id=example-project&location=europe-west1&instance=testdb&" \
"use_proxy=False&use_ssl=False"
self._setup_connections(get_connections, uri)
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('mysql', conn.conn_type)
self.assertEqual('127.0.0.1', conn.host)
self.assertEqual(3200, conn.port)
self.assertEqual('testdb', conn.schema)
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('mysql', connection.conn_type)
self.assertEqual('127.0.0.1', connection.host)
self.assertEqual(3200, connection.port)
self.assertEqual('testdb', connection.schema)
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_correct_parameters_mysql_ssl(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_correct_parameters_mysql_ssl(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
"project_id=example-project&location=europe-west1&instance=testdb&" \
"use_proxy=False&use_ssl=True&sslcert=/bin/bash&" \
"sslkey=/bin/bash&sslrootcert=/bin/bash"
self._setup_connections(get_connections, uri)
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('mysql', connection.conn_type)
self.assertEqual('127.0.0.1', connection.host)
self.assertEqual(3200, connection.port)
self.assertEqual('testdb', connection.schema)
self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['cert'])
self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['key'])
self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['ca'])
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('mysql', conn.conn_type)
self.assertEqual('127.0.0.1', conn.host)
self.assertEqual(3200, conn.port)
self.assertEqual('testdb', conn.schema)
self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['cert'])
self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['key'])
self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['ca'])
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
"project_id=example-project&location=europe-west1&instance=testdb&" \
"use_proxy=True&sql_proxy_use_tcp=False"
self._setup_connections(get_connections, uri)
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('mysql', conn.conn_type)
self.assertEqual('localhost', conn.host)
self.assertIn('/tmp', conn.extra_dejson['unix_socket'])
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('mysql', connection.conn_type)
self.assertEqual('localhost', connection.host)
self.assertIn('/tmp', connection.extra_dejson['unix_socket'])
self.assertIn('example-project:europe-west1:testdb',
conn.extra_dejson['unix_socket'])
self.assertIsNone(conn.port)
self.assertEqual('testdb', conn.schema)
connection.extra_dejson['unix_socket'])
self.assertIsNone(connection.port)
self.assertEqual('testdb', connection.schema)
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_hook_with_correct_parameters_mysql_tcp(self, get_connections):
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
def test_hook_with_correct_parameters_mysql_tcp(self, get_connection):
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
"project_id=example-project&location=europe-west1&instance=testdb&" \
"use_proxy=True&sql_proxy_use_tcp=True"
self._setup_connections(get_connections, uri)
gcp_conn_id = 'google_cloud_default'
hook = CloudSqlDatabaseHook(
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
'extra__google_cloud_platform__project')
)
hook.create_connection()
try:
db_hook = hook.get_database_hook()
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
finally:
hook.delete_connection()
self.assertEqual('mysql', conn.conn_type)
self.assertEqual('127.0.0.1', conn.host)
self.assertNotEqual(3200, conn.port)
self.assertEqual('testdb', conn.schema)
get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSqlDatabaseHook()
connection = hook.create_connection()
self.assertEqual('mysql', connection.conn_type)
self.assertEqual('127.0.0.1', connection.host)
self.assertNotEqual(3200, connection.port)
self.assertEqual('testdb', connection.schema)

Просмотреть файл

@ -19,7 +19,6 @@
# pylint: disable=too-many-lines
import json
import os
import unittest
@ -787,36 +786,3 @@ class TestCloudSqlQueryValidation(unittest.TestCase):
operator.execute(None)
err = cm.exception
self.assertIn("The UNIX socket path length cannot exceed", str(err))
@mock.patch("airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook."
"delete_connection")
@mock.patch("airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook."
"get_connection")
@mock.patch("airflow.hooks.mysql_hook.MySqlHook.run")
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
def test_cloudsql_hook_delete_connection_on_exception(
self, get_connections, run, get_connection, delete_connection):
connection = Connection()
connection.parse_from_uri(
"gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&"
"project_id=example-project&location=europe-west1&instance=testdb&"
"use_proxy=False")
get_connection.return_value = connection
db_connection = Connection()
db_connection.host = "127.0.0.1"
db_connection.set_extra(json.dumps({"project_id": "example-project",
"location": "europe-west1",
"instance": "testdb",
"database_type": "mysql"}))
get_connections.return_value = [db_connection]
run.side_effect = Exception("Exception when running a query")
operator = CloudSqlQueryOperator(
sql=['SELECT * FROM TABLE'],
task_id='task_id'
)
with self.assertRaises(Exception) as cm:
operator.execute(None)
err = cm.exception
self.assertEqual("Exception when running a query", str(err))
delete_connection.assert_called_once_with()

Просмотреть файл

@ -61,6 +61,24 @@ class TestMySqlHookConn(unittest.TestCase):
self.assertEqual(kwargs['host'], 'host')
self.assertEqual(kwargs['db'], 'schema')
@mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
def test_get_conn_from_connection(self, mock_connect):
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
hook = MySqlHook(connection=conn)
hook.get_conn()
mock_connect.assert_called_once_with(
user='login-conn', passwd='password-conn', host='host', db='schema', port=3306
)
@mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
def test_get_conn_from_connection_with_schema(self, mock_connect):
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
hook = MySqlHook(connection=conn, schema='schema-override')
hook.get_conn()
mock_connect.assert_called_once_with(
user='login-conn', passwd='password-conn', host='host', db='schema-override', port=3306
)
@mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
def test_get_conn_port(self, mock_connect):
self.connection.port = 3307

Просмотреть файл

@ -76,6 +76,24 @@ class TestPostgresHookConn(unittest.TestCase):
with self.assertRaises(ValueError):
self.db_hook.get_conn()
@mock.patch('airflow.hooks.postgres_hook.psycopg2.connect')
def test_get_conn_from_connection(self, mock_connect):
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
hook = PostgresHook(connection=conn)
hook.get_conn()
mock_connect.assert_called_once_with(
user='login-conn', password='password-conn', host='host', dbname='schema', port=None
)
@mock.patch('airflow.hooks.postgres_hook.psycopg2.connect')
def test_get_conn_from_connection_with_schema(self, mock_connect):
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
hook = PostgresHook(connection=conn, schema='schema-override')
hook.get_conn()
mock_connect.assert_called_once_with(
user='login-conn', password='password-conn', host='host', dbname='schema-override', port=None
)
@mock.patch('airflow.hooks.postgres_hook.psycopg2.connect')
@mock.patch('airflow.contrib.hooks.aws_hook.AwsHook.get_client_type')
def test_get_conn_rds_iam_postgres(self, mock_client, mock_connect):