- The dag_run argument is only there for test mocks, and only to access a static method. Removing this simplifies the function, reduces confusion.
- Give optional arguments a default value, reduce indentation of arg list to PEP / Black standard.
- Clean up tests for readability
This commit is contained in:
Martijn Pieters 2020-10-16 20:22:40 +01:00 коммит произвёл GitHub
Родитель 84c70287ee
Коммит 4d611f2ffd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 26 добавлений и 94 удалений

Просмотреть файл

@ -28,19 +28,17 @@ from airflow.utils.types import DagRunType
def _trigger_dag(
dag_id: str,
dag_bag: DagBag,
dag_run: DagModel,
run_id: Optional[str],
conf: Optional[Union[dict, str]],
execution_date: Optional[datetime],
replace_microseconds: bool,
dag_id: str,
dag_bag: DagBag,
run_id: Optional[str] = None,
conf: Optional[Union[dict, str]] = None,
execution_date: Optional[datetime] = None,
replace_microseconds: bool = True,
) -> List[DagRun]: # pylint: disable=too-many-arguments
"""Triggers DAG run.
:param dag_id: DAG ID
:param dag_bag: DAG Bag model
:param dag_run: DAG Run model
:param run_id: ID of the dag_run
:param conf: configuration
:param execution_date: date of execution
@ -69,7 +67,7 @@ def _trigger_dag(
min_dag_start_date.isoformat()))
run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date)
dag_run = dag_run.find(dag_id=dag_id, run_id=run_id)
dag_run = DagRun.find(dag_id=dag_id, run_id=run_id)
if dag_run:
raise DagRunAlreadyExists(
@ -123,10 +121,8 @@ def trigger_dag(
dag_folder=dag_model.fileloc,
read_dags_from_db=read_store_serialized_dags()
)
dag_run = DagRun()
triggers = _trigger_dag(
dag_id=dag_id,
dag_run=dag_run,
dag_bag=dagbag,
run_id=run_id,
conf=conf,

Просмотреть файл

@ -36,23 +36,13 @@ class TestTriggerDag(unittest.TestCase):
def tearDown(self) -> None:
db.clear_db_runs()
@mock.patch('airflow.models.DagRun')
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_dag_not_found(self, dag_bag_mock, dag_run_mock):
dag_bag_mock.dags = []
self.assertRaises(
AirflowException,
_trigger_dag,
'dag_not_found',
dag_bag_mock,
dag_run_mock,
run_id=None,
conf=None,
execution_date=None,
replace_microseconds=True,
)
def test_trigger_dag_dag_not_found(self, dag_bag_mock):
dag_bag_mock.dags = {}
with self.assertRaises(AirflowException):
_trigger_dag('dag_not_found', dag_bag_mock)
@mock.patch('airflow.models.DagRun')
@mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun)
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_dag_run_exist(self, dag_bag_mock, dag_run_mock):
dag_id = "dag_run_exist"
@ -60,65 +50,38 @@ class TestTriggerDag(unittest.TestCase):
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag
dag_run_mock.find.return_value = DagRun()
self.assertRaises(
AirflowException,
_trigger_dag,
dag_id,
dag_bag_mock,
dag_run_mock,
run_id=None,
conf=None,
execution_date=None,
replace_microseconds=True,
)
with self.assertRaises(AirflowException):
_trigger_dag(dag_id, dag_bag_mock)
@mock.patch('airflow.models.DAG')
@mock.patch('airflow.models.DagRun')
@mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun)
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_include_subdags(self, dag_bag_mock, dag_run_mock, dag_mock):
dag_id = "trigger_dag"
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag_mock
dag_run_mock.find.return_value = None
dag1 = mock.MagicMock()
dag1.subdags = []
dag2 = mock.MagicMock()
dag2.subdags = []
dag1 = mock.MagicMock(subdags=[])
dag2 = mock.MagicMock(subdags=[])
dag_mock.subdags = [dag1, dag2]
triggers = _trigger_dag(
dag_id,
dag_bag_mock,
dag_run_mock,
run_id=None,
conf=None,
execution_date=None,
replace_microseconds=True)
triggers = _trigger_dag(dag_id, dag_bag_mock)
self.assertEqual(3, len(triggers))
@mock.patch('airflow.models.DAG')
@mock.patch('airflow.models.DagRun')
@mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun)
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_include_nested_subdags(self, dag_bag_mock, dag_run_mock, dag_mock):
dag_id = "trigger_dag"
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag_mock
dag_run_mock.find.return_value = None
dag1 = mock.MagicMock()
dag1.subdags = []
dag2 = mock.MagicMock()
dag2.subdags = [dag1]
dag1 = mock.MagicMock(subdags=[])
dag2 = mock.MagicMock(subdags=[dag1])
dag_mock.subdags = [dag1, dag2]
triggers = _trigger_dag(
dag_id,
dag_bag_mock,
dag_run_mock,
run_id=None,
conf=None,
execution_date=None,
replace_microseconds=True)
triggers = _trigger_dag(dag_id, dag_bag_mock)
self.assertEqual(3, len(triggers))
@ -128,19 +91,9 @@ class TestTriggerDag(unittest.TestCase):
dag = DAG(dag_id, default_args={'start_date': timezone.datetime(2016, 9, 5, 10, 10, 0)})
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag
dag_run = DagRun()
self.assertRaises(
ValueError,
_trigger_dag,
dag_id,
dag_bag_mock,
dag_run,
run_id=None,
conf=None,
execution_date=timezone.datetime(2015, 7, 5, 10, 10, 0),
replace_microseconds=True,
)
with self.assertRaises(ValueError):
_trigger_dag(dag_id, dag_bag_mock, execution_date=timezone.datetime(2015, 7, 5, 10, 10, 0))
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_with_valid_start_date(self, dag_bag_mock):
@ -149,17 +102,8 @@ class TestTriggerDag(unittest.TestCase):
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag
dag_bag_mock.dags_hash = {}
dag_run = DagRun()
triggers = _trigger_dag(
dag_id,
dag_bag_mock,
dag_run,
run_id=None,
conf=None,
execution_date=timezone.datetime(2018, 7, 5, 10, 10, 0),
replace_microseconds=True,
)
triggers = _trigger_dag(dag_id, dag_bag_mock, execution_date=timezone.datetime(2018, 7, 5, 10, 10, 0))
assert len(triggers) == 1
@ -174,17 +118,9 @@ class TestTriggerDag(unittest.TestCase):
dag = DAG(dag_id)
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag
dag_run = DagRun()
dag_bag_mock.dags_hash = {}
triggers = _trigger_dag(
dag_id,
dag_bag_mock,
dag_run,
run_id=None,
conf=conf,
execution_date=None,
replace_microseconds=True)
triggers = _trigger_dag(dag_id, dag_bag_mock, conf=conf)
self.assertEqual(triggers[0].conf, expected_conf)