This commit is contained in:
Omair Khan 2020-07-08 16:42:26 +05:30 коммит произвёл GitHub
Родитель 07b81029eb
Коммит 7a4988a3c7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 421 добавлений и 201 удалений

Просмотреть файл

@ -14,23 +14,33 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from connexion import NoContent
from flask import request
from sqlalchemy import and_, func
from sqlalchemy import func
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.exceptions import AlreadyExists, NotFound
from airflow.api_connexion.parameters import check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dag_run_schema import (
DAGRunCollection, dagrun_collection_schema, dagrun_schema,
)
from airflow.models import DagRun
from airflow.models import DagModel, DagRun
from airflow.utils.session import provide_session
from airflow.utils.types import DagRunType
def delete_dag_run():
@provide_session
def delete_dag_run(dag_id, dag_run_id, session):
"""
Delete a DAG Run
"""
raise NotImplementedError("Not implemented yet.")
if (
session.query(DagRun)
.filter(and_(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id))
.delete()
== 0
):
raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found")
return NoContent, 204
@provide_session
@ -55,9 +65,18 @@ def get_dag_run(dag_id, dag_run_id, session):
'limit': check_limit
})
@provide_session
def get_dag_runs(session, dag_id, start_date_gte=None, start_date_lte=None,
execution_date_gte=None, execution_date_lte=None,
end_date_gte=None, end_date_lte=None, offset=None, limit=None):
def get_dag_runs(
session,
dag_id,
start_date_gte=None,
start_date_lte=None,
execution_date_gte=None,
execution_date_lte=None,
end_date_gte=None,
end_date_lte=None,
offset=None,
limit=None,
):
"""
Get all DAG Runs.
"""
@ -65,7 +84,7 @@ def get_dag_runs(session, dag_id, start_date_gte=None, start_date_lte=None,
query = session.query(DagRun)
# This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs.
if dag_id != '~':
if dag_id != "~":
query = query.filter(DagRun.dag_id == dag_id)
# filter start date
@ -93,8 +112,9 @@ def get_dag_runs(session, dag_id, start_date_gte=None, start_date_lte=None,
dag_run = query.order_by(DagRun.id).offset(offset).limit(limit).all()
total_entries = session.query(func.count(DagRun.id)).scalar()
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run,
total_entries=total_entries))
return dagrun_collection_schema.dump(
DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)
)
def get_dag_runs_batch():
@ -104,8 +124,25 @@ def get_dag_runs_batch():
raise NotImplementedError("Not implemented yet.")
def post_dag_run():
@provide_session
def post_dag_run(dag_id, session):
"""
Trigger a DAG.
"""
raise NotImplementedError("Not implemented yet.")
if not session.query(DagModel).filter(DagModel.dag_id == dag_id).first():
raise NotFound(f"DAG with dag_id: '{dag_id}' not found")
post_body = dagrun_schema.load(request.json, session=session)
dagrun_instance = (
session.query(DagRun)
.filter(and_(DagRun.dag_id == dag_id, DagRun.run_id == post_body["run_id"]))
.first()
)
if not dagrun_instance:
dag_run = DagRun(dag_id=dag_id, run_type=DagRunType.MANUAL.value, **post_body)
session.add(dag_run)
session.commit()
return dagrun_schema.dump(dag_run)
raise AlreadyExists(
detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{post_body['run_id']}' already exists"
)

Просмотреть файл

@ -287,6 +287,34 @@ paths:
'401':
$ref: '#/components/responses/Unauthenticated'
post:
summary: Trigger a DAG Run
operationId: airflow.api_connexion.endpoints.dag_run_endpoint.post_dag_run
tags: [DAGRun]
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/DAGRun'
responses:
'200':
description: Successful response.
content:
application/json:
schema:
$ref: '#/components/schemas/DAGRun'
'400':
$ref: '#/components/responses/BadRequest'
'401':
$ref: '#/components/responses/Unauthenticated'
'409':
$ref: '#/components/responses/AlreadyExists'
'403':
$ref: '#/components/responses/PermissionDenied'
'404':
$ref: '#/components/responses/NotFound'
/dags/~/dagRuns/list:
post:
summary: Get all DAG Runs from aall DAGs.
@ -342,33 +370,6 @@ paths:
'404':
$ref: '#/components/responses/NotFound'
post:
summary: Trigger a DAG Run
x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint
operationId: post_dag_run
tags: [DAGRun]
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/DAGRun'
responses:
'200':
description: Successful response.
content:
application/json:
schema:
$ref: '#/components/schemas/DAGRun'
'400':
$ref: '#/components/responses/BadRequest'
'401':
$ref: '#/components/responses/Unauthenticated'
'409':
$ref: '#/components/responses/AlreadyExists'
'403':
$ref: '#/components/responses/PermissionDenied'
delete:
summary: Delete a DAG Run
x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint
@ -1195,6 +1196,7 @@ components:
If the specified dag_run_id is in use, the creation request fails with an ALREADY_EXISTS error.
This together with DAG_ID are a unique key.
nullable: true
dag_id:
type: string
readOnly: true
@ -1222,6 +1224,7 @@ components:
nullable: true
state:
$ref: '#/components/schemas/DagState'
readOnly: True
external_trigger:
type: boolean
default: true

Просмотреть файл

@ -18,12 +18,14 @@
import json
from typing import List, NamedTuple
from marshmallow import fields
from marshmallow import fields, pre_load
from marshmallow.schema import Schema
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
from airflow.api_connexion.schemas.enum_schemas import DagStateField
from airflow.models.dagrun import DagRun
from airflow.utils import timezone
from airflow.utils.types import DagRunType
class ConfObject(fields.Field):
@ -46,18 +48,32 @@ class DAGRunSchema(SQLAlchemySchema):
class Meta:
""" Meta """
model = DagRun
dateformat = 'iso'
dateformat = "iso"
run_id = auto_field(data_key='dag_run_id')
dag_id = auto_field(dump_only=True)
execution_date = auto_field()
start_date = auto_field(dump_only=True)
end_date = auto_field(dump_only=True)
state = DagStateField()
state = DagStateField(dump_only=True)
external_trigger = auto_field(default=True, dump_only=True)
conf = ConfObject()
@pre_load
def autogenerate(self, data, **kwargs):
"""
Auto generate run_id and execution_date if they are not loaded
"""
if "execution_date" not in data.keys():
data["execution_date"] = str(timezone.utcnow())
if "dag_run_id" not in data.keys():
data["dag_run_id"] = DagRun.generate_run_id(
DagRunType.MANUAL, timezone.parse(data["execution_date"])
)
return data
class DAGRunCollection(NamedTuple):
"""List of DAGRuns with metadata"""
@ -68,6 +84,7 @@ class DAGRunCollection(NamedTuple):
class DAGRunCollectionSchema(Schema):
"""DAGRun Collection schema"""
dag_runs = fields.List(fields.Nested(DAGRunSchema))
total_entries = fields.Int()

Просмотреть файл

@ -17,16 +17,15 @@
import unittest
from datetime import timedelta
import pytest
from parameterized import parameterized
from airflow.models import DagRun
from airflow.models import DagModel, DagRun
from airflow.utils import timezone
from airflow.utils.session import provide_session
from airflow.utils.types import DagRunType
from airflow.www import app
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_runs
from tests.test_utils.db import clear_db_dags, clear_db_runs
class TestDagRunEndpoint(unittest.TestCase):
@ -38,17 +37,18 @@ class TestDagRunEndpoint(unittest.TestCase):
def setUp(self) -> None:
self.client = self.app.test_client() # type:ignore
self.default_time = '2020-06-11T18:00:00+00:00'
self.default_time_2 = '2020-06-12T18:00:00+00:00'
self.default_time = "2020-06-11T18:00:00+00:00"
self.default_time_2 = "2020-06-12T18:00:00+00:00"
clear_db_runs()
clear_db_dags()
def tearDown(self) -> None:
clear_db_runs()
def _create_test_dag_run(self, state='running', extra_dag=False):
def _create_test_dag_run(self, state="running", extra_dag=False):
dagrun_model_1 = DagRun(
dag_id='TEST_DAG_ID',
run_id='TEST_DAG_RUN_ID_1',
dag_id="TEST_DAG_ID",
run_id="TEST_DAG_RUN_ID_1",
run_type=DagRunType.MANUAL.value,
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
@ -56,39 +56,64 @@ class TestDagRunEndpoint(unittest.TestCase):
state=state,
)
dagrun_model_2 = DagRun(
dag_id='TEST_DAG_ID',
run_id='TEST_DAG_RUN_ID_2',
dag_id="TEST_DAG_ID",
run_id="TEST_DAG_RUN_ID_2",
run_type=DagRunType.MANUAL.value,
execution_date=timezone.parse(self.default_time_2),
start_date=timezone.parse(self.default_time),
external_trigger=True,
)
if extra_dag:
dagrun_extra = [DagRun(
dag_id='TEST_DAG_ID_' + str(i),
run_id='TEST_DAG_RUN_ID_' + str(i),
run_type=DagRunType.MANUAL.value,
execution_date=timezone.parse(self.default_time_2),
start_date=timezone.parse(self.default_time),
external_trigger=True,
) for i in range(3, 5)]
dagrun_extra = [
DagRun(
dag_id="TEST_DAG_ID_" + str(i),
run_id="TEST_DAG_RUN_ID_" + str(i),
run_type=DagRunType.MANUAL.value,
execution_date=timezone.parse(self.default_time_2),
start_date=timezone.parse(self.default_time),
external_trigger=True,
)
for i in range(3, 5)
]
return [dagrun_model_1, dagrun_model_2] + dagrun_extra
return [dagrun_model_1, dagrun_model_2]
class TestDeleteDagRun(TestDagRunEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
response = self.client.delete("api/v1/dags/TEST_DAG_ID}/dagRuns/TEST_DAG_RUN_ID")
assert response.status_code == 204
@provide_session
def test_should_response_204(self, session):
session.add_all(self._create_test_dag_run())
session.commit()
response = self.client.delete(
"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1"
)
self.assertEqual(response.status_code, 204)
# Check if the Dag Run is deleted from the database
response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1")
self.assertEqual(response.status_code, 404)
def test_should_response_404(self):
response = self.client.delete(
"api/v1/dags/INVALID_DAG_RUN/dagRuns/INVALID_DAG_RUN"
)
self.assertEqual(response.status_code, 404)
self.assertEqual(
response.json,
{
"detail": "DAGRun with DAG ID: 'INVALID_DAG_RUN' and DagRun ID: 'INVALID_DAG_RUN' not found",
"status": 404,
"title": "Object not found",
"type": "about:blank",
},
)
class TestGetDagRun(TestDagRunEndpoint):
@provide_session
def test_should_response_200(self, session):
dagrun_model = DagRun(
dag_id='TEST_DAG_ID',
run_id='TEST_DAG_RUN_ID',
dag_id="TEST_DAG_ID",
run_id="TEST_DAG_RUN_ID",
run_type=DagRunType.MANUAL.value,
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
@ -103,14 +128,14 @@ class TestGetDagRun(TestDagRunEndpoint):
self.assertEqual(
response.json,
{
'dag_id': 'TEST_DAG_ID',
'dag_run_id': 'TEST_DAG_RUN_ID',
'end_date': None,
'state': 'running',
'execution_date': self.default_time,
'external_trigger': True,
'start_date': self.default_time,
'conf': {},
"dag_id": "TEST_DAG_ID",
"dag_run_id": "TEST_DAG_RUN_ID",
"end_date": None,
"state": "running",
"execution_date": self.default_time,
"external_trigger": True,
"start_date": self.default_time,
"conf": {},
},
)
@ -118,7 +143,13 @@ class TestGetDagRun(TestDagRunEndpoint):
response = self.client.get("api/v1/dags/invalid-id/dagRuns/invalid-id")
assert response.status_code == 404
self.assertEqual(
{'detail': None, 'status': 404, 'title': 'DAGRun not found', 'type': 'about:blank'}, response.json
{
"detail": None,
"status": 404,
"title": "DAGRun not found",
"type": "about:blank",
},
response.json,
)
@ -137,24 +168,24 @@ class TestGetDagRuns(TestDagRunEndpoint):
{
"dag_runs": [
{
'dag_id': 'TEST_DAG_ID',
'dag_run_id': 'TEST_DAG_RUN_ID_1',
'end_date': None,
'state': 'running',
'execution_date': self.default_time,
'external_trigger': True,
'start_date': self.default_time,
'conf': {},
"dag_id": "TEST_DAG_ID",
"dag_run_id": "TEST_DAG_RUN_ID_1",
"end_date": None,
"state": "running",
"execution_date": self.default_time,
"external_trigger": True,
"start_date": self.default_time,
"conf": {},
},
{
'dag_id': 'TEST_DAG_ID',
'dag_run_id': 'TEST_DAG_RUN_ID_2',
'end_date': None,
'state': 'running',
'execution_date': self.default_time_2,
'external_trigger': True,
'start_date': self.default_time,
'conf': {},
"dag_id": "TEST_DAG_ID",
"dag_run_id": "TEST_DAG_RUN_ID_2",
"end_date": None,
"state": "running",
"execution_date": self.default_time_2,
"external_trigger": True,
"start_date": self.default_time,
"conf": {},
},
],
"total_entries": 2,
@ -164,8 +195,12 @@ class TestGetDagRuns(TestDagRunEndpoint):
@provide_session
def test_should_return_all_with_tilde_as_dag_id(self, session):
dagruns = self._create_test_dag_run(extra_dag=True)
expected_dag_run_ids = ['TEST_DAG_ID', 'TEST_DAG_ID',
"TEST_DAG_ID_3", "TEST_DAG_ID_4"]
expected_dag_run_ids = [
"TEST_DAG_ID",
"TEST_DAG_ID",
"TEST_DAG_ID_3",
"TEST_DAG_ID_4",
]
session.add_all(dagruns)
session.commit()
result = session.query(DagRun).all()
@ -180,7 +215,10 @@ class TestGetDagRunsPagination(TestDagRunEndpoint):
@parameterized.expand(
[
("api/v1/dags/TEST_DAG_ID/dagRuns?limit=1", ["TEST_DAG_RUN_ID1"]),
("api/v1/dags/TEST_DAG_ID/dagRuns?limit=2", ["TEST_DAG_RUN_ID1", "TEST_DAG_RUN_ID2"]),
(
"api/v1/dags/TEST_DAG_ID/dagRuns?limit=2",
["TEST_DAG_RUN_ID1", "TEST_DAG_RUN_ID2"],
),
(
"api/v1/dags/TEST_DAG_ID/dagRuns?offset=5",
[
@ -208,7 +246,10 @@ class TestGetDagRunsPagination(TestDagRunEndpoint):
),
("api/v1/dags/TEST_DAG_ID/dagRuns?limit=1&offset=5", ["TEST_DAG_RUN_ID6"]),
("api/v1/dags/TEST_DAG_ID/dagRuns?limit=1&offset=1", ["TEST_DAG_RUN_ID2"]),
("api/v1/dags/TEST_DAG_ID/dagRuns?limit=2&offset=2", ["TEST_DAG_RUN_ID3", "TEST_DAG_RUN_ID4"],),
(
"api/v1/dags/TEST_DAG_ID/dagRuns?limit=2&offset=2",
["TEST_DAG_RUN_ID3", "TEST_DAG_RUN_ID4"],
),
]
)
@provide_session
@ -245,7 +286,7 @@ class TestGetDagRunsPagination(TestDagRunEndpoint):
response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns?limit=180")
assert response.status_code == 200
self.assertEqual(len(response.json['dag_runs']), 150)
self.assertEqual(len(response.json["dag_runs"]), 150)
def _create_dag_runs(self, count):
return [
@ -275,18 +316,30 @@ class TestGetDagRunsPaginationFilters(TestDagRunEndpoint):
(
"api/v1/dags/TEST_DAG_ID/dagRuns?start_date_lte= 2020-06-15T18:00:00+00:00"
"&start_date_gte=2020-06-12T18:00:00Z",
["TEST_START_EXEC_DAY_12", "TEST_START_EXEC_DAY_13",
"TEST_START_EXEC_DAY_14", "TEST_START_EXEC_DAY_15"],
[
"TEST_START_EXEC_DAY_12",
"TEST_START_EXEC_DAY_13",
"TEST_START_EXEC_DAY_14",
"TEST_START_EXEC_DAY_15",
],
),
(
"api/v1/dags/TEST_DAG_ID/dagRuns?execution_date_lte=2020-06-13T18:00:00+00:00",
["TEST_START_EXEC_DAY_10", "TEST_START_EXEC_DAY_11",
"TEST_START_EXEC_DAY_12", "TEST_START_EXEC_DAY_13"],
[
"TEST_START_EXEC_DAY_10",
"TEST_START_EXEC_DAY_11",
"TEST_START_EXEC_DAY_12",
"TEST_START_EXEC_DAY_13",
],
),
(
"api/v1/dags/TEST_DAG_ID/dagRuns?execution_date_gte=2020-06-16T18:00:00+00:00",
["TEST_START_EXEC_DAY_16", "TEST_START_EXEC_DAY_17",
"TEST_START_EXEC_DAY_18", "TEST_START_EXEC_DAY_19"],
[
"TEST_START_EXEC_DAY_16",
"TEST_START_EXEC_DAY_17",
"TEST_START_EXEC_DAY_18",
"TEST_START_EXEC_DAY_19",
],
),
]
)
@ -304,16 +357,16 @@ class TestGetDagRunsPaginationFilters(TestDagRunEndpoint):
def _create_dag_runs(self):
dates = [
'2020-06-10T18:00:00+00:00',
'2020-06-11T18:00:00+00:00',
'2020-06-12T18:00:00+00:00',
'2020-06-13T18:00:00+00:00',
'2020-06-14T18:00:00+00:00',
'2020-06-15T18:00:00Z',
'2020-06-16T18:00:00Z',
'2020-06-17T18:00:00Z',
'2020-06-18T18:00:00Z',
'2020-06-19T18:00:00Z',
"2020-06-10T18:00:00+00:00",
"2020-06-11T18:00:00+00:00",
"2020-06-12T18:00:00+00:00",
"2020-06-13T18:00:00+00:00",
"2020-06-14T18:00:00+00:00",
"2020-06-15T18:00:00Z",
"2020-06-16T18:00:00Z",
"2020-06-17T18:00:00Z",
"2020-06-18T18:00:00Z",
"2020-06-19T18:00:00Z",
]
return [
@ -324,7 +377,7 @@ class TestGetDagRunsPaginationFilters(TestDagRunEndpoint):
execution_date=timezone.parse(dates[i]),
start_date=timezone.parse(dates[i]),
external_trigger=True,
state='success',
state="success",
)
for i in range(len(dates))
]
@ -347,19 +400,135 @@ class TestGetDagRunsEndDateFilters(TestDagRunEndpoint):
)
@provide_session
def test_end_date_gte_lte(self, url, expected_dag_run_ids, session):
dagruns = self._create_test_dag_run('success') # state==success, then end date is today
dagruns = self._create_test_dag_run(
"success"
) # state==success, then end date is today
session.add_all(dagruns)
session.commit()
response = self.client.get(url)
assert response.status_code == 200
self.assertEqual(response.json["total_entries"], 2)
dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"] if dag_run]
dag_run_ids = [
dag_run["dag_run_id"] for dag_run in response.json["dag_runs"] if dag_run
]
self.assertEqual(dag_run_ids, expected_dag_run_ids)
class TestPostDagRun(TestDagRunEndpoint):
@pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
response = self.client.post("/dags/TEST_DAG_ID/dagRuns")
assert response.status_code == 200
@parameterized.expand(
[
(
"All fields present",
{
"dag_run_id": "TEST_DAG_RUN",
"execution_date": "2020-06-11T18:00:00+00:00",
},
),
("dag_run_id missing", {"execution_date": "2020-06-11T18:00:00+00:00"}),
("dag_run_id and execution_date missing", {}),
]
)
@provide_session
def test_should_response_200(self, name, request_json, session):
del name
dag_instance = DagModel(dag_id="TEST_DAG_ID")
session.add(dag_instance)
session.commit()
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns", json=request_json
)
self.assertEqual(response.status_code, 200)
self.assertEqual(
{
"conf": {},
"dag_id": "TEST_DAG_ID",
"dag_run_id": response.json["dag_run_id"],
"end_date": None,
"execution_date": response.json["execution_date"],
"external_trigger": True,
"start_date": response.json["start_date"],
"state": "running",
},
response.json,
)
def test_response_404(self):
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns",
json={"dag_run_id": "TEST_DAG_RUN", "execution_date": self.default_time},
)
self.assertEqual(response.status_code, 404)
self.assertEqual(
{
"detail": None,
"status": 404,
"title": "DAG with dag_id: 'TEST_DAG_ID' not found",
"type": "about:blank",
},
response.json,
)
@parameterized.expand(
[
(
"start_date in request json",
"api/v1/dags/TEST_DAG_ID/dagRuns",
{
"start_date": "2020-06-11T18:00:00+00:00",
"execution_date": "2020-06-12T18:00:00+00:00",
},
{
"detail": "Property is read-only - 'start_date'",
"status": 400,
"title": "Bad Request",
"type": "about:blank",
},
),
(
"state in request json",
"api/v1/dags/TEST_DAG_ID/dagRuns",
{"state": "failed", "execution_date": "2020-06-12T18:00:00+00:00"},
{
"detail": "Property is read-only - 'state'",
"status": 400,
"title": "Bad Request",
"type": "about:blank",
},
),
]
)
@provide_session
def test_response_400(self, name, url, request_json, expected_response, session):
del name
dag_instance = DagModel(dag_id="TEST_DAG_ID")
session.add(dag_instance)
session.commit()
response = self.client.post(url, json=request_json)
self.assertEqual(response.status_code, 400, response.data)
self.assertEqual(expected_response, response.json)
@provide_session
def test_response_409(self, session):
dag_instance = DagModel(dag_id="TEST_DAG_ID")
session.add(dag_instance)
session.add_all(self._create_test_dag_run())
session.commit()
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns",
json={
"dag_run_id": "TEST_DAG_RUN_ID_1",
"execution_date": self.default_time,
},
)
self.assertEqual(response.status_code, 409, response.data)
self.assertEqual(
response.json,
{
"detail": "DAGRun with DAG ID: 'TEST_DAG_ID' and "
"DAGRun ID: 'TEST_DAG_RUN_ID_1' already exists",
"status": 409,
"title": "Object already exists",
"type": "about:blank",
},
)

Просмотреть файл

@ -14,11 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from dateutil.parser import parse
from marshmallow import ValidationError
from parameterized import parameterized
from airflow.api_connexion.schemas.dag_run_schema import (
DAGRunCollection, dagrun_collection_schema, dagrun_schema,
@ -29,27 +28,28 @@ from airflow.utils.session import provide_session
from airflow.utils.types import DagRunType
from tests.test_utils.db import clear_db_runs
DEFAULT_TIME = "2020-06-09T13:59:56.336000+00:00"
class TestDAGRunBase(unittest.TestCase):
def setUp(self) -> None:
clear_db_runs()
self.default_time = "2020-06-09T13:59:56.336000+00:00"
self.default_time = DEFAULT_TIME
def tearDown(self) -> None:
clear_db_runs()
class TestDAGRunSchema(TestDAGRunBase):
@provide_session
def test_serialze(self, session):
dagrun_model = DagRun(run_id='my-dag-run',
run_type=DagRunType.MANUAL.value,
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
conf='{"start": "stop"}'
)
dagrun_model = DagRun(
run_id="my-dag-run",
run_type=DagRunType.MANUAL.value,
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
conf='{"start": "stop"}',
)
session.add(dagrun_model)
session.commit()
dagrun_model = session.query(DagRun).first()
@ -58,68 +58,63 @@ class TestDAGRunSchema(TestDAGRunBase):
self.assertEqual(
deserialized_dagrun,
{
'dag_id': None,
'dag_run_id': 'my-dag-run',
'end_date': None,
'state': 'running',
'execution_date': self.default_time,
'external_trigger': True,
'start_date': self.default_time,
'conf': {"start": "stop"}
}
"dag_id": None,
"dag_run_id": "my-dag-run",
"end_date": None,
"state": "running",
"execution_date": self.default_time,
"external_trigger": True,
"start_date": self.default_time,
"conf": {"start": "stop"},
},
)
def test_deserialize(self):
# Only dag_run_id, execution_date, state,
# and conf are loaded.
# dag_run_id should be loaded as run_id
serialized_dagrun = {
'dag_run_id': 'my-dag-run',
'state': 'failed',
'execution_date': self.default_time,
'conf': '{"start": "stop"}'
}
@parameterized.expand(
[
( # Conf not provided
{"dag_run_id": "my-dag-run", "execution_date": DEFAULT_TIME},
{"run_id": "my-dag-run", "execution_date": parse(DEFAULT_TIME)},
),
(
{
"dag_run_id": "my-dag-run",
"execution_date": DEFAULT_TIME,
"conf": {"start": "stop"},
},
{
"run_id": "my-dag-run",
"execution_date": parse(DEFAULT_TIME),
"conf": {"start": "stop"},
},
),
]
)
def test_deserialize(self, serialized_dagrun, expected_result):
result = dagrun_schema.load(serialized_dagrun)
self.assertEqual(
result,
{
'run_id': 'my-dag-run',
'execution_date': parse(self.default_time),
'state': 'failed',
'conf': {"start": "stop"}
}
)
self.assertDictEqual(result, expected_result)
def test_deserialize_2(self):
# loading dump_only field raises
serialized_dagrun = {
'dag_id': None,
'dag_run_id': 'my-dag-run',
'end_date': None,
'state': 'failed',
'execution_date': self.default_time,
'external_trigger': True,
'start_date': self.default_time,
'conf': {"start": "stop"}
}
with self.assertRaises(ValidationError):
dagrun_schema.load(serialized_dagrun)
def test_autofill_fields(self):
"""Dag_run_id and execution_date fields are autogenerated if missing"""
serialized_dagrun = {}
result = dagrun_schema.load(serialized_dagrun)
self.assertDictEqual(
result,
{"execution_date": result["execution_date"], "run_id": result["run_id"]},
)
class TestDagRunCollection(TestDAGRunBase):
@provide_session
def test_serialize(self, session):
dagrun_model_1 = DagRun(
run_id='my-dag-run',
run_id="my-dag-run",
execution_date=timezone.parse(self.default_time),
run_type=DagRunType.MANUAL.value,
start_date=timezone.parse(self.default_time),
conf='{"start": "stop"}'
conf='{"start": "stop"}',
)
dagrun_model_2 = DagRun(
run_id='my-dag-run-2',
run_id="my-dag-run-2",
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
run_type=DagRunType.MANUAL.value,
@ -127,34 +122,33 @@ class TestDagRunCollection(TestDAGRunBase):
dagruns = [dagrun_model_1, dagrun_model_2]
session.add_all(dagruns)
session.commit()
instance = DAGRunCollection(dag_runs=dagruns,
total_entries=2)
instance = DAGRunCollection(dag_runs=dagruns, total_entries=2)
deserialized_dagruns = dagrun_collection_schema.dump(instance)
self.assertEqual(
deserialized_dagruns,
{
'dag_runs': [
"dag_runs": [
{
'dag_id': None,
'dag_run_id': 'my-dag-run',
'end_date': None,
'execution_date': self.default_time,
'external_trigger': True,
'state': 'running',
'start_date': self.default_time,
'conf': {"start": "stop"}
"dag_id": None,
"dag_run_id": "my-dag-run",
"end_date": None,
"execution_date": self.default_time,
"external_trigger": True,
"state": "running",
"start_date": self.default_time,
"conf": {"start": "stop"},
},
{
'dag_id': None,
'dag_run_id': 'my-dag-run-2',
'end_date': None,
'state': 'running',
'execution_date': self.default_time,
'external_trigger': True,
'start_date': self.default_time,
'conf': {}
}
"dag_id": None,
"dag_run_id": "my-dag-run-2",
"end_date": None,
"state": "running",
"execution_date": self.default_time,
"external_trigger": True,
"start_date": self.default_time,
"conf": {},
},
],
'total_entries': 2
}
"total_entries": 2,
},
)