Add readonly connection API endpoints (#9095)
* add connection schema with tests * add endpoints for connection * update patch * update endpoint methods * add readonly connection endpoints * improve base schema and add tests * update spec set connection id to string * update type hint * improve base schema * add pre_load processing to return data to normal * remove pre_load processing as it is not needed * readonly endpoints * improve code * handle exception and improve code * improve pagination test * add nullable to spec and improve code * remove base and add parameterized tests * add pagination limit test
This commit is contained in:
Родитель
0682e784b1
Коммит
ecbb366e63
|
@ -15,8 +15,15 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
# TODO(mik-laj): We have to implement it.
|
from flask import request
|
||||||
# Do you want to help? Please look at: https://github.com/apache/airflow/issues/8127
|
|
||||||
|
from airflow.api_connexion import parameters
|
||||||
|
from airflow.api_connexion.exceptions import NotFound
|
||||||
|
from airflow.api_connexion.schemas.connection_schema import (
|
||||||
|
ConnectionCollection, connection_collection_item_schema, connection_collection_schema,
|
||||||
|
)
|
||||||
|
from airflow.models import Connection
|
||||||
|
from airflow.utils.session import provide_session
|
||||||
|
|
||||||
|
|
||||||
def delete_connection():
|
def delete_connection():
|
||||||
|
@ -26,18 +33,34 @@ def delete_connection():
|
||||||
raise NotImplementedError("Not implemented yet.")
|
raise NotImplementedError("Not implemented yet.")
|
||||||
|
|
||||||
|
|
||||||
def get_connection():
|
@provide_session
|
||||||
|
def get_connection(connection_id, session):
|
||||||
"""
|
"""
|
||||||
Get a connection entry
|
Get a connection entry
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Not implemented yet.")
|
query = session.query(Connection)
|
||||||
|
query = query.filter(Connection.conn_id == connection_id)
|
||||||
|
connection = query.one_or_none()
|
||||||
|
if connection is None:
|
||||||
|
raise NotFound("Connection not found")
|
||||||
|
return connection_collection_item_schema.dump(connection)
|
||||||
|
|
||||||
|
|
||||||
def get_connections():
|
@provide_session
|
||||||
|
def get_connections(session):
|
||||||
"""
|
"""
|
||||||
Get all connection entries
|
Get all connection entries
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Not implemented yet.")
|
offset = request.args.get(parameters.page_offset, 0)
|
||||||
|
limit = min(int(request.args.get(parameters.page_limit, 100)), 100)
|
||||||
|
|
||||||
|
query = session.query(Connection)
|
||||||
|
total_entries = query.count()
|
||||||
|
query = query.offset(offset).limit(limit)
|
||||||
|
|
||||||
|
connections = query.all()
|
||||||
|
return connection_collection_schema.dump(ConnectionCollection(connections=connections,
|
||||||
|
total_entries=total_entries))
|
||||||
|
|
||||||
|
|
||||||
def patch_connection():
|
def patch_connection():
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from connexion import ProblemException
|
||||||
|
|
||||||
|
|
||||||
|
class NotFound(ProblemException):
|
||||||
|
"""Raise when the object cannot be found"""
|
||||||
|
def __init__(self, title='Object not found', detail=None):
|
||||||
|
super().__init__(status=404, title=title, detail=detail)
|
|
@ -1148,12 +1148,16 @@ components:
|
||||||
type: string
|
type: string
|
||||||
host:
|
host:
|
||||||
type: string
|
type: string
|
||||||
|
nullable: true
|
||||||
login:
|
login:
|
||||||
type: string
|
type: string
|
||||||
|
nullable: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
nullable: true
|
||||||
port:
|
port:
|
||||||
type: integer
|
type: integer
|
||||||
|
nullable: true
|
||||||
|
|
||||||
ConnectionCollection:
|
ConnectionCollection:
|
||||||
type: object
|
type: object
|
||||||
|
@ -1174,6 +1178,7 @@ components:
|
||||||
writeOnly: true
|
writeOnly: true
|
||||||
extra:
|
extra:
|
||||||
type: string
|
type: string
|
||||||
|
nullable: true
|
||||||
|
|
||||||
DAG:
|
DAG:
|
||||||
type: object
|
type: object
|
||||||
|
@ -2106,7 +2111,7 @@ components:
|
||||||
in: path
|
in: path
|
||||||
name: connection_id
|
name: connection_id
|
||||||
schema:
|
schema:
|
||||||
type: integer
|
type: string
|
||||||
required: true
|
required: true
|
||||||
description: The Connection ID.
|
description: The Connection ID.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
# Pagination parameters
|
||||||
|
page_offset = "offset"
|
||||||
|
page_limit = "limit"
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,66 @@
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
from typing import List, NamedTuple
|
||||||
|
|
||||||
|
from marshmallow import Schema, fields
|
||||||
|
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
|
||||||
|
|
||||||
|
from airflow.models.connection import Connection
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionCollectionItemSchema(SQLAlchemySchema):
|
||||||
|
"""
|
||||||
|
Schema for a connection item
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
""" Meta """
|
||||||
|
model = Connection
|
||||||
|
|
||||||
|
conn_id = auto_field(dump_to='connection_id', load_from='connection_id')
|
||||||
|
conn_type = auto_field()
|
||||||
|
host = auto_field()
|
||||||
|
login = auto_field()
|
||||||
|
schema = auto_field()
|
||||||
|
port = auto_field()
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionSchema(ConnectionCollectionItemSchema): # pylint: disable=too-many-ancestors
|
||||||
|
"""
|
||||||
|
Connection schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
password = auto_field(load_only=True)
|
||||||
|
extra = auto_field()
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionCollection(NamedTuple):
|
||||||
|
""" List of Connections with meta"""
|
||||||
|
connections: List[Connection]
|
||||||
|
total_entries: int
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionCollectionSchema(Schema):
|
||||||
|
""" Connection Collection Schema"""
|
||||||
|
connections = fields.List(fields.Nested(ConnectionCollectionItemSchema))
|
||||||
|
total_entries = fields.Int()
|
||||||
|
|
||||||
|
|
||||||
|
connection_schema = ConnectionSchema()
|
||||||
|
connection_collection_item_schema = ConnectionCollectionItemSchema()
|
||||||
|
connection_collection_schema = ConnectionCollectionSchema()
|
|
@ -28,6 +28,7 @@ import connexion
|
||||||
import flask
|
import flask
|
||||||
import flask_login
|
import flask_login
|
||||||
import pendulum
|
import pendulum
|
||||||
|
from connexion import ProblemException
|
||||||
from flask import Flask, session as flask_session
|
from flask import Flask, session as flask_session
|
||||||
from flask_appbuilder import SQLA, AppBuilder
|
from flask_appbuilder import SQLA, AppBuilder
|
||||||
from flask_caching import Cache
|
from flask_caching import Cache
|
||||||
|
@ -254,6 +255,7 @@ def create_app(config=None, testing=False, app_name="Airflow"):
|
||||||
validate_responses=True,
|
validate_responses=True,
|
||||||
strict_validation=False
|
strict_validation=False
|
||||||
)
|
)
|
||||||
|
app.register_error_handler(ProblemException, connexion_app.common_error_handler)
|
||||||
|
|
||||||
init_views(appbuilder)
|
init_views(appbuilder)
|
||||||
init_plugin_blueprints(app)
|
init_plugin_blueprints(app)
|
||||||
|
|
|
@ -16,9 +16,12 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from airflow.models import Connection
|
||||||
|
from airflow.utils.session import create_session, provide_session
|
||||||
from airflow.www import app
|
from airflow.www import app
|
||||||
|
from tests.test_utils.db import clear_db_connections
|
||||||
|
|
||||||
|
|
||||||
class TestConnectionEndpoint(unittest.TestCase):
|
class TestConnectionEndpoint(unittest.TestCase):
|
||||||
|
@ -29,38 +32,182 @@ class TestConnectionEndpoint(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.client = self.app.test_client() # type:ignore
|
self.client = self.app.test_client() # type:ignore
|
||||||
|
# we want only the connection created here for this test
|
||||||
|
with create_session() as session:
|
||||||
|
session.query(Connection).delete()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
clear_db_connections()
|
||||||
|
|
||||||
|
|
||||||
class TestDeleteConnection(TestConnectionEndpoint):
|
class TestDeleteConnection(TestConnectionEndpoint):
|
||||||
@pytest.mark.skip(reason="Not implemented yet")
|
@unittest.skip("Not implemented yet")
|
||||||
def test_should_response_200(self):
|
def test_should_response_200(self):
|
||||||
response = self.client.delete("/api/v1/connections/1")
|
response = self.client.delete("/api/v1/connections/1")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
class TestGetConnection(TestConnectionEndpoint):
|
class TestGetConnection(TestConnectionEndpoint):
|
||||||
@pytest.mark.skip(reason="Not implemented yet")
|
|
||||||
def test_should_response_200(self):
|
@provide_session
|
||||||
response = self.client.get("/api/v1/connection/1")
|
def test_should_response_200(self, session):
|
||||||
|
connection_model = Connection(conn_id='test-connection-id',
|
||||||
|
conn_type='mysql',
|
||||||
|
host='mysql',
|
||||||
|
login='login',
|
||||||
|
schema='testschema',
|
||||||
|
port=80
|
||||||
|
)
|
||||||
|
session.add(connection_model)
|
||||||
|
session.commit()
|
||||||
|
result = session.query(Connection).all()
|
||||||
|
assert len(result) == 1
|
||||||
|
response = self.client.get("/api/v1/connections/test-connection-id")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
self.assertEqual(
|
||||||
|
response.json,
|
||||||
|
{
|
||||||
|
"connection_id": "test-connection-id",
|
||||||
|
"conn_type": 'mysql',
|
||||||
|
"host": 'mysql',
|
||||||
|
"login": 'login',
|
||||||
|
'schema': 'testschema',
|
||||||
|
'port': 80
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_should_response_404(self):
|
||||||
|
response = self.client.get("/api/v1/connections/invalid-connection")
|
||||||
|
assert response.status_code == 404
|
||||||
|
self.assertEqual(
|
||||||
|
{
|
||||||
|
'detail': None,
|
||||||
|
'status': 404,
|
||||||
|
'title': 'Connection not found',
|
||||||
|
'type': 'about:blank'
|
||||||
|
},
|
||||||
|
response.json
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestGetConnections(TestConnectionEndpoint):
|
class TestGetConnections(TestConnectionEndpoint):
|
||||||
@pytest.mark.skip(reason="Not implemented yet")
|
|
||||||
def test_should_response_200(self):
|
@provide_session
|
||||||
response = self.client.get("/api/v1/connections/")
|
def test_should_response_200(self, session):
|
||||||
|
connection_model_1 = Connection(conn_id='test-connection-id-1',
|
||||||
|
conn_type='test_type')
|
||||||
|
connection_model_2 = Connection(conn_id='test-connection-id-2',
|
||||||
|
conn_type='test_type')
|
||||||
|
connections = [connection_model_1, connection_model_2]
|
||||||
|
session.add_all(connections)
|
||||||
|
session.commit()
|
||||||
|
result = session.query(Connection).all()
|
||||||
|
assert len(result) == 2
|
||||||
|
response = self.client.get("/api/v1/connections")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
self.assertEqual(
|
||||||
|
response.json,
|
||||||
|
{
|
||||||
|
'connections': [
|
||||||
|
{
|
||||||
|
"connection_id": "test-connection-id-1",
|
||||||
|
"conn_type": 'test_type',
|
||||||
|
"host": None,
|
||||||
|
"login": None,
|
||||||
|
'schema': None,
|
||||||
|
'port': None
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"connection_id": "test-connection-id-2",
|
||||||
|
"conn_type": 'test_type',
|
||||||
|
"host": None,
|
||||||
|
"login": None,
|
||||||
|
'schema': None,
|
||||||
|
'port': None
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'total_entries': 2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetConnectionsPagination(TestConnectionEndpoint):
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
("/api/v1/connections?limit=1", ['TEST_CONN_ID1']),
|
||||||
|
("/api/v1/connections?limit=2", ['TEST_CONN_ID1', "TEST_CONN_ID2"]),
|
||||||
|
(
|
||||||
|
"/api/v1/connections?offset=5",
|
||||||
|
[
|
||||||
|
"TEST_CONN_ID6",
|
||||||
|
"TEST_CONN_ID7",
|
||||||
|
"TEST_CONN_ID8",
|
||||||
|
"TEST_CONN_ID9",
|
||||||
|
"TEST_CONN_ID10",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"/api/v1/connections?offset=0",
|
||||||
|
[
|
||||||
|
"TEST_CONN_ID1",
|
||||||
|
"TEST_CONN_ID2",
|
||||||
|
"TEST_CONN_ID3",
|
||||||
|
"TEST_CONN_ID4",
|
||||||
|
"TEST_CONN_ID5",
|
||||||
|
"TEST_CONN_ID6",
|
||||||
|
"TEST_CONN_ID7",
|
||||||
|
"TEST_CONN_ID8",
|
||||||
|
"TEST_CONN_ID9",
|
||||||
|
"TEST_CONN_ID10",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
("/api/v1/connections?limit=1&offset=5", ["TEST_CONN_ID6"]),
|
||||||
|
("/api/v1/connections?limit=1&offset=1", ["TEST_CONN_ID2"]),
|
||||||
|
(
|
||||||
|
"/api/v1/connections?limit=2&offset=2",
|
||||||
|
["TEST_CONN_ID3", "TEST_CONN_ID4"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@provide_session
|
||||||
|
def test_handle_limit_offset(self, url, expected_conn_ids, session):
|
||||||
|
connections = self._create_connections(10)
|
||||||
|
session.add_all(connections)
|
||||||
|
session.commit()
|
||||||
|
response = self.client.get(url)
|
||||||
|
assert response.status_code == 200
|
||||||
|
self.assertEqual(response.json["total_entries"], 10)
|
||||||
|
conn_ids = [conn["connection_id"] for conn in response.json["connections"] if conn]
|
||||||
|
self.assertEqual(conn_ids, expected_conn_ids)
|
||||||
|
|
||||||
|
@provide_session
|
||||||
|
def test_should_respect_page_size_limit(self, session):
|
||||||
|
connection_models = self._create_connections(200)
|
||||||
|
session.add_all(connection_models)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
response = self.client.get("/api/v1/connections?limit=150")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
self.assertEqual(response.json["total_entries"], 200)
|
||||||
|
self.assertEqual(len(response.json["connections"]), 100)
|
||||||
|
|
||||||
|
def _create_connections(self, count):
|
||||||
|
return [Connection(
|
||||||
|
conn_id='TEST_CONN_ID' + str(i),
|
||||||
|
conn_type='TEST_CONN_TYPE' + str(i)
|
||||||
|
) for i in range(1, count + 1)]
|
||||||
|
|
||||||
|
|
||||||
class TestPatchConnection(TestConnectionEndpoint):
|
class TestPatchConnection(TestConnectionEndpoint):
|
||||||
@pytest.mark.skip(reason="Not implemented yet")
|
@unittest.skip("Not implemented yet")
|
||||||
def test_should_response_200(self):
|
def test_should_response_200(self):
|
||||||
response = self.client.patch("/api/v1/connections/1")
|
response = self.client.patch("/api/v1/connections/1")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
class TestPostConnection(TestConnectionEndpoint):
|
class TestPostConnection(TestConnectionEndpoint):
|
||||||
@pytest.mark.skip(reason="Not implemented yet")
|
@unittest.skip("Not implemented yet")
|
||||||
def test_should_response_200(self):
|
def test_should_response_200(self):
|
||||||
response = self.client.post("/api/v1/connections/")
|
response = self.client.post("/api/v1/connections/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,210 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from airflow.api_connexion.schemas.connection_schema import (
|
||||||
|
ConnectionCollection, connection_collection_item_schema, connection_collection_schema, connection_schema,
|
||||||
|
)
|
||||||
|
from airflow.models import Connection
|
||||||
|
from airflow.utils.session import create_session, provide_session
|
||||||
|
from tests.test_utils.db import clear_db_connections
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionCollectionItemSchema(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
with create_session() as session:
|
||||||
|
session.query(Connection).delete()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
clear_db_connections()
|
||||||
|
|
||||||
|
@provide_session
|
||||||
|
def test_serialzie(self, session):
|
||||||
|
connection_model = Connection(
|
||||||
|
conn_id='mysql_default',
|
||||||
|
conn_type='mysql',
|
||||||
|
host='mysql',
|
||||||
|
login='login',
|
||||||
|
schema='testschema',
|
||||||
|
port=80
|
||||||
|
)
|
||||||
|
session.add(connection_model)
|
||||||
|
session.commit()
|
||||||
|
connection_model = session.query(Connection).first()
|
||||||
|
deserialized_connection = connection_collection_item_schema.dump(connection_model)
|
||||||
|
self.assertEqual(
|
||||||
|
deserialized_connection[0],
|
||||||
|
{
|
||||||
|
'connection_id': "mysql_default",
|
||||||
|
'conn_type': 'mysql',
|
||||||
|
'host': 'mysql',
|
||||||
|
'login': 'login',
|
||||||
|
'schema': 'testschema',
|
||||||
|
'port': 80
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_deserialize(self):
|
||||||
|
connection_dump_1 = {
|
||||||
|
'connection_id': "mysql_default_1",
|
||||||
|
'conn_type': 'mysql',
|
||||||
|
'host': 'mysql',
|
||||||
|
'login': 'login',
|
||||||
|
'schema': 'testschema',
|
||||||
|
'port': 80
|
||||||
|
}
|
||||||
|
connection_dump_2 = {
|
||||||
|
'connection_id': "mysql_default_2"
|
||||||
|
}
|
||||||
|
result_1 = connection_collection_item_schema.load(connection_dump_1)
|
||||||
|
result_2 = connection_collection_item_schema.load(connection_dump_2)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
result_1[0],
|
||||||
|
{
|
||||||
|
'conn_id': "mysql_default_1",
|
||||||
|
'conn_type': 'mysql',
|
||||||
|
'host': 'mysql',
|
||||||
|
'login': 'login',
|
||||||
|
'schema': 'testschema',
|
||||||
|
'port': 80
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
result_2[0],
|
||||||
|
{
|
||||||
|
'conn_id': "mysql_default_2",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionCollectionSchema(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
with create_session() as session:
|
||||||
|
session.query(Connection).delete()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
clear_db_connections()
|
||||||
|
|
||||||
|
@provide_session
|
||||||
|
def test_serialzie(self, session):
|
||||||
|
connection_model_1 = Connection(
|
||||||
|
conn_id='mysql_default_1',
|
||||||
|
conn_type='test-type'
|
||||||
|
)
|
||||||
|
connection_model_2 = Connection(
|
||||||
|
conn_id='mysql_default_2',
|
||||||
|
conn_type='test-type2'
|
||||||
|
)
|
||||||
|
connections = [connection_model_1, connection_model_2]
|
||||||
|
session.add_all(connections)
|
||||||
|
session.commit()
|
||||||
|
instance = ConnectionCollection(
|
||||||
|
connections=connections,
|
||||||
|
total_entries=2
|
||||||
|
)
|
||||||
|
deserialized_connections = connection_collection_schema.dump(instance)
|
||||||
|
self.assertEqual(
|
||||||
|
deserialized_connections[0],
|
||||||
|
{
|
||||||
|
'connections': [
|
||||||
|
{
|
||||||
|
"connection_id": "mysql_default_1",
|
||||||
|
"conn_type": "test-type",
|
||||||
|
"host": None,
|
||||||
|
"login": None,
|
||||||
|
'schema': None,
|
||||||
|
'port': None
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"connection_id": "mysql_default_2",
|
||||||
|
"conn_type": "test-type2",
|
||||||
|
"host": None,
|
||||||
|
"login": None,
|
||||||
|
'schema': None,
|
||||||
|
'port': None
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'total_entries': 2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestConnectionSchema(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
with create_session() as session:
|
||||||
|
session.query(Connection).delete()
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
clear_db_connections()
|
||||||
|
|
||||||
|
@provide_session
|
||||||
|
def test_serialize(self, session):
|
||||||
|
connection_model = Connection(
|
||||||
|
conn_id='mysql_default',
|
||||||
|
conn_type='mysql',
|
||||||
|
host='mysql',
|
||||||
|
login='login',
|
||||||
|
schema='testschema',
|
||||||
|
port=80,
|
||||||
|
password='test-password',
|
||||||
|
extra="{'key':'string'}"
|
||||||
|
)
|
||||||
|
session.add(connection_model)
|
||||||
|
session.commit()
|
||||||
|
connection_model = session.query(Connection).first()
|
||||||
|
deserialized_connection = connection_schema.dump(connection_model)
|
||||||
|
self.assertEqual(
|
||||||
|
deserialized_connection[0],
|
||||||
|
{
|
||||||
|
'connection_id': "mysql_default",
|
||||||
|
'conn_type': 'mysql',
|
||||||
|
'host': 'mysql',
|
||||||
|
'login': 'login',
|
||||||
|
'schema': 'testschema',
|
||||||
|
'port': 80,
|
||||||
|
'extra': "{'key':'string'}"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_deserialize(self):
|
||||||
|
den = {
|
||||||
|
'connection_id': "mysql_default",
|
||||||
|
'conn_type': 'mysql',
|
||||||
|
'host': 'mysql',
|
||||||
|
'login': 'login',
|
||||||
|
'schema': 'testschema',
|
||||||
|
'port': 80,
|
||||||
|
'extra': "{'key':'string'}"
|
||||||
|
}
|
||||||
|
result = connection_schema.load(den)
|
||||||
|
self.assertEqual(
|
||||||
|
result[0],
|
||||||
|
{
|
||||||
|
'conn_id': "mysql_default",
|
||||||
|
'conn_type': 'mysql',
|
||||||
|
'host': 'mysql',
|
||||||
|
'login': 'login',
|
||||||
|
'schema': 'testschema',
|
||||||
|
'port': 80,
|
||||||
|
'extra': "{'key':'string'}"
|
||||||
|
}
|
||||||
|
)
|
Загрузка…
Ссылка в новой задаче