From 1e371864cc628524c9e28a7b69d370a7c05467a2 Mon Sep 17 00:00:00 2001 From: Omair Khan Date: Fri, 21 Aug 2020 15:58:21 +0530 Subject: [PATCH] Add update endpoint for DAG (#9101) (#9740) --- .../api_connexion/endpoints/dag_endpoint.py | 20 +- airflow/api_connexion/openapi/v1.yaml | 2 +- .../api_connexion/schemas/common_schema.py | 2 +- airflow/api_connexion/schemas/dag_schema.py | 4 +- .../endpoints/test_dag_endpoint.py | 206 +++++++++++------- 5 files changed, 153 insertions(+), 81 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index a743e5191d..27634885f8 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from flask import current_app +from flask import current_app, request +from marshmallow import ValidationError from sqlalchemy import func from airflow import DAG from airflow.api_connexion import security -from airflow.api_connexion.exceptions import NotFound +from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.dag_schema import ( DAGCollection, dag_detail_schema, dag_schema, dags_collection_schema, @@ -70,8 +71,19 @@ def get_dags(session, limit, offset=0): @security.requires_authentication -def patch_dag(): +@provide_session +def patch_dag(session, dag_id): """ Update the specific DAG """ - raise NotImplementedError("Not implemented yet.") + dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none() + if not dag: + raise NotFound(f"Dag with id: '{dag_id}' not found") + try: + patch_body = dag_schema.load(request.json, session=session) + except ValidationError as err: + raise BadRequest("Invalid Dag schema", detail=str(err.messages)) + for key, value in patch_body.items(): + setattr(dag, key, value) + session.commit() + return dag_schema.dump(dag) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 5f7e339052..6e859a55cc 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -1171,7 +1171,6 @@ components: nullable: true schedule_interval: $ref: '#/components/schemas/ScheduleInterval' - readOnly: true tags: type: array nullable: true @@ -1938,6 +1937,7 @@ components: # Common data type ScheduleInterval: + readOnly: true oneOf: - $ref: '#/components/schemas/TimeDelta' - $ref: '#/components/schemas/RelativeDelta' diff --git a/airflow/api_connexion/schemas/common_schema.py b/airflow/api_connexion/schemas/common_schema.py index 160e12cc40..27fd413206 100644 --- a/airflow/api_connexion/schemas/common_schema.py +++ b/airflow/api_connexion/schemas/common_schema.py @@ -83,7 +83,7 @@ class RelativeDeltaSchema(Schema): class CronExpressionSchema(Schema): """Cron expression schema""" - objectType = fields.Constant("CronExpression", data_key="__type", required=True) + objectType = fields.Constant("CronExpression", data_key="__type") value = fields.String(required=True) @marshmallow.post_load diff --git a/airflow/api_connexion/schemas/dag_schema.py b/airflow/api_connexion/schemas/dag_schema.py index bae2228403..ec07695eb4 100644 --- a/airflow/api_connexion/schemas/dag_schema.py +++ b/airflow/api_connexion/schemas/dag_schema.py @@ -43,12 +43,12 @@ class DAGSchema(SQLAlchemySchema): dag_id = auto_field(dump_only=True) root_dag_id = auto_field(dump_only=True) - is_paused = auto_field(dump_only=True) + is_paused = auto_field() is_subdag = auto_field(dump_only=True) fileloc = auto_field(dump_only=True) owners = fields.Method("get_owners", dump_only=True) description = auto_field(dump_only=True) - schedule_interval = fields.Nested(ScheduleIntervalSchema, dump_only=True) + schedule_interval = fields.Nested(ScheduleIntervalSchema) tags = fields.List(fields.Nested(DagTagSchema), dump_only=True) @staticmethod diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 3b25c7a1cc..66c7cf7ec8 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -18,7 +18,6 @@ import os import unittest from datetime import datetime -import pytest from parameterized import parameterized from airflow import DAG @@ -77,7 +76,7 @@ class TestDagEndpoint(unittest.TestCase): dag_model = DagModel( dag_id=f"TEST_DAG_{num}", fileloc=f"/tmp/dag_{num}.py", - schedule_interval="2 2 * * *" + schedule_interval="2 2 * * *", ) session.add(dag_model) @@ -90,17 +89,20 @@ class TestGetDag(TestDagEndpoint): current_response = response.json current_response["fileloc"] = "/tmp/test-dag.py" - self.assertEqual({ - 'dag_id': 'TEST_DAG_1', - 'description': None, - 'fileloc': '/tmp/test-dag.py', - 'is_paused': False, - 'is_subdag': False, - 'owners': [], - 'root_dag_id': None, - 'schedule_interval': {'__type': 'CronExpression', 'value': '2 2 * * *'}, - 'tags': [] - }, current_response) + self.assertEqual( + { + "dag_id": "TEST_DAG_1", + "description": None, + "fileloc": "/tmp/test-dag.py", + "is_paused": False, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"}, + "tags": [], + }, + current_response, + ) def test_should_response_404(self): response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={'REMOTE_USER': "test"}) @@ -121,27 +123,27 @@ class TestGetDagDetails(TestDagEndpoint): ) assert response.status_code == 200 expected = { - 'catchup': True, - 'concurrency': 16, - 'dag_id': 'test_dag', - 'dag_run_timeout': None, - 'default_view': 'tree', - 'description': None, - 'doc_md': 'details', - 'fileloc': __file__, - 'is_paused': None, - 'is_subdag': False, - 'orientation': 'LR', - 'owners': [], - 'schedule_interval': { - '__type': 'TimeDelta', - 'days': 1, - 'microseconds': 0, - 'seconds': 0 + "catchup": True, + "concurrency": 16, + "dag_id": "test_dag", + "dag_run_timeout": None, + "default_view": "tree", + "description": None, + "doc_md": "details", + "fileloc": __file__, + "is_paused": None, + "is_subdag": False, + "orientation": "LR", + "owners": [], + "schedule_interval": { + "__type": "TimeDelta", + "days": 1, + "microseconds": 0, + "seconds": 0, }, - 'start_date': '2020-06-15T00:00:00+00:00', - 'tags': None, - 'timezone': "Timezone('UTC')" + "start_date": "2020-06-15T00:00:00+00:00", + "tags": None, + "timezone": "Timezone('UTC')", } assert response.json == expected @@ -155,27 +157,27 @@ class TestGetDagDetails(TestDagEndpoint): SerializedDagModel.write_dag(self.dag) expected = { - 'catchup': True, - 'concurrency': 16, - 'dag_id': 'test_dag', - 'dag_run_timeout': None, - 'default_view': 'tree', - 'description': None, - 'doc_md': 'details', - 'fileloc': __file__, - 'is_paused': None, - 'is_subdag': False, - 'orientation': 'LR', - 'owners': [], - 'schedule_interval': { - '__type': 'TimeDelta', - 'days': 1, - 'microseconds': 0, - 'seconds': 0 + "catchup": True, + "concurrency": 16, + "dag_id": "test_dag", + "dag_run_timeout": None, + "default_view": "tree", + "description": None, + "doc_md": "details", + "fileloc": __file__, + "is_paused": None, + "is_subdag": False, + "orientation": "LR", + "owners": [], + "schedule_interval": { + "__type": "TimeDelta", + "days": 1, + "microseconds": 0, + "seconds": 0, }, - 'start_date': '2020-06-15T00:00:00+00:00', - 'tags': None, - 'timezone': "Timezone('UTC')" + "start_date": "2020-06-15T00:00:00+00:00", + "tags": None, + "timezone": "Timezone('UTC')", } response = client.get( f"/api/v1/dags/{self.dag_id}/details", environ_overrides={'REMOTE_USER': "test"} @@ -219,7 +221,6 @@ class TestGetDagDetails(TestDagEndpoint): class TestGetDags(TestDagEndpoint): - def test_should_response_200(self): self._create_dag_models(2) @@ -238,7 +239,10 @@ class TestGetDags(TestDagEndpoint): "is_subdag": False, "owners": [], "root_dag_id": None, - "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"}, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, "tags": [], }, { @@ -249,7 +253,10 @@ class TestGetDags(TestDagEndpoint): "is_subdag": False, "owners": [], "root_dag_id": None, - "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *"}, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, "tags": [], }, ], @@ -264,13 +271,7 @@ class TestGetDags(TestDagEndpoint): ("api/v1/dags?limit=2", ["TEST_DAG_1", "TEST_DAG_10"]), ( "api/v1/dags?offset=5", - [ - "TEST_DAG_5", - "TEST_DAG_6", - "TEST_DAG_7", - "TEST_DAG_8", - "TEST_DAG_9", - ], + ["TEST_DAG_5", "TEST_DAG_6", "TEST_DAG_7", "TEST_DAG_8", "TEST_DAG_9"], ), ( "api/v1/dags?offset=0", @@ -299,10 +300,10 @@ class TestGetDags(TestDagEndpoint): assert response.status_code == 200 - dag_ids = [dag["dag_id"] for dag in response.json['dags']] + dag_ids = [dag["dag_id"] for dag in response.json["dags"]] self.assertEqual(expected_dag_ids, dag_ids) - self.assertEqual(10, response.json['total_entries']) + self.assertEqual(10, response.json["total_entries"]) def test_should_response_200_default_limit(self): self._create_dag_models(101) @@ -311,8 +312,8 @@ class TestGetDags(TestDagEndpoint): assert response.status_code == 200 - self.assertEqual(100, len(response.json['dags'])) - self.assertEqual(101, response.json['total_entries']) + self.assertEqual(100, len(response.json["dags"])) + self.assertEqual(101, response.json["total_entries"]) def test_should_raises_401_unauthenticated(self): response = self.client.get("api/v1/dags") @@ -321,13 +322,72 @@ class TestGetDags(TestDagEndpoint): class TestPatchDag(TestDagEndpoint): - @pytest.mark.skip(reason="Not implemented yet") - def test_should_response_200(self): - response = self.client.patch("/api/v1/dags/1", environ_overrides={'REMOTE_USER': "test"}) - assert response.status_code == 200 + def test_should_response_200_on_patch_is_paused(self): + dag_model = self._create_dag_model() + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}", + json={ + "is_paused": False, + }, + environ_overrides={'REMOTE_USER': "test"} + ) + self.assertEqual(response.status_code, 200) + expected_response = { + "dag_id": "TEST_DAG_1", + "description": None, + "fileloc": "/tmp/dag_1.py", + "is_paused": False, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + } + self.assertEqual(response.json, expected_response) + + def test_should_response_400_on_invalid_request(self): + patch_body = { + "is_paused": True, + "schedule_interval": { + "__type": "CronExpression", + "value": "1 1 * * *", + }, + } + dag_model = self._create_dag_model() + response = self.client.patch(f"/api/v1/dags/{dag_model.dag_id}", json=patch_body) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json, { + 'detail': "Property is read-only - 'schedule_interval'", + 'status': 400, + 'title': 'Bad Request', + 'type': 'about:blank' + }) + + def test_should_response_404(self): + response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={'REMOTE_USER': "test"}) + self.assertEqual(response.status_code, 404) + + @provide_session + def _create_dag_model(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_1", + fileloc="/tmp/dag_1.py", + schedule_interval="2 2 * * *", + is_paused=True + ) + session.add(dag_model) + return dag_model - @pytest.mark.skip(reason="Not implemented yet") def test_should_raises_401_unauthenticated(self): - response = self.client.patch("/api/v1/dags/1") + dag_model = self._create_dag_model() + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}", + json={ + "is_paused": False, + }, + ) assert_401(response)