Add readonly connection API endpoints (#9095)

* add connection schema with tests

* add endpoints for connection

* update patch

* update endpoint methods

* add readonly connection endpoints

* improve base schema and add tests

* update spec set connection id to string

* update type hint

* improve base schema

* add pre_load processing to return data to normal

* remove pre_load processing as it is not needed

* readonly endpoints

* improve code

* handle exception and improve code

* improve pagination test

* add nullable to spec and improve code

* remove base and add parameterized tests

* add pagination limit test
This commit is contained in:
Ephraim Anierobi 2020-06-11 20:53:36 +01:00 коммит произвёл GitHub
Родитель 0682e784b1
Коммит ecbb366e63
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 546 добавлений и 17 удалений

Просмотреть файл

@ -15,8 +15,15 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# TODO(mik-laj): We have to implement it. from flask import request
# Do you want to help? Please look at: https://github.com/apache/airflow/issues/8127
from airflow.api_connexion import parameters
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.schemas.connection_schema import (
ConnectionCollection, connection_collection_item_schema, connection_collection_schema,
)
from airflow.models import Connection
from airflow.utils.session import provide_session
def delete_connection(): def delete_connection():
@ -26,18 +33,34 @@ def delete_connection():
raise NotImplementedError("Not implemented yet.") raise NotImplementedError("Not implemented yet.")
def get_connection(): @provide_session
def get_connection(connection_id, session):
""" """
Get a connection entry Get a connection entry
""" """
raise NotImplementedError("Not implemented yet.") query = session.query(Connection)
query = query.filter(Connection.conn_id == connection_id)
connection = query.one_or_none()
if connection is None:
raise NotFound("Connection not found")
return connection_collection_item_schema.dump(connection)
def get_connections(): @provide_session
def get_connections(session):
""" """
Get all connection entries Get all connection entries
""" """
raise NotImplementedError("Not implemented yet.") offset = request.args.get(parameters.page_offset, 0)
limit = min(int(request.args.get(parameters.page_limit, 100)), 100)
query = session.query(Connection)
total_entries = query.count()
query = query.offset(offset).limit(limit)
connections = query.all()
return connection_collection_schema.dump(ConnectionCollection(connections=connections,
total_entries=total_entries))
def patch_connection(): def patch_connection():

Просмотреть файл

@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from connexion import ProblemException
class NotFound(ProblemException):
"""Raise when the object cannot be found"""
def __init__(self, title='Object not found', detail=None):
super().__init__(status=404, title=title, detail=detail)

Просмотреть файл

@ -1148,12 +1148,16 @@ components:
type: string type: string
host: host:
type: string type: string
nullable: true
login: login:
type: string type: string
nullable: true
schema: schema:
type: string type: string
nullable: true
port: port:
type: integer type: integer
nullable: true
ConnectionCollection: ConnectionCollection:
type: object type: object
@ -1174,6 +1178,7 @@ components:
writeOnly: true writeOnly: true
extra: extra:
type: string type: string
nullable: true
DAG: DAG:
type: object type: object
@ -2106,7 +2111,7 @@ components:
in: path in: path
name: connection_id name: connection_id
schema: schema:
type: integer type: string
required: true required: true
description: The Connection ID. description: The Connection ID.

Просмотреть файл

@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Pagination parameters
page_offset = "offset"
page_limit = "limit"

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,66 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List, NamedTuple
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
from airflow.models.connection import Connection
class ConnectionCollectionItemSchema(SQLAlchemySchema):
"""
Schema for a connection item
"""
class Meta:
""" Meta """
model = Connection
conn_id = auto_field(dump_to='connection_id', load_from='connection_id')
conn_type = auto_field()
host = auto_field()
login = auto_field()
schema = auto_field()
port = auto_field()
class ConnectionSchema(ConnectionCollectionItemSchema): # pylint: disable=too-many-ancestors
"""
Connection schema
"""
password = auto_field(load_only=True)
extra = auto_field()
class ConnectionCollection(NamedTuple):
""" List of Connections with meta"""
connections: List[Connection]
total_entries: int
class ConnectionCollectionSchema(Schema):
""" Connection Collection Schema"""
connections = fields.List(fields.Nested(ConnectionCollectionItemSchema))
total_entries = fields.Int()
connection_schema = ConnectionSchema()
connection_collection_item_schema = ConnectionCollectionItemSchema()
connection_collection_schema = ConnectionCollectionSchema()

Просмотреть файл

@ -28,6 +28,7 @@ import connexion
import flask import flask
import flask_login import flask_login
import pendulum import pendulum
from connexion import ProblemException
from flask import Flask, session as flask_session from flask import Flask, session as flask_session
from flask_appbuilder import SQLA, AppBuilder from flask_appbuilder import SQLA, AppBuilder
from flask_caching import Cache from flask_caching import Cache
@ -254,6 +255,7 @@ def create_app(config=None, testing=False, app_name="Airflow"):
validate_responses=True, validate_responses=True,
strict_validation=False strict_validation=False
) )
app.register_error_handler(ProblemException, connexion_app.common_error_handler)
init_views(appbuilder) init_views(appbuilder)
init_plugin_blueprints(app) init_plugin_blueprints(app)

Просмотреть файл

@ -16,9 +16,12 @@
# under the License. # under the License.
import unittest import unittest
import pytest from parameterized import parameterized
from airflow.models import Connection
from airflow.utils.session import create_session, provide_session
from airflow.www import app from airflow.www import app
from tests.test_utils.db import clear_db_connections
class TestConnectionEndpoint(unittest.TestCase): class TestConnectionEndpoint(unittest.TestCase):
@ -29,38 +32,182 @@ class TestConnectionEndpoint(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.client = self.app.test_client() # type:ignore self.client = self.app.test_client() # type:ignore
# we want only the connection created here for this test
with create_session() as session:
session.query(Connection).delete()
def tearDown(self) -> None:
clear_db_connections()
class TestDeleteConnection(TestConnectionEndpoint): class TestDeleteConnection(TestConnectionEndpoint):
@pytest.mark.skip(reason="Not implemented yet") @unittest.skip("Not implemented yet")
def test_should_response_200(self): def test_should_response_200(self):
response = self.client.delete("/api/v1/connections/1") response = self.client.delete("/api/v1/connections/1")
assert response.status_code == 200 assert response.status_code == 200
class TestGetConnection(TestConnectionEndpoint): class TestGetConnection(TestConnectionEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self): @provide_session
response = self.client.get("/api/v1/connection/1") def test_should_response_200(self, session):
connection_model = Connection(conn_id='test-connection-id',
conn_type='mysql',
host='mysql',
login='login',
schema='testschema',
port=80
)
session.add(connection_model)
session.commit()
result = session.query(Connection).all()
assert len(result) == 1
response = self.client.get("/api/v1/connections/test-connection-id")
assert response.status_code == 200 assert response.status_code == 200
self.assertEqual(
response.json,
{
"connection_id": "test-connection-id",
"conn_type": 'mysql',
"host": 'mysql',
"login": 'login',
'schema': 'testschema',
'port': 80
},
)
def test_should_response_404(self):
response = self.client.get("/api/v1/connections/invalid-connection")
assert response.status_code == 404
self.assertEqual(
{
'detail': None,
'status': 404,
'title': 'Connection not found',
'type': 'about:blank'
},
response.json
)
class TestGetConnections(TestConnectionEndpoint): class TestGetConnections(TestConnectionEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self): @provide_session
response = self.client.get("/api/v1/connections/") def test_should_response_200(self, session):
connection_model_1 = Connection(conn_id='test-connection-id-1',
conn_type='test_type')
connection_model_2 = Connection(conn_id='test-connection-id-2',
conn_type='test_type')
connections = [connection_model_1, connection_model_2]
session.add_all(connections)
session.commit()
result = session.query(Connection).all()
assert len(result) == 2
response = self.client.get("/api/v1/connections")
assert response.status_code == 200 assert response.status_code == 200
self.assertEqual(
response.json,
{
'connections': [
{
"connection_id": "test-connection-id-1",
"conn_type": 'test_type',
"host": None,
"login": None,
'schema': None,
'port': None
},
{
"connection_id": "test-connection-id-2",
"conn_type": 'test_type',
"host": None,
"login": None,
'schema': None,
'port': None
}
],
'total_entries': 2
}
)
class TestGetConnectionsPagination(TestConnectionEndpoint):
@parameterized.expand(
[
("/api/v1/connections?limit=1", ['TEST_CONN_ID1']),
("/api/v1/connections?limit=2", ['TEST_CONN_ID1', "TEST_CONN_ID2"]),
(
"/api/v1/connections?offset=5",
[
"TEST_CONN_ID6",
"TEST_CONN_ID7",
"TEST_CONN_ID8",
"TEST_CONN_ID9",
"TEST_CONN_ID10",
],
),
(
"/api/v1/connections?offset=0",
[
"TEST_CONN_ID1",
"TEST_CONN_ID2",
"TEST_CONN_ID3",
"TEST_CONN_ID4",
"TEST_CONN_ID5",
"TEST_CONN_ID6",
"TEST_CONN_ID7",
"TEST_CONN_ID8",
"TEST_CONN_ID9",
"TEST_CONN_ID10",
],
),
("/api/v1/connections?limit=1&offset=5", ["TEST_CONN_ID6"]),
("/api/v1/connections?limit=1&offset=1", ["TEST_CONN_ID2"]),
(
"/api/v1/connections?limit=2&offset=2",
["TEST_CONN_ID3", "TEST_CONN_ID4"],
),
]
)
@provide_session
def test_handle_limit_offset(self, url, expected_conn_ids, session):
connections = self._create_connections(10)
session.add_all(connections)
session.commit()
response = self.client.get(url)
assert response.status_code == 200
self.assertEqual(response.json["total_entries"], 10)
conn_ids = [conn["connection_id"] for conn in response.json["connections"] if conn]
self.assertEqual(conn_ids, expected_conn_ids)
@provide_session
def test_should_respect_page_size_limit(self, session):
connection_models = self._create_connections(200)
session.add_all(connection_models)
session.commit()
response = self.client.get("/api/v1/connections?limit=150")
assert response.status_code == 200
self.assertEqual(response.json["total_entries"], 200)
self.assertEqual(len(response.json["connections"]), 100)
def _create_connections(self, count):
return [Connection(
conn_id='TEST_CONN_ID' + str(i),
conn_type='TEST_CONN_TYPE' + str(i)
) for i in range(1, count + 1)]
class TestPatchConnection(TestConnectionEndpoint): class TestPatchConnection(TestConnectionEndpoint):
@pytest.mark.skip(reason="Not implemented yet") @unittest.skip("Not implemented yet")
def test_should_response_200(self): def test_should_response_200(self):
response = self.client.patch("/api/v1/connections/1") response = self.client.patch("/api/v1/connections/1")
assert response.status_code == 200 assert response.status_code == 200
class TestPostConnection(TestConnectionEndpoint): class TestPostConnection(TestConnectionEndpoint):
@pytest.mark.skip(reason="Not implemented yet") @unittest.skip("Not implemented yet")
def test_should_response_200(self): def test_should_response_200(self):
response = self.client.post("/api/v1/connections/") response = self.client.post("/api/v1/connections/")
assert response.status_code == 200 assert response.status_code == 200

Просмотреть файл

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

Просмотреть файл

@ -0,0 +1,210 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from airflow.api_connexion.schemas.connection_schema import (
ConnectionCollection, connection_collection_item_schema, connection_collection_schema, connection_schema,
)
from airflow.models import Connection
from airflow.utils.session import create_session, provide_session
from tests.test_utils.db import clear_db_connections
class TestConnectionCollectionItemSchema(unittest.TestCase):
def setUp(self) -> None:
with create_session() as session:
session.query(Connection).delete()
def tearDown(self) -> None:
clear_db_connections()
@provide_session
def test_serialzie(self, session):
connection_model = Connection(
conn_id='mysql_default',
conn_type='mysql',
host='mysql',
login='login',
schema='testschema',
port=80
)
session.add(connection_model)
session.commit()
connection_model = session.query(Connection).first()
deserialized_connection = connection_collection_item_schema.dump(connection_model)
self.assertEqual(
deserialized_connection[0],
{
'connection_id': "mysql_default",
'conn_type': 'mysql',
'host': 'mysql',
'login': 'login',
'schema': 'testschema',
'port': 80
}
)
def test_deserialize(self):
connection_dump_1 = {
'connection_id': "mysql_default_1",
'conn_type': 'mysql',
'host': 'mysql',
'login': 'login',
'schema': 'testschema',
'port': 80
}
connection_dump_2 = {
'connection_id': "mysql_default_2"
}
result_1 = connection_collection_item_schema.load(connection_dump_1)
result_2 = connection_collection_item_schema.load(connection_dump_2)
self.assertEqual(
result_1[0],
{
'conn_id': "mysql_default_1",
'conn_type': 'mysql',
'host': 'mysql',
'login': 'login',
'schema': 'testschema',
'port': 80
}
)
self.assertEqual(
result_2[0],
{
'conn_id': "mysql_default_2",
}
)
class TestConnectionCollectionSchema(unittest.TestCase):
def setUp(self) -> None:
with create_session() as session:
session.query(Connection).delete()
def tearDown(self) -> None:
clear_db_connections()
@provide_session
def test_serialzie(self, session):
connection_model_1 = Connection(
conn_id='mysql_default_1',
conn_type='test-type'
)
connection_model_2 = Connection(
conn_id='mysql_default_2',
conn_type='test-type2'
)
connections = [connection_model_1, connection_model_2]
session.add_all(connections)
session.commit()
instance = ConnectionCollection(
connections=connections,
total_entries=2
)
deserialized_connections = connection_collection_schema.dump(instance)
self.assertEqual(
deserialized_connections[0],
{
'connections': [
{
"connection_id": "mysql_default_1",
"conn_type": "test-type",
"host": None,
"login": None,
'schema': None,
'port': None
},
{
"connection_id": "mysql_default_2",
"conn_type": "test-type2",
"host": None,
"login": None,
'schema': None,
'port': None
}
],
'total_entries': 2
}
)
class TestConnectionSchema(unittest.TestCase):
def setUp(self) -> None:
with create_session() as session:
session.query(Connection).delete()
def tearDown(self) -> None:
clear_db_connections()
@provide_session
def test_serialize(self, session):
connection_model = Connection(
conn_id='mysql_default',
conn_type='mysql',
host='mysql',
login='login',
schema='testschema',
port=80,
password='test-password',
extra="{'key':'string'}"
)
session.add(connection_model)
session.commit()
connection_model = session.query(Connection).first()
deserialized_connection = connection_schema.dump(connection_model)
self.assertEqual(
deserialized_connection[0],
{
'connection_id': "mysql_default",
'conn_type': 'mysql',
'host': 'mysql',
'login': 'login',
'schema': 'testschema',
'port': 80,
'extra': "{'key':'string'}"
}
)
def test_deserialize(self):
den = {
'connection_id': "mysql_default",
'conn_type': 'mysql',
'host': 'mysql',
'login': 'login',
'schema': 'testschema',
'port': 80,
'extra': "{'key':'string'}"
}
result = connection_schema.load(den)
self.assertEqual(
result[0],
{
'conn_id': "mysql_default",
'conn_type': 'mysql',
'host': 'mysql',
'login': 'login',
'schema': 'testschema',
'port': 80,
'extra': "{'key':'string'}"
}
)