diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index da85f79f6b..43e6d57420 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -15,8 +15,15 @@ # specific language governing permissions and limitations # under the License. -# TODO(mik-laj): We have to implement it. -# Do you want to help? Please look at: https://github.com/apache/airflow/issues/8127 +from flask import request + +from airflow.api_connexion import parameters +from airflow.api_connexion.exceptions import NotFound +from airflow.api_connexion.schemas.connection_schema import ( + ConnectionCollection, connection_collection_item_schema, connection_collection_schema, +) +from airflow.models import Connection +from airflow.utils.session import provide_session def delete_connection(): @@ -26,18 +33,34 @@ def delete_connection(): raise NotImplementedError("Not implemented yet.") -def get_connection(): +@provide_session +def get_connection(connection_id, session): """ Get a connection entry """ - raise NotImplementedError("Not implemented yet.") + query = session.query(Connection) + query = query.filter(Connection.conn_id == connection_id) + connection = query.one_or_none() + if connection is None: + raise NotFound("Connection not found") + return connection_collection_item_schema.dump(connection) -def get_connections(): +@provide_session +def get_connections(session): """ Get all connection entries """ - raise NotImplementedError("Not implemented yet.") + offset = request.args.get(parameters.page_offset, 0) + limit = min(int(request.args.get(parameters.page_limit, 100)), 100) + + query = session.query(Connection) + total_entries = query.count() + query = query.offset(offset).limit(limit) + + connections = query.all() + return connection_collection_schema.dump(ConnectionCollection(connections=connections, + total_entries=total_entries)) def patch_connection(): diff --git a/airflow/api_connexion/exceptions.py b/airflow/api_connexion/exceptions.py new file mode 100644 index 0000000000..98d09e6522 --- /dev/null +++ b/airflow/api_connexion/exceptions.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from connexion import ProblemException + + +class NotFound(ProblemException): + """Raise when the object cannot be found""" + def __init__(self, title='Object not found', detail=None): + super().__init__(status=404, title=title, detail=detail) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index a3f895d071..8b751d6498 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -1148,12 +1148,16 @@ components: type: string host: type: string + nullable: true login: type: string + nullable: true schema: type: string + nullable: true port: type: integer + nullable: true ConnectionCollection: type: object @@ -1174,6 +1178,7 @@ components: writeOnly: true extra: type: string + nullable: true DAG: type: object @@ -2106,7 +2111,7 @@ components: in: path name: connection_id schema: - type: integer + type: string required: true description: The Connection ID. diff --git a/airflow/api_connexion/parameters.py b/airflow/api_connexion/parameters.py new file mode 100644 index 0000000000..8f2e915183 --- /dev/null +++ b/airflow/api_connexion/parameters.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Pagination parameters +page_offset = "offset" +page_limit = "limit" diff --git a/airflow/api_connexion/schemas/__init__.py b/airflow/api_connexion/schemas/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/airflow/api_connexion/schemas/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_connexion/schemas/connection_schema.py b/airflow/api_connexion/schemas/connection_schema.py new file mode 100644 index 0000000000..387dd7015e --- /dev/null +++ b/airflow/api_connexion/schemas/connection_schema.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, NamedTuple + +from marshmallow import Schema, fields +from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + +from airflow.models.connection import Connection + + +class ConnectionCollectionItemSchema(SQLAlchemySchema): + """ + Schema for a connection item + """ + + class Meta: + """ Meta """ + model = Connection + + conn_id = auto_field(dump_to='connection_id', load_from='connection_id') + conn_type = auto_field() + host = auto_field() + login = auto_field() + schema = auto_field() + port = auto_field() + + +class ConnectionSchema(ConnectionCollectionItemSchema): # pylint: disable=too-many-ancestors + """ + Connection schema + """ + + password = auto_field(load_only=True) + extra = auto_field() + + +class ConnectionCollection(NamedTuple): + """ List of Connections with meta""" + connections: List[Connection] + total_entries: int + + +class ConnectionCollectionSchema(Schema): + """ Connection Collection Schema""" + connections = fields.List(fields.Nested(ConnectionCollectionItemSchema)) + total_entries = fields.Int() + + +connection_schema = ConnectionSchema() +connection_collection_item_schema = ConnectionCollectionItemSchema() +connection_collection_schema = ConnectionCollectionSchema() diff --git a/airflow/www/app.py b/airflow/www/app.py index 73b19191c2..4eb718f35c 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -28,6 +28,7 @@ import connexion import flask import flask_login import pendulum +from connexion import ProblemException from flask import Flask, session as flask_session from flask_appbuilder import SQLA, AppBuilder from flask_caching import Cache @@ -254,6 +255,7 @@ def create_app(config=None, testing=False, app_name="Airflow"): validate_responses=True, strict_validation=False ) + app.register_error_handler(ProblemException, connexion_app.common_error_handler) init_views(appbuilder) init_plugin_blueprints(app) diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py index bb191b3bff..0dbc99e6b6 100644 --- a/tests/api_connexion/endpoints/test_connection_endpoint.py +++ b/tests/api_connexion/endpoints/test_connection_endpoint.py @@ -16,9 +16,12 @@ # under the License. import unittest -import pytest +from parameterized import parameterized +from airflow.models import Connection +from airflow.utils.session import create_session, provide_session from airflow.www import app +from tests.test_utils.db import clear_db_connections class TestConnectionEndpoint(unittest.TestCase): @@ -29,38 +32,182 @@ class TestConnectionEndpoint(unittest.TestCase): def setUp(self) -> None: self.client = self.app.test_client() # type:ignore + # we want only the connection created here for this test + with create_session() as session: + session.query(Connection).delete() + + def tearDown(self) -> None: + clear_db_connections() class TestDeleteConnection(TestConnectionEndpoint): - @pytest.mark.skip(reason="Not implemented yet") + @unittest.skip("Not implemented yet") def test_should_response_200(self): response = self.client.delete("/api/v1/connections/1") assert response.status_code == 200 class TestGetConnection(TestConnectionEndpoint): - @pytest.mark.skip(reason="Not implemented yet") - def test_should_response_200(self): - response = self.client.get("/api/v1/connection/1") + + @provide_session + def test_should_response_200(self, session): + connection_model = Connection(conn_id='test-connection-id', + conn_type='mysql', + host='mysql', + login='login', + schema='testschema', + port=80 + ) + session.add(connection_model) + session.commit() + result = session.query(Connection).all() + assert len(result) == 1 + response = self.client.get("/api/v1/connections/test-connection-id") assert response.status_code == 200 + self.assertEqual( + response.json, + { + "connection_id": "test-connection-id", + "conn_type": 'mysql', + "host": 'mysql', + "login": 'login', + 'schema': 'testschema', + 'port': 80 + }, + ) + + def test_should_response_404(self): + response = self.client.get("/api/v1/connections/invalid-connection") + assert response.status_code == 404 + self.assertEqual( + { + 'detail': None, + 'status': 404, + 'title': 'Connection not found', + 'type': 'about:blank' + }, + response.json + ) class TestGetConnections(TestConnectionEndpoint): - @pytest.mark.skip(reason="Not implemented yet") - def test_should_response_200(self): - response = self.client.get("/api/v1/connections/") + + @provide_session + def test_should_response_200(self, session): + connection_model_1 = Connection(conn_id='test-connection-id-1', + conn_type='test_type') + connection_model_2 = Connection(conn_id='test-connection-id-2', + conn_type='test_type') + connections = [connection_model_1, connection_model_2] + session.add_all(connections) + session.commit() + result = session.query(Connection).all() + assert len(result) == 2 + response = self.client.get("/api/v1/connections") assert response.status_code == 200 + self.assertEqual( + response.json, + { + 'connections': [ + { + "connection_id": "test-connection-id-1", + "conn_type": 'test_type', + "host": None, + "login": None, + 'schema': None, + 'port': None + }, + { + "connection_id": "test-connection-id-2", + "conn_type": 'test_type', + "host": None, + "login": None, + 'schema': None, + 'port': None + } + ], + 'total_entries': 2 + } + ) + + +class TestGetConnectionsPagination(TestConnectionEndpoint): + @parameterized.expand( + [ + ("/api/v1/connections?limit=1", ['TEST_CONN_ID1']), + ("/api/v1/connections?limit=2", ['TEST_CONN_ID1', "TEST_CONN_ID2"]), + ( + "/api/v1/connections?offset=5", + [ + "TEST_CONN_ID6", + "TEST_CONN_ID7", + "TEST_CONN_ID8", + "TEST_CONN_ID9", + "TEST_CONN_ID10", + ], + ), + ( + "/api/v1/connections?offset=0", + [ + "TEST_CONN_ID1", + "TEST_CONN_ID2", + "TEST_CONN_ID3", + "TEST_CONN_ID4", + "TEST_CONN_ID5", + "TEST_CONN_ID6", + "TEST_CONN_ID7", + "TEST_CONN_ID8", + "TEST_CONN_ID9", + "TEST_CONN_ID10", + ], + ), + ("/api/v1/connections?limit=1&offset=5", ["TEST_CONN_ID6"]), + ("/api/v1/connections?limit=1&offset=1", ["TEST_CONN_ID2"]), + ( + "/api/v1/connections?limit=2&offset=2", + ["TEST_CONN_ID3", "TEST_CONN_ID4"], + ), + ] + ) + @provide_session + def test_handle_limit_offset(self, url, expected_conn_ids, session): + connections = self._create_connections(10) + session.add_all(connections) + session.commit() + response = self.client.get(url) + assert response.status_code == 200 + self.assertEqual(response.json["total_entries"], 10) + conn_ids = [conn["connection_id"] for conn in response.json["connections"] if conn] + self.assertEqual(conn_ids, expected_conn_ids) + + @provide_session + def test_should_respect_page_size_limit(self, session): + connection_models = self._create_connections(200) + session.add_all(connection_models) + session.commit() + + response = self.client.get("/api/v1/connections?limit=150") + assert response.status_code == 200 + + self.assertEqual(response.json["total_entries"], 200) + self.assertEqual(len(response.json["connections"]), 100) + + def _create_connections(self, count): + return [Connection( + conn_id='TEST_CONN_ID' + str(i), + conn_type='TEST_CONN_TYPE' + str(i) + ) for i in range(1, count + 1)] class TestPatchConnection(TestConnectionEndpoint): - @pytest.mark.skip(reason="Not implemented yet") + @unittest.skip("Not implemented yet") def test_should_response_200(self): response = self.client.patch("/api/v1/connections/1") assert response.status_code == 200 class TestPostConnection(TestConnectionEndpoint): - @pytest.mark.skip(reason="Not implemented yet") + @unittest.skip("Not implemented yet") def test_should_response_200(self): response = self.client.post("/api/v1/connections/") assert response.status_code == 200 diff --git a/tests/api_connexion/schemas/__init__.py b/tests/api_connexion/schemas/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/api_connexion/schemas/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/api_connexion/schemas/test_connection_schema.py b/tests/api_connexion/schemas/test_connection_schema.py new file mode 100644 index 0000000000..4e8876bbb3 --- /dev/null +++ b/tests/api_connexion/schemas/test_connection_schema.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from airflow.api_connexion.schemas.connection_schema import ( + ConnectionCollection, connection_collection_item_schema, connection_collection_schema, connection_schema, +) +from airflow.models import Connection +from airflow.utils.session import create_session, provide_session +from tests.test_utils.db import clear_db_connections + + +class TestConnectionCollectionItemSchema(unittest.TestCase): + + def setUp(self) -> None: + with create_session() as session: + session.query(Connection).delete() + + def tearDown(self) -> None: + clear_db_connections() + + @provide_session + def test_serialzie(self, session): + connection_model = Connection( + conn_id='mysql_default', + conn_type='mysql', + host='mysql', + login='login', + schema='testschema', + port=80 + ) + session.add(connection_model) + session.commit() + connection_model = session.query(Connection).first() + deserialized_connection = connection_collection_item_schema.dump(connection_model) + self.assertEqual( + deserialized_connection[0], + { + 'connection_id': "mysql_default", + 'conn_type': 'mysql', + 'host': 'mysql', + 'login': 'login', + 'schema': 'testschema', + 'port': 80 + } + ) + + def test_deserialize(self): + connection_dump_1 = { + 'connection_id': "mysql_default_1", + 'conn_type': 'mysql', + 'host': 'mysql', + 'login': 'login', + 'schema': 'testschema', + 'port': 80 + } + connection_dump_2 = { + 'connection_id': "mysql_default_2" + } + result_1 = connection_collection_item_schema.load(connection_dump_1) + result_2 = connection_collection_item_schema.load(connection_dump_2) + + self.assertEqual( + result_1[0], + { + 'conn_id': "mysql_default_1", + 'conn_type': 'mysql', + 'host': 'mysql', + 'login': 'login', + 'schema': 'testschema', + 'port': 80 + } + ) + self.assertEqual( + result_2[0], + { + 'conn_id': "mysql_default_2", + } + ) + + +class TestConnectionCollectionSchema(unittest.TestCase): + + def setUp(self) -> None: + with create_session() as session: + session.query(Connection).delete() + + def tearDown(self) -> None: + clear_db_connections() + + @provide_session + def test_serialzie(self, session): + connection_model_1 = Connection( + conn_id='mysql_default_1', + conn_type='test-type' + ) + connection_model_2 = Connection( + conn_id='mysql_default_2', + conn_type='test-type2' + ) + connections = [connection_model_1, connection_model_2] + session.add_all(connections) + session.commit() + instance = ConnectionCollection( + connections=connections, + total_entries=2 + ) + deserialized_connections = connection_collection_schema.dump(instance) + self.assertEqual( + deserialized_connections[0], + { + 'connections': [ + { + "connection_id": "mysql_default_1", + "conn_type": "test-type", + "host": None, + "login": None, + 'schema': None, + 'port': None + }, + { + "connection_id": "mysql_default_2", + "conn_type": "test-type2", + "host": None, + "login": None, + 'schema': None, + 'port': None + } + ], + 'total_entries': 2 + } + ) + + +class TestConnectionSchema(unittest.TestCase): + + def setUp(self) -> None: + with create_session() as session: + session.query(Connection).delete() + + def tearDown(self) -> None: + clear_db_connections() + + @provide_session + def test_serialize(self, session): + connection_model = Connection( + conn_id='mysql_default', + conn_type='mysql', + host='mysql', + login='login', + schema='testschema', + port=80, + password='test-password', + extra="{'key':'string'}" + ) + session.add(connection_model) + session.commit() + connection_model = session.query(Connection).first() + deserialized_connection = connection_schema.dump(connection_model) + self.assertEqual( + deserialized_connection[0], + { + 'connection_id': "mysql_default", + 'conn_type': 'mysql', + 'host': 'mysql', + 'login': 'login', + 'schema': 'testschema', + 'port': 80, + 'extra': "{'key':'string'}" + } + ) + + def test_deserialize(self): + den = { + 'connection_id': "mysql_default", + 'conn_type': 'mysql', + 'host': 'mysql', + 'login': 'login', + 'schema': 'testschema', + 'port': 80, + 'extra': "{'key':'string'}" + } + result = connection_schema.load(den) + self.assertEqual( + result[0], + { + 'conn_id': "mysql_default", + 'conn_type': 'mysql', + 'host': 'mysql', + 'login': 'login', + 'schema': 'testschema', + 'port': 80, + 'extra': "{'key':'string'}" + } + )