179 строки
6.3 KiB
Python
179 строки
6.3 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import datetime
|
|
import unittest
|
|
|
|
from airflow.models import DAG, DagRun, TaskInstance as TI
|
|
from airflow.operators.branch_operator import BaseBranchOperator
|
|
from airflow.operators.dummy_operator import DummyOperator
|
|
from airflow.utils import timezone
|
|
from airflow.utils.session import create_session
|
|
from airflow.utils.state import State
|
|
from airflow.utils.types import DagRunType
|
|
|
|
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
|
|
INTERVAL = datetime.timedelta(hours=12)
|
|
|
|
|
|
class ChooseBranchOne(BaseBranchOperator):
|
|
def choose_branch(self, context):
|
|
return 'branch_1'
|
|
|
|
|
|
class ChooseBranchOneTwo(BaseBranchOperator):
|
|
def choose_branch(self, context):
|
|
return ['branch_1', 'branch_2']
|
|
|
|
|
|
class TestBranchOperator(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
|
|
with create_session() as session:
|
|
session.query(DagRun).delete()
|
|
session.query(TI).delete()
|
|
|
|
def setUp(self):
|
|
self.dag = DAG('branch_operator_test',
|
|
default_args={
|
|
'owner': 'airflow',
|
|
'start_date': DEFAULT_DATE},
|
|
schedule_interval=INTERVAL)
|
|
|
|
self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
|
|
self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
|
|
self.branch_3 = None
|
|
self.branch_op = None
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
|
|
with create_session() as session:
|
|
session.query(DagRun).delete()
|
|
session.query(TI).delete()
|
|
|
|
def test_without_dag_run(self):
|
|
"""This checks the defensive against non existent tasks in a dag run"""
|
|
self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
|
|
self.branch_1.set_upstream(self.branch_op)
|
|
self.branch_2.set_upstream(self.branch_op)
|
|
self.dag.clear()
|
|
|
|
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
|
|
|
with create_session() as session:
|
|
tis = session.query(TI).filter(
|
|
TI.dag_id == self.dag.dag_id,
|
|
TI.execution_date == DEFAULT_DATE
|
|
)
|
|
|
|
for ti in tis:
|
|
if ti.task_id == 'make_choice':
|
|
self.assertEqual(ti.state, State.SUCCESS)
|
|
elif ti.task_id == 'branch_1':
|
|
# should exist with state None
|
|
self.assertEqual(ti.state, State.NONE)
|
|
elif ti.task_id == 'branch_2':
|
|
self.assertEqual(ti.state, State.SKIPPED)
|
|
else:
|
|
raise Exception
|
|
|
|
def test_branch_list_without_dag_run(self):
|
|
"""This checks if the BranchOperator supports branching off to a list of tasks."""
|
|
self.branch_op = ChooseBranchOneTwo(task_id='make_choice', dag=self.dag)
|
|
self.branch_1.set_upstream(self.branch_op)
|
|
self.branch_2.set_upstream(self.branch_op)
|
|
self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag)
|
|
self.branch_3.set_upstream(self.branch_op)
|
|
self.dag.clear()
|
|
|
|
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
|
|
|
with create_session() as session:
|
|
tis = session.query(TI).filter(
|
|
TI.dag_id == self.dag.dag_id,
|
|
TI.execution_date == DEFAULT_DATE
|
|
)
|
|
|
|
expected = {
|
|
"make_choice": State.SUCCESS,
|
|
"branch_1": State.NONE,
|
|
"branch_2": State.NONE,
|
|
"branch_3": State.SKIPPED,
|
|
}
|
|
|
|
for ti in tis:
|
|
if ti.task_id in expected:
|
|
self.assertEqual(ti.state, expected[ti.task_id])
|
|
else:
|
|
raise Exception
|
|
|
|
def test_with_dag_run(self):
|
|
self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
|
|
self.branch_1.set_upstream(self.branch_op)
|
|
self.branch_2.set_upstream(self.branch_op)
|
|
self.dag.clear()
|
|
|
|
dagrun = self.dag.create_dagrun(
|
|
run_type=DagRunType.MANUAL,
|
|
start_date=timezone.utcnow(),
|
|
execution_date=DEFAULT_DATE,
|
|
state=State.RUNNING
|
|
)
|
|
|
|
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
|
|
|
tis = dagrun.get_task_instances()
|
|
for ti in tis:
|
|
if ti.task_id == 'make_choice':
|
|
self.assertEqual(ti.state, State.SUCCESS)
|
|
elif ti.task_id == 'branch_1':
|
|
self.assertEqual(ti.state, State.NONE)
|
|
elif ti.task_id == 'branch_2':
|
|
self.assertEqual(ti.state, State.SKIPPED)
|
|
else:
|
|
raise Exception
|
|
|
|
def test_with_skip_in_branch_downstream_dependencies(self):
|
|
self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
|
|
self.branch_op >> self.branch_1 >> self.branch_2
|
|
self.branch_op >> self.branch_2
|
|
self.dag.clear()
|
|
|
|
dagrun = self.dag.create_dagrun(
|
|
run_type=DagRunType.MANUAL,
|
|
start_date=timezone.utcnow(),
|
|
execution_date=DEFAULT_DATE,
|
|
state=State.RUNNING
|
|
)
|
|
|
|
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
|
|
|
tis = dagrun.get_task_instances()
|
|
for ti in tis:
|
|
if ti.task_id == 'make_choice':
|
|
self.assertEqual(ti.state, State.SUCCESS)
|
|
elif ti.task_id == 'branch_1':
|
|
self.assertEqual(ti.state, State.NONE)
|
|
elif ti.task_id == 'branch_2':
|
|
self.assertEqual(ti.state, State.NONE)
|
|
else:
|
|
raise Exception
|