Clean up _trigger_dag function (#11584)
- The dag_run argument is only there for test mocks, and only to access a static method. Removing this simplifies the function, reduces confusion. - Give optional arguments a default value, reduce indentation of arg list to PEP / Black standard. - Clean up tests for readability
This commit is contained in:
Родитель
84c70287ee
Коммит
4d611f2ffd
|
@ -28,19 +28,17 @@ from airflow.utils.types import DagRunType
|
|||
|
||||
|
||||
def _trigger_dag(
|
||||
dag_id: str,
|
||||
dag_bag: DagBag,
|
||||
dag_run: DagModel,
|
||||
run_id: Optional[str],
|
||||
conf: Optional[Union[dict, str]],
|
||||
execution_date: Optional[datetime],
|
||||
replace_microseconds: bool,
|
||||
dag_id: str,
|
||||
dag_bag: DagBag,
|
||||
run_id: Optional[str] = None,
|
||||
conf: Optional[Union[dict, str]] = None,
|
||||
execution_date: Optional[datetime] = None,
|
||||
replace_microseconds: bool = True,
|
||||
) -> List[DagRun]: # pylint: disable=too-many-arguments
|
||||
"""Triggers DAG run.
|
||||
|
||||
:param dag_id: DAG ID
|
||||
:param dag_bag: DAG Bag model
|
||||
:param dag_run: DAG Run model
|
||||
:param run_id: ID of the dag_run
|
||||
:param conf: configuration
|
||||
:param execution_date: date of execution
|
||||
|
@ -69,7 +67,7 @@ def _trigger_dag(
|
|||
min_dag_start_date.isoformat()))
|
||||
|
||||
run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date)
|
||||
dag_run = dag_run.find(dag_id=dag_id, run_id=run_id)
|
||||
dag_run = DagRun.find(dag_id=dag_id, run_id=run_id)
|
||||
|
||||
if dag_run:
|
||||
raise DagRunAlreadyExists(
|
||||
|
@ -123,10 +121,8 @@ def trigger_dag(
|
|||
dag_folder=dag_model.fileloc,
|
||||
read_dags_from_db=read_store_serialized_dags()
|
||||
)
|
||||
dag_run = DagRun()
|
||||
triggers = _trigger_dag(
|
||||
dag_id=dag_id,
|
||||
dag_run=dag_run,
|
||||
dag_bag=dagbag,
|
||||
run_id=run_id,
|
||||
conf=conf,
|
||||
|
|
|
@ -36,23 +36,13 @@ class TestTriggerDag(unittest.TestCase):
|
|||
def tearDown(self) -> None:
|
||||
db.clear_db_runs()
|
||||
|
||||
@mock.patch('airflow.models.DagRun')
|
||||
@mock.patch('airflow.models.DagBag')
|
||||
def test_trigger_dag_dag_not_found(self, dag_bag_mock, dag_run_mock):
|
||||
dag_bag_mock.dags = []
|
||||
self.assertRaises(
|
||||
AirflowException,
|
||||
_trigger_dag,
|
||||
'dag_not_found',
|
||||
dag_bag_mock,
|
||||
dag_run_mock,
|
||||
run_id=None,
|
||||
conf=None,
|
||||
execution_date=None,
|
||||
replace_microseconds=True,
|
||||
)
|
||||
def test_trigger_dag_dag_not_found(self, dag_bag_mock):
|
||||
dag_bag_mock.dags = {}
|
||||
with self.assertRaises(AirflowException):
|
||||
_trigger_dag('dag_not_found', dag_bag_mock)
|
||||
|
||||
@mock.patch('airflow.models.DagRun')
|
||||
@mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun)
|
||||
@mock.patch('airflow.models.DagBag')
|
||||
def test_trigger_dag_dag_run_exist(self, dag_bag_mock, dag_run_mock):
|
||||
dag_id = "dag_run_exist"
|
||||
|
@ -60,65 +50,38 @@ class TestTriggerDag(unittest.TestCase):
|
|||
dag_bag_mock.dags = [dag_id]
|
||||
dag_bag_mock.get_dag.return_value = dag
|
||||
dag_run_mock.find.return_value = DagRun()
|
||||
self.assertRaises(
|
||||
AirflowException,
|
||||
_trigger_dag,
|
||||
dag_id,
|
||||
dag_bag_mock,
|
||||
dag_run_mock,
|
||||
run_id=None,
|
||||
conf=None,
|
||||
execution_date=None,
|
||||
replace_microseconds=True,
|
||||
)
|
||||
with self.assertRaises(AirflowException):
|
||||
_trigger_dag(dag_id, dag_bag_mock)
|
||||
|
||||
@mock.patch('airflow.models.DAG')
|
||||
@mock.patch('airflow.models.DagRun')
|
||||
@mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun)
|
||||
@mock.patch('airflow.models.DagBag')
|
||||
def test_trigger_dag_include_subdags(self, dag_bag_mock, dag_run_mock, dag_mock):
|
||||
dag_id = "trigger_dag"
|
||||
dag_bag_mock.dags = [dag_id]
|
||||
dag_bag_mock.get_dag.return_value = dag_mock
|
||||
dag_run_mock.find.return_value = None
|
||||
dag1 = mock.MagicMock()
|
||||
dag1.subdags = []
|
||||
dag2 = mock.MagicMock()
|
||||
dag2.subdags = []
|
||||
dag1 = mock.MagicMock(subdags=[])
|
||||
dag2 = mock.MagicMock(subdags=[])
|
||||
dag_mock.subdags = [dag1, dag2]
|
||||
|
||||
triggers = _trigger_dag(
|
||||
dag_id,
|
||||
dag_bag_mock,
|
||||
dag_run_mock,
|
||||
run_id=None,
|
||||
conf=None,
|
||||
execution_date=None,
|
||||
replace_microseconds=True)
|
||||
triggers = _trigger_dag(dag_id, dag_bag_mock)
|
||||
|
||||
self.assertEqual(3, len(triggers))
|
||||
|
||||
@mock.patch('airflow.models.DAG')
|
||||
@mock.patch('airflow.models.DagRun')
|
||||
@mock.patch('airflow.api.common.experimental.trigger_dag.DagRun', spec=DagRun)
|
||||
@mock.patch('airflow.models.DagBag')
|
||||
def test_trigger_dag_include_nested_subdags(self, dag_bag_mock, dag_run_mock, dag_mock):
|
||||
dag_id = "trigger_dag"
|
||||
dag_bag_mock.dags = [dag_id]
|
||||
dag_bag_mock.get_dag.return_value = dag_mock
|
||||
dag_run_mock.find.return_value = None
|
||||
dag1 = mock.MagicMock()
|
||||
dag1.subdags = []
|
||||
dag2 = mock.MagicMock()
|
||||
dag2.subdags = [dag1]
|
||||
dag1 = mock.MagicMock(subdags=[])
|
||||
dag2 = mock.MagicMock(subdags=[dag1])
|
||||
dag_mock.subdags = [dag1, dag2]
|
||||
|
||||
triggers = _trigger_dag(
|
||||
dag_id,
|
||||
dag_bag_mock,
|
||||
dag_run_mock,
|
||||
run_id=None,
|
||||
conf=None,
|
||||
execution_date=None,
|
||||
replace_microseconds=True)
|
||||
triggers = _trigger_dag(dag_id, dag_bag_mock)
|
||||
|
||||
self.assertEqual(3, len(triggers))
|
||||
|
||||
|
@ -128,19 +91,9 @@ class TestTriggerDag(unittest.TestCase):
|
|||
dag = DAG(dag_id, default_args={'start_date': timezone.datetime(2016, 9, 5, 10, 10, 0)})
|
||||
dag_bag_mock.dags = [dag_id]
|
||||
dag_bag_mock.get_dag.return_value = dag
|
||||
dag_run = DagRun()
|
||||
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
_trigger_dag,
|
||||
dag_id,
|
||||
dag_bag_mock,
|
||||
dag_run,
|
||||
run_id=None,
|
||||
conf=None,
|
||||
execution_date=timezone.datetime(2015, 7, 5, 10, 10, 0),
|
||||
replace_microseconds=True,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
_trigger_dag(dag_id, dag_bag_mock, execution_date=timezone.datetime(2015, 7, 5, 10, 10, 0))
|
||||
|
||||
@mock.patch('airflow.models.DagBag')
|
||||
def test_trigger_dag_with_valid_start_date(self, dag_bag_mock):
|
||||
|
@ -149,17 +102,8 @@ class TestTriggerDag(unittest.TestCase):
|
|||
dag_bag_mock.dags = [dag_id]
|
||||
dag_bag_mock.get_dag.return_value = dag
|
||||
dag_bag_mock.dags_hash = {}
|
||||
dag_run = DagRun()
|
||||
|
||||
triggers = _trigger_dag(
|
||||
dag_id,
|
||||
dag_bag_mock,
|
||||
dag_run,
|
||||
run_id=None,
|
||||
conf=None,
|
||||
execution_date=timezone.datetime(2018, 7, 5, 10, 10, 0),
|
||||
replace_microseconds=True,
|
||||
)
|
||||
triggers = _trigger_dag(dag_id, dag_bag_mock, execution_date=timezone.datetime(2018, 7, 5, 10, 10, 0))
|
||||
|
||||
assert len(triggers) == 1
|
||||
|
||||
|
@ -174,17 +118,9 @@ class TestTriggerDag(unittest.TestCase):
|
|||
dag = DAG(dag_id)
|
||||
dag_bag_mock.dags = [dag_id]
|
||||
dag_bag_mock.get_dag.return_value = dag
|
||||
dag_run = DagRun()
|
||||
|
||||
dag_bag_mock.dags_hash = {}
|
||||
|
||||
triggers = _trigger_dag(
|
||||
dag_id,
|
||||
dag_bag_mock,
|
||||
dag_run,
|
||||
run_id=None,
|
||||
conf=conf,
|
||||
execution_date=None,
|
||||
replace_microseconds=True)
|
||||
triggers = _trigger_dag(dag_id, dag_bag_mock, conf=conf)
|
||||
|
||||
self.assertEqual(triggers[0].conf, expected_conf)
|
||||
|
|
Загрузка…
Ссылка в новой задаче