This commit is contained in:
Omair Khan 2020-08-21 15:58:21 +05:30 коммит произвёл GitHub
Родитель 2f552233f5
Коммит 1e371864cc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 153 добавлений и 81 удалений

Просмотреть файл

@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from flask import current_app from flask import current_app, request
from marshmallow import ValidationError
from sqlalchemy import func from sqlalchemy import func
from airflow import DAG from airflow import DAG
from airflow.api_connexion import security from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.dag_schema import ( from airflow.api_connexion.schemas.dag_schema import (
DAGCollection, dag_detail_schema, dag_schema, dags_collection_schema, DAGCollection, dag_detail_schema, dag_schema, dags_collection_schema,
@ -70,8 +71,19 @@ def get_dags(session, limit, offset=0):
@security.requires_authentication @security.requires_authentication
def patch_dag(): @provide_session
def patch_dag(session, dag_id):
""" """
Update the specific DAG Update the specific DAG
""" """
raise NotImplementedError("Not implemented yet.") dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none()
if not dag:
raise NotFound(f"Dag with id: '{dag_id}' not found")
try:
patch_body = dag_schema.load(request.json, session=session)
except ValidationError as err:
raise BadRequest("Invalid Dag schema", detail=str(err.messages))
for key, value in patch_body.items():
setattr(dag, key, value)
session.commit()
return dag_schema.dump(dag)

Просмотреть файл

@ -1171,7 +1171,6 @@ components:
nullable: true nullable: true
schedule_interval: schedule_interval:
$ref: '#/components/schemas/ScheduleInterval' $ref: '#/components/schemas/ScheduleInterval'
readOnly: true
tags: tags:
type: array type: array
nullable: true nullable: true
@ -1938,6 +1937,7 @@ components:
# Common data type # Common data type
ScheduleInterval: ScheduleInterval:
readOnly: true
oneOf: oneOf:
- $ref: '#/components/schemas/TimeDelta' - $ref: '#/components/schemas/TimeDelta'
- $ref: '#/components/schemas/RelativeDelta' - $ref: '#/components/schemas/RelativeDelta'

Просмотреть файл

@ -83,7 +83,7 @@ class RelativeDeltaSchema(Schema):
class CronExpressionSchema(Schema): class CronExpressionSchema(Schema):
"""Cron expression schema""" """Cron expression schema"""
objectType = fields.Constant("CronExpression", data_key="__type", required=True) objectType = fields.Constant("CronExpression", data_key="__type")
value = fields.String(required=True) value = fields.String(required=True)
@marshmallow.post_load @marshmallow.post_load

Просмотреть файл

@ -43,12 +43,12 @@ class DAGSchema(SQLAlchemySchema):
dag_id = auto_field(dump_only=True) dag_id = auto_field(dump_only=True)
root_dag_id = auto_field(dump_only=True) root_dag_id = auto_field(dump_only=True)
is_paused = auto_field(dump_only=True) is_paused = auto_field()
is_subdag = auto_field(dump_only=True) is_subdag = auto_field(dump_only=True)
fileloc = auto_field(dump_only=True) fileloc = auto_field(dump_only=True)
owners = fields.Method("get_owners", dump_only=True) owners = fields.Method("get_owners", dump_only=True)
description = auto_field(dump_only=True) description = auto_field(dump_only=True)
schedule_interval = fields.Nested(ScheduleIntervalSchema, dump_only=True) schedule_interval = fields.Nested(ScheduleIntervalSchema)
tags = fields.List(fields.Nested(DagTagSchema), dump_only=True) tags = fields.List(fields.Nested(DagTagSchema), dump_only=True)
@staticmethod @staticmethod

Просмотреть файл

@ -18,7 +18,6 @@ import os
import unittest import unittest
from datetime import datetime from datetime import datetime
import pytest
from parameterized import parameterized from parameterized import parameterized
from airflow import DAG from airflow import DAG
@ -77,7 +76,7 @@ class TestDagEndpoint(unittest.TestCase):
dag_model = DagModel( dag_model = DagModel(
dag_id=f"TEST_DAG_{num}", dag_id=f"TEST_DAG_{num}",
fileloc=f"/tmp/dag_{num}.py", fileloc=f"/tmp/dag_{num}.py",
schedule_interval="2 2 * * *" schedule_interval="2 2 * * *",
) )
session.add(dag_model) session.add(dag_model)
@ -90,17 +89,20 @@ class TestGetDag(TestDagEndpoint):
current_response = response.json current_response = response.json
current_response["fileloc"] = "/tmp/test-dag.py" current_response["fileloc"] = "/tmp/test-dag.py"
self.assertEqual({ self.assertEqual(
'dag_id': 'TEST_DAG_1', {
'description': None, "dag_id": "TEST_DAG_1",
'fileloc': '/tmp/test-dag.py', "description": None,
'is_paused': False, "fileloc": "/tmp/test-dag.py",
'is_subdag': False, "is_paused": False,
'owners': [], "is_subdag": False,
'root_dag_id': None, "owners": [],
'schedule_interval': {'__type': 'CronExpression', 'value': '2 2 * * *'}, "root_dag_id": None,
'tags': [] "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"},
}, current_response) "tags": [],
},
current_response,
)
def test_should_response_404(self): def test_should_response_404(self):
response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={'REMOTE_USER': "test"}) response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={'REMOTE_USER': "test"})
@ -121,27 +123,27 @@ class TestGetDagDetails(TestDagEndpoint):
) )
assert response.status_code == 200 assert response.status_code == 200
expected = { expected = {
'catchup': True, "catchup": True,
'concurrency': 16, "concurrency": 16,
'dag_id': 'test_dag', "dag_id": "test_dag",
'dag_run_timeout': None, "dag_run_timeout": None,
'default_view': 'tree', "default_view": "tree",
'description': None, "description": None,
'doc_md': 'details', "doc_md": "details",
'fileloc': __file__, "fileloc": __file__,
'is_paused': None, "is_paused": None,
'is_subdag': False, "is_subdag": False,
'orientation': 'LR', "orientation": "LR",
'owners': [], "owners": [],
'schedule_interval': { "schedule_interval": {
'__type': 'TimeDelta', "__type": "TimeDelta",
'days': 1, "days": 1,
'microseconds': 0, "microseconds": 0,
'seconds': 0 "seconds": 0,
}, },
'start_date': '2020-06-15T00:00:00+00:00', "start_date": "2020-06-15T00:00:00+00:00",
'tags': None, "tags": None,
'timezone': "Timezone('UTC')" "timezone": "Timezone('UTC')",
} }
assert response.json == expected assert response.json == expected
@ -155,27 +157,27 @@ class TestGetDagDetails(TestDagEndpoint):
SerializedDagModel.write_dag(self.dag) SerializedDagModel.write_dag(self.dag)
expected = { expected = {
'catchup': True, "catchup": True,
'concurrency': 16, "concurrency": 16,
'dag_id': 'test_dag', "dag_id": "test_dag",
'dag_run_timeout': None, "dag_run_timeout": None,
'default_view': 'tree', "default_view": "tree",
'description': None, "description": None,
'doc_md': 'details', "doc_md": "details",
'fileloc': __file__, "fileloc": __file__,
'is_paused': None, "is_paused": None,
'is_subdag': False, "is_subdag": False,
'orientation': 'LR', "orientation": "LR",
'owners': [], "owners": [],
'schedule_interval': { "schedule_interval": {
'__type': 'TimeDelta', "__type": "TimeDelta",
'days': 1, "days": 1,
'microseconds': 0, "microseconds": 0,
'seconds': 0 "seconds": 0,
}, },
'start_date': '2020-06-15T00:00:00+00:00', "start_date": "2020-06-15T00:00:00+00:00",
'tags': None, "tags": None,
'timezone': "Timezone('UTC')" "timezone": "Timezone('UTC')",
} }
response = client.get( response = client.get(
f"/api/v1/dags/{self.dag_id}/details", environ_overrides={'REMOTE_USER': "test"} f"/api/v1/dags/{self.dag_id}/details", environ_overrides={'REMOTE_USER': "test"}
@ -219,7 +221,6 @@ class TestGetDagDetails(TestDagEndpoint):
class TestGetDags(TestDagEndpoint): class TestGetDags(TestDagEndpoint):
def test_should_response_200(self): def test_should_response_200(self):
self._create_dag_models(2) self._create_dag_models(2)
@ -238,7 +239,10 @@ class TestGetDags(TestDagEndpoint):
"is_subdag": False, "is_subdag": False,
"owners": [], "owners": [],
"root_dag_id": None, "root_dag_id": None,
"schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"}, "schedule_interval": {
"__type": "CronExpression",
"value": "2 2 * * *",
},
"tags": [], "tags": [],
}, },
{ {
@ -249,7 +253,10 @@ class TestGetDags(TestDagEndpoint):
"is_subdag": False, "is_subdag": False,
"owners": [], "owners": [],
"root_dag_id": None, "root_dag_id": None,
"schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"}, "schedule_interval": {
"__type": "CronExpression",
"value": "2 2 * * *",
},
"tags": [], "tags": [],
}, },
], ],
@ -264,13 +271,7 @@ class TestGetDags(TestDagEndpoint):
("api/v1/dags?limit=2", ["TEST_DAG_1", "TEST_DAG_10"]), ("api/v1/dags?limit=2", ["TEST_DAG_1", "TEST_DAG_10"]),
( (
"api/v1/dags?offset=5", "api/v1/dags?offset=5",
[ ["TEST_DAG_5", "TEST_DAG_6", "TEST_DAG_7", "TEST_DAG_8", "TEST_DAG_9"],
"TEST_DAG_5",
"TEST_DAG_6",
"TEST_DAG_7",
"TEST_DAG_8",
"TEST_DAG_9",
],
), ),
( (
"api/v1/dags?offset=0", "api/v1/dags?offset=0",
@ -299,10 +300,10 @@ class TestGetDags(TestDagEndpoint):
assert response.status_code == 200 assert response.status_code == 200
dag_ids = [dag["dag_id"] for dag in response.json['dags']] dag_ids = [dag["dag_id"] for dag in response.json["dags"]]
self.assertEqual(expected_dag_ids, dag_ids) self.assertEqual(expected_dag_ids, dag_ids)
self.assertEqual(10, response.json['total_entries']) self.assertEqual(10, response.json["total_entries"])
def test_should_response_200_default_limit(self): def test_should_response_200_default_limit(self):
self._create_dag_models(101) self._create_dag_models(101)
@ -311,8 +312,8 @@ class TestGetDags(TestDagEndpoint):
assert response.status_code == 200 assert response.status_code == 200
self.assertEqual(100, len(response.json['dags'])) self.assertEqual(100, len(response.json["dags"]))
self.assertEqual(101, response.json['total_entries']) self.assertEqual(101, response.json["total_entries"])
def test_should_raises_401_unauthenticated(self): def test_should_raises_401_unauthenticated(self):
response = self.client.get("api/v1/dags") response = self.client.get("api/v1/dags")
@ -321,13 +322,72 @@ class TestGetDags(TestDagEndpoint):
class TestPatchDag(TestDagEndpoint): class TestPatchDag(TestDagEndpoint):
@pytest.mark.skip(reason="Not implemented yet") def test_should_response_200_on_patch_is_paused(self):
def test_should_response_200(self): dag_model = self._create_dag_model()
response = self.client.patch("/api/v1/dags/1", environ_overrides={'REMOTE_USER': "test"}) response = self.client.patch(
assert response.status_code == 200 f"/api/v1/dags/{dag_model.dag_id}",
json={
"is_paused": False,
},
environ_overrides={'REMOTE_USER': "test"}
)
self.assertEqual(response.status_code, 200)
expected_response = {
"dag_id": "TEST_DAG_1",
"description": None,
"fileloc": "/tmp/dag_1.py",
"is_paused": False,
"is_subdag": False,
"owners": [],
"root_dag_id": None,
"schedule_interval": {
"__type": "CronExpression",
"value": "2 2 * * *",
},
"tags": [],
}
self.assertEqual(response.json, expected_response)
def test_should_response_400_on_invalid_request(self):
patch_body = {
"is_paused": True,
"schedule_interval": {
"__type": "CronExpression",
"value": "1 1 * * *",
},
}
dag_model = self._create_dag_model()
response = self.client.patch(f"/api/v1/dags/{dag_model.dag_id}", json=patch_body)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json, {
'detail': "Property is read-only - 'schedule_interval'",
'status': 400,
'title': 'Bad Request',
'type': 'about:blank'
})
def test_should_response_404(self):
response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={'REMOTE_USER': "test"})
self.assertEqual(response.status_code, 404)
@provide_session
def _create_dag_model(self, session=None):
dag_model = DagModel(
dag_id="TEST_DAG_1",
fileloc="/tmp/dag_1.py",
schedule_interval="2 2 * * *",
is_paused=True
)
session.add(dag_model)
return dag_model
@pytest.mark.skip(reason="Not implemented yet")
def test_should_raises_401_unauthenticated(self): def test_should_raises_401_unauthenticated(self):
response = self.client.patch("/api/v1/dags/1") dag_model = self._create_dag_model()
response = self.client.patch(
f"/api/v1/dags/{dag_model.dag_id}",
json={
"is_paused": False,
},
)
assert_401(response) assert_401(response)