[AIRFLOW-5768] GCP cloud sql don't store ephemeral connection in db (#6440)
This commit is contained in:
Родитель
61775833f7
Коммит
776e24aa05
|
@ -934,57 +934,16 @@ class CloudSqlDatabaseHook(BaseHook):
|
|||
instance_specification += "=tcp:" + str(self.sql_proxy_tcp_port)
|
||||
return instance_specification
|
||||
|
||||
@provide_session
|
||||
def create_connection(self, session: Optional[Session] = None) -> None:
|
||||
def create_connection(self) -> Connection:
|
||||
"""
|
||||
Create connection in the Connection table, according to whether it uses
|
||||
proxy, TCP, UNIX sockets, SSL. Connection ID will be randomly generated.
|
||||
|
||||
:param session: Session of the SQL Alchemy ORM (automatically generated with
|
||||
decorator).
|
||||
Create Connection object, according to whether it uses proxy, TCP, UNIX sockets, SSL.
|
||||
Connection ID will be randomly generated.
|
||||
"""
|
||||
assert session is not None
|
||||
connection = Connection(conn_id=self.db_conn_id)
|
||||
uri = self._generate_connection_uri()
|
||||
self.log.info("Creating connection %s", self.db_conn_id)
|
||||
connection.parse_from_uri(uri)
|
||||
session.add(connection)
|
||||
session.commit()
|
||||
|
||||
@provide_session
|
||||
def retrieve_connection(self, session: Optional[Session] = None) -> Optional[Connection]:
|
||||
"""
|
||||
Retrieves the dynamically created connection from the Connection table.
|
||||
|
||||
:param session: Session of the SQL Alchemy ORM (automatically generated with
|
||||
decorator).
|
||||
"""
|
||||
assert session is not None
|
||||
self.log.info("Retrieving connection %s", self.db_conn_id)
|
||||
connections = session.query(Connection).filter(
|
||||
Connection.conn_id == self.db_conn_id)
|
||||
if connections.count():
|
||||
return connections[0]
|
||||
return None
|
||||
|
||||
@provide_session
|
||||
def delete_connection(self, session: Optional[Session] = None) -> None:
|
||||
"""
|
||||
Delete the dynamically created connection from the Connection table.
|
||||
|
||||
:param session: Session of the SQL Alchemy ORM (automatically generated with
|
||||
decorator).
|
||||
"""
|
||||
assert session is not None
|
||||
self.log.info("Deleting connection %s", self.db_conn_id)
|
||||
connections = session.query(Connection).filter(
|
||||
Connection.conn_id == self.db_conn_id)
|
||||
if connections.count():
|
||||
connection = connections[0]
|
||||
session.delete(connection)
|
||||
session.commit()
|
||||
else:
|
||||
self.log.info("Connection was already deleted!")
|
||||
return connection
|
||||
|
||||
def get_sqlproxy_runner(self) -> CloudSqlProxyRunner:
|
||||
"""
|
||||
|
@ -1006,17 +965,15 @@ class CloudSqlDatabaseHook(BaseHook):
|
|||
gcp_conn_id=self.gcp_conn_id
|
||||
)
|
||||
|
||||
def get_database_hook(self) -> Union[PostgresHook, MySqlHook]:
|
||||
def get_database_hook(self, connection: Connection) -> Union[PostgresHook, MySqlHook]:
|
||||
"""
|
||||
Retrieve database hook. This is the actual Postgres or MySQL database hook
|
||||
that uses proxy or connects directly to the Google Cloud SQL database.
|
||||
"""
|
||||
if self.database_type == 'postgres':
|
||||
self.db_hook = PostgresHook(postgres_conn_id=self.db_conn_id,
|
||||
schema=self.database)
|
||||
self.db_hook = PostgresHook(connection=connection, schema=self.database)
|
||||
else:
|
||||
self.db_hook = MySqlHook(mysql_conn_id=self.db_conn_id,
|
||||
schema=self.database)
|
||||
self.db_hook = MySqlHook(connection=connection, schema=self.database)
|
||||
return self.db_hook
|
||||
|
||||
def cleanup_database_hook(self) -> None:
|
||||
|
|
|
@ -835,13 +835,10 @@ class CloudSqlQueryOperator(BaseOperator):
|
|||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.validate_ssl_certs()
|
||||
hook.create_connection()
|
||||
connection = hook.create_connection()
|
||||
hook.validate_socket_path_length()
|
||||
database_hook = hook.get_database_hook(connection=connection)
|
||||
try:
|
||||
hook.validate_socket_path_length()
|
||||
database_hook = hook.get_database_hook()
|
||||
try:
|
||||
self._execute_query(hook, database_hook)
|
||||
finally:
|
||||
hook.cleanup_database_hook()
|
||||
self._execute_query(hook, database_hook)
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
hook.cleanup_database_hook()
|
||||
|
|
|
@ -47,6 +47,7 @@ class MySqlHook(DbApiHook):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.schema = kwargs.pop("schema", None)
|
||||
self.connection = kwargs.pop("connection", None)
|
||||
|
||||
def set_autocommit(self, conn, autocommit):
|
||||
"""
|
||||
|
@ -69,7 +70,7 @@ class MySqlHook(DbApiHook):
|
|||
"""
|
||||
Returns a mysql connection object
|
||||
"""
|
||||
conn = self.get_connection(self.mysql_conn_id)
|
||||
conn = self.connection or self.get_connection(self.mysql_conn_id)
|
||||
|
||||
conn_config = {
|
||||
"user": conn.login,
|
||||
|
|
|
@ -56,6 +56,7 @@ class PostgresHook(DbApiHook):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.schema = kwargs.pop("schema", None)
|
||||
self.connection = kwargs.pop("connection", None)
|
||||
|
||||
def _get_cursor(self, raw_cursor):
|
||||
_cursor = raw_cursor.lower()
|
||||
|
@ -68,8 +69,9 @@ class PostgresHook(DbApiHook):
|
|||
raise ValueError('Invalid cursor passed {}'.format(_cursor))
|
||||
|
||||
def get_conn(self):
|
||||
|
||||
conn_id = getattr(self, self.conn_name_attr)
|
||||
conn = self.get_connection(conn_id)
|
||||
conn = self.connection or self.get_connection(conn_id)
|
||||
|
||||
# check for authentication via AWS IAM
|
||||
if conn.extra_dejson.get('iam', False):
|
||||
|
|
|
@ -27,7 +27,6 @@ from parameterized import parameterized
|
|||
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.gcp.hooks.cloud_sql import CloudSqlDatabaseHook, CloudSqlHook
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
from airflow.models import Connection
|
||||
from tests.compat import PropertyMock, mock
|
||||
from tests.gcp.utils.base_gcp_mock import (
|
||||
|
@ -1066,23 +1065,6 @@ class TestCloudsqlDatabaseHook(unittest.TestCase):
|
|||
err = cm.exception
|
||||
self.assertIn("needs to be set in connection", str(err))
|
||||
|
||||
@mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection')
|
||||
def test_cloudsql_database_hook_create_delete_connection(self, get_connection):
|
||||
connection = Connection()
|
||||
connection.parse_from_uri("http://user:password@host:80/database")
|
||||
connection.set_extra(json.dumps({
|
||||
"location": "test",
|
||||
"instance": "instance",
|
||||
"database_type": "postgres"
|
||||
}))
|
||||
get_connection.return_value = connection
|
||||
hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection',
|
||||
default_gcp_project_id='google_connection')
|
||||
hook.create_connection()
|
||||
self.assertIsNotNone(hook.retrieve_connection())
|
||||
hook.delete_connection()
|
||||
self.assertIsNone(hook.retrieve_connection())
|
||||
|
||||
@mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection')
|
||||
def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, get_connection):
|
||||
connection = Connection()
|
||||
|
@ -1095,14 +1077,10 @@ class TestCloudsqlDatabaseHook(unittest.TestCase):
|
|||
get_connection.return_value = connection
|
||||
hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection',
|
||||
default_gcp_project_id='google_connection')
|
||||
hook.create_connection()
|
||||
try:
|
||||
with self.assertRaises(AirflowException) as cm:
|
||||
hook.get_sqlproxy_runner()
|
||||
err = cm.exception
|
||||
self.assertIn('Proxy runner can only be retrieved in case of use_proxy = True', str(err))
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
with self.assertRaises(AirflowException) as cm:
|
||||
hook.get_sqlproxy_runner()
|
||||
err = cm.exception
|
||||
self.assertIn('Proxy runner can only be retrieved in case of use_proxy = True', str(err))
|
||||
|
||||
@mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection')
|
||||
def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection):
|
||||
|
@ -1119,11 +1097,8 @@ class TestCloudsqlDatabaseHook(unittest.TestCase):
|
|||
hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection',
|
||||
default_gcp_project_id='google_connection')
|
||||
hook.create_connection()
|
||||
try:
|
||||
proxy_runner = hook.get_sqlproxy_runner()
|
||||
self.assertIsNotNone(proxy_runner)
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
proxy_runner = hook.get_sqlproxy_runner()
|
||||
self.assertIsNotNone(proxy_runner)
|
||||
|
||||
@mock.patch('airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection')
|
||||
def test_cloudsql_database_hook_get_database_hook(self, get_connection):
|
||||
|
@ -1137,28 +1112,13 @@ class TestCloudsqlDatabaseHook(unittest.TestCase):
|
|||
get_connection.return_value = connection
|
||||
hook = CloudSqlDatabaseHook(gcp_cloudsql_conn_id='cloudsql_connection',
|
||||
default_gcp_project_id='google_connection')
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
self.assertIsNotNone(db_hook)
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
connection = hook.create_connection()
|
||||
db_hook = hook.get_database_hook(connection=connection)
|
||||
self.assertIsNotNone(db_hook)
|
||||
|
||||
|
||||
class TestCloudSqlDatabaseHook(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def _setup_connections(get_connections, uri):
|
||||
gcp_connection = mock.MagicMock()
|
||||
gcp_connection.extra_dejson = mock.MagicMock()
|
||||
gcp_connection.extra_dejson.get.return_value = 'empty_project'
|
||||
cloudsql_connection = Connection()
|
||||
cloudsql_connection.parse_from_uri(uri)
|
||||
cloudsql_connection2 = Connection()
|
||||
cloudsql_connection2.parse_from_uri(uri)
|
||||
get_connections.side_effect = [[gcp_connection], [cloudsql_connection],
|
||||
[cloudsql_connection2]]
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook.get_connection')
|
||||
def setUp(self, m):
|
||||
super().setUp()
|
||||
|
@ -1213,235 +1173,144 @@ class TestCloudSqlDatabaseHook(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(sqlproxy_runner.instance_specification, instance_spec)
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_not_too_long_unix_socket_path(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_not_too_long_unix_socket_path(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
|
||||
"project_id=example-project&location=europe-west1&" \
|
||||
"instance=" \
|
||||
"test_db_with_longname_but_with_limit_of_UNIX_socket&" \
|
||||
"use_proxy=True&sql_proxy_use_tcp=False"
|
||||
self._setup_connections(get_connections, uri)
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('postgres', conn.conn_type)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('postgres', connection.conn_type)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_correct_parameters_postgres(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_correct_parameters_postgres(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
|
||||
"project_id=example-project&location=europe-west1&instance=testdb&" \
|
||||
"use_proxy=False&use_ssl=False"
|
||||
self._setup_connections(get_connections, uri)
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('postgres', conn.conn_type)
|
||||
self.assertEqual('127.0.0.1', conn.host)
|
||||
self.assertEqual(3200, conn.port)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('postgres', connection.conn_type)
|
||||
self.assertEqual('127.0.0.1', connection.host)
|
||||
self.assertEqual(3200, connection.port)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_correct_parameters_postgres_ssl(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_correct_parameters_postgres_ssl(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
|
||||
"project_id=example-project&location=europe-west1&instance=testdb&" \
|
||||
"use_proxy=False&use_ssl=True&sslcert=/bin/bash&" \
|
||||
"sslkey=/bin/bash&sslrootcert=/bin/bash"
|
||||
self._setup_connections(get_connections, uri)
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('postgres', conn.conn_type)
|
||||
self.assertEqual('127.0.0.1', conn.host)
|
||||
self.assertEqual(3200, conn.port)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
self.assertEqual('/bin/bash', conn.extra_dejson['sslkey'])
|
||||
self.assertEqual('/bin/bash', conn.extra_dejson['sslcert'])
|
||||
self.assertEqual('/bin/bash', conn.extra_dejson['sslrootcert'])
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('postgres', connection.conn_type)
|
||||
self.assertEqual('127.0.0.1', connection.host)
|
||||
self.assertEqual(3200, connection.port)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
self.assertEqual('/bin/bash', connection.extra_dejson['sslkey'])
|
||||
self.assertEqual('/bin/bash', connection.extra_dejson['sslcert'])
|
||||
self.assertEqual('/bin/bash', connection.extra_dejson['sslrootcert'])
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
|
||||
"project_id=example-project&location=europe-west1&instance=testdb&" \
|
||||
"use_proxy=True&sql_proxy_use_tcp=False"
|
||||
self._setup_connections(get_connections, uri)
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('postgres', conn.conn_type)
|
||||
self.assertIn('/tmp', conn.host)
|
||||
self.assertIn('example-project:europe-west1:testdb', conn.host)
|
||||
self.assertIsNone(conn.port)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('postgres', connection.conn_type)
|
||||
self.assertIn('/tmp', connection.host)
|
||||
self.assertIn('example-project:europe-west1:testdb', connection.host)
|
||||
self.assertIsNone(connection.port)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_correct_parameters_project_id_missing(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_correct_parameters_project_id_missing(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
|
||||
"location=europe-west1&instance=testdb&" \
|
||||
"use_proxy=False&use_ssl=False"
|
||||
self._setup_connections(get_connections, uri)
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('mysql', conn.conn_type)
|
||||
self.assertEqual('127.0.0.1', conn.host)
|
||||
self.assertEqual(3200, conn.port)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('mysql', connection.conn_type)
|
||||
self.assertEqual('127.0.0.1', connection.host)
|
||||
self.assertEqual(3200, connection.port)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_correct_parameters_postgres_proxy_tcp(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=postgres&" \
|
||||
"project_id=example-project&location=europe-west1&instance=testdb&" \
|
||||
"use_proxy=True&sql_proxy_use_tcp=True"
|
||||
self._setup_connections(get_connections, uri)
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('postgres', conn.conn_type)
|
||||
self.assertEqual('127.0.0.1', conn.host)
|
||||
self.assertNotEqual(3200, conn.port)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('postgres', connection.conn_type)
|
||||
self.assertEqual('127.0.0.1', connection.host)
|
||||
self.assertNotEqual(3200, connection.port)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_correct_parameters_mysql(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_correct_parameters_mysql(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
|
||||
"project_id=example-project&location=europe-west1&instance=testdb&" \
|
||||
"use_proxy=False&use_ssl=False"
|
||||
self._setup_connections(get_connections, uri)
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('mysql', conn.conn_type)
|
||||
self.assertEqual('127.0.0.1', conn.host)
|
||||
self.assertEqual(3200, conn.port)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('mysql', connection.conn_type)
|
||||
self.assertEqual('127.0.0.1', connection.host)
|
||||
self.assertEqual(3200, connection.port)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_correct_parameters_mysql_ssl(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_correct_parameters_mysql_ssl(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
|
||||
"project_id=example-project&location=europe-west1&instance=testdb&" \
|
||||
"use_proxy=False&use_ssl=True&sslcert=/bin/bash&" \
|
||||
"sslkey=/bin/bash&sslrootcert=/bin/bash"
|
||||
self._setup_connections(get_connections, uri)
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('mysql', connection.conn_type)
|
||||
self.assertEqual('127.0.0.1', connection.host)
|
||||
self.assertEqual(3200, connection.port)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['cert'])
|
||||
self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['key'])
|
||||
self.assertEqual('/bin/bash', json.loads(connection.extra_dejson['ssl'])['ca'])
|
||||
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('mysql', conn.conn_type)
|
||||
self.assertEqual('127.0.0.1', conn.host)
|
||||
self.assertEqual(3200, conn.port)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['cert'])
|
||||
self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['key'])
|
||||
self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['ca'])
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
|
||||
"project_id=example-project&location=europe-west1&instance=testdb&" \
|
||||
"use_proxy=True&sql_proxy_use_tcp=False"
|
||||
self._setup_connections(get_connections, uri)
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('mysql', conn.conn_type)
|
||||
self.assertEqual('localhost', conn.host)
|
||||
self.assertIn('/tmp', conn.extra_dejson['unix_socket'])
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('mysql', connection.conn_type)
|
||||
self.assertEqual('localhost', connection.host)
|
||||
self.assertIn('/tmp', connection.extra_dejson['unix_socket'])
|
||||
self.assertIn('example-project:europe-west1:testdb',
|
||||
conn.extra_dejson['unix_socket'])
|
||||
self.assertIsNone(conn.port)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
connection.extra_dejson['unix_socket'])
|
||||
self.assertIsNone(connection.port)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_hook_with_correct_parameters_mysql_tcp(self, get_connections):
|
||||
@mock.patch("airflow.gcp.hooks.cloud_sql.CloudSqlDatabaseHook.get_connection")
|
||||
def test_hook_with_correct_parameters_mysql_tcp(self, get_connection):
|
||||
uri = "gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&" \
|
||||
"project_id=example-project&location=europe-west1&instance=testdb&" \
|
||||
"use_proxy=True&sql_proxy_use_tcp=True"
|
||||
self._setup_connections(get_connections, uri)
|
||||
gcp_conn_id = 'google_cloud_default'
|
||||
hook = CloudSqlDatabaseHook(
|
||||
default_gcp_project_id=BaseHook.get_connection(gcp_conn_id).extra_dejson.get(
|
||||
'extra__google_cloud_platform__project')
|
||||
)
|
||||
hook.create_connection()
|
||||
try:
|
||||
db_hook = hook.get_database_hook()
|
||||
conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member
|
||||
finally:
|
||||
hook.delete_connection()
|
||||
self.assertEqual('mysql', conn.conn_type)
|
||||
self.assertEqual('127.0.0.1', conn.host)
|
||||
self.assertNotEqual(3200, conn.port)
|
||||
self.assertEqual('testdb', conn.schema)
|
||||
get_connection.side_effect = [Connection(uri=uri)]
|
||||
hook = CloudSqlDatabaseHook()
|
||||
connection = hook.create_connection()
|
||||
self.assertEqual('mysql', connection.conn_type)
|
||||
self.assertEqual('127.0.0.1', connection.host)
|
||||
self.assertNotEqual(3200, connection.port)
|
||||
self.assertEqual('testdb', connection.schema)
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
|
@ -787,36 +786,3 @@ class TestCloudSqlQueryValidation(unittest.TestCase):
|
|||
operator.execute(None)
|
||||
err = cm.exception
|
||||
self.assertIn("The UNIX socket path length cannot exceed", str(err))
|
||||
|
||||
@mock.patch("airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook."
|
||||
"delete_connection")
|
||||
@mock.patch("airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook."
|
||||
"get_connection")
|
||||
@mock.patch("airflow.hooks.mysql_hook.MySqlHook.run")
|
||||
@mock.patch("airflow.hooks.base_hook.BaseHook.get_connections")
|
||||
def test_cloudsql_hook_delete_connection_on_exception(
|
||||
self, get_connections, run, get_connection, delete_connection):
|
||||
connection = Connection()
|
||||
connection.parse_from_uri(
|
||||
"gcpcloudsql://user:password@127.0.0.1:3200/testdb?database_type=mysql&"
|
||||
"project_id=example-project&location=europe-west1&instance=testdb&"
|
||||
"use_proxy=False")
|
||||
get_connection.return_value = connection
|
||||
|
||||
db_connection = Connection()
|
||||
db_connection.host = "127.0.0.1"
|
||||
db_connection.set_extra(json.dumps({"project_id": "example-project",
|
||||
"location": "europe-west1",
|
||||
"instance": "testdb",
|
||||
"database_type": "mysql"}))
|
||||
get_connections.return_value = [db_connection]
|
||||
run.side_effect = Exception("Exception when running a query")
|
||||
operator = CloudSqlQueryOperator(
|
||||
sql=['SELECT * FROM TABLE'],
|
||||
task_id='task_id'
|
||||
)
|
||||
with self.assertRaises(Exception) as cm:
|
||||
operator.execute(None)
|
||||
err = cm.exception
|
||||
self.assertEqual("Exception when running a query", str(err))
|
||||
delete_connection.assert_called_once_with()
|
||||
|
|
|
@ -61,6 +61,24 @@ class TestMySqlHookConn(unittest.TestCase):
|
|||
self.assertEqual(kwargs['host'], 'host')
|
||||
self.assertEqual(kwargs['db'], 'schema')
|
||||
|
||||
@mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
|
||||
def test_get_conn_from_connection(self, mock_connect):
|
||||
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
|
||||
hook = MySqlHook(connection=conn)
|
||||
hook.get_conn()
|
||||
mock_connect.assert_called_once_with(
|
||||
user='login-conn', passwd='password-conn', host='host', db='schema', port=3306
|
||||
)
|
||||
|
||||
@mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
|
||||
def test_get_conn_from_connection_with_schema(self, mock_connect):
|
||||
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
|
||||
hook = MySqlHook(connection=conn, schema='schema-override')
|
||||
hook.get_conn()
|
||||
mock_connect.assert_called_once_with(
|
||||
user='login-conn', passwd='password-conn', host='host', db='schema-override', port=3306
|
||||
)
|
||||
|
||||
@mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
|
||||
def test_get_conn_port(self, mock_connect):
|
||||
self.connection.port = 3307
|
||||
|
|
|
@ -76,6 +76,24 @@ class TestPostgresHookConn(unittest.TestCase):
|
|||
with self.assertRaises(ValueError):
|
||||
self.db_hook.get_conn()
|
||||
|
||||
@mock.patch('airflow.hooks.postgres_hook.psycopg2.connect')
|
||||
def test_get_conn_from_connection(self, mock_connect):
|
||||
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
|
||||
hook = PostgresHook(connection=conn)
|
||||
hook.get_conn()
|
||||
mock_connect.assert_called_once_with(
|
||||
user='login-conn', password='password-conn', host='host', dbname='schema', port=None
|
||||
)
|
||||
|
||||
@mock.patch('airflow.hooks.postgres_hook.psycopg2.connect')
|
||||
def test_get_conn_from_connection_with_schema(self, mock_connect):
|
||||
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
|
||||
hook = PostgresHook(connection=conn, schema='schema-override')
|
||||
hook.get_conn()
|
||||
mock_connect.assert_called_once_with(
|
||||
user='login-conn', password='password-conn', host='host', dbname='schema-override', port=None
|
||||
)
|
||||
|
||||
@mock.patch('airflow.hooks.postgres_hook.psycopg2.connect')
|
||||
@mock.patch('airflow.contrib.hooks.aws_hook.AwsHook.get_client_type')
|
||||
def test_get_conn_rds_iam_postgres(self, mock_client, mock_connect):
|
||||
|
|
Загрузка…
Ссылка в новой задаче