incubator-airflow/tests/operators/test_branch_operator.py

179 строки
6.3 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import unittest
from airflow.models import DAG, DagRun, TaskInstance as TI
from airflow.operators.branch_operator import BaseBranchOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
INTERVAL = datetime.timedelta(hours=12)
class ChooseBranchOne(BaseBranchOperator):
def choose_branch(self, context):
return 'branch_1'
class ChooseBranchOneTwo(BaseBranchOperator):
def choose_branch(self, context):
return ['branch_1', 'branch_2']
class TestBranchOperator(unittest.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()
def setUp(self):
self.dag = DAG('branch_operator_test',
default_args={
'owner': 'airflow',
'start_date': DEFAULT_DATE},
schedule_interval=INTERVAL)
self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
self.branch_3 = None
self.branch_op = None
def tearDown(self):
super().tearDown()
with create_session() as session:
session.query(DagRun).delete()
session.query(TI).delete()
def test_without_dag_run(self):
"""This checks the defensive against non existent tasks in a dag run"""
self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
self.branch_1.set_upstream(self.branch_op)
self.branch_2.set_upstream(self.branch_op)
self.dag.clear()
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
with create_session() as session:
tis = session.query(TI).filter(
TI.dag_id == self.dag.dag_id,
TI.execution_date == DEFAULT_DATE
)
for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'branch_1':
# should exist with state None
self.assertEqual(ti.state, State.NONE)
elif ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.SKIPPED)
else:
raise Exception
def test_branch_list_without_dag_run(self):
"""This checks if the BranchOperator supports branching off to a list of tasks."""
self.branch_op = ChooseBranchOneTwo(task_id='make_choice', dag=self.dag)
self.branch_1.set_upstream(self.branch_op)
self.branch_2.set_upstream(self.branch_op)
self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag)
self.branch_3.set_upstream(self.branch_op)
self.dag.clear()
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
with create_session() as session:
tis = session.query(TI).filter(
TI.dag_id == self.dag.dag_id,
TI.execution_date == DEFAULT_DATE
)
expected = {
"make_choice": State.SUCCESS,
"branch_1": State.NONE,
"branch_2": State.NONE,
"branch_3": State.SKIPPED,
}
for ti in tis:
if ti.task_id in expected:
self.assertEqual(ti.state, expected[ti.task_id])
else:
raise Exception
def test_with_dag_run(self):
self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
self.branch_1.set_upstream(self.branch_op)
self.branch_2.set_upstream(self.branch_op)
self.dag.clear()
dagrun = self.dag.create_dagrun(
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING
)
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
tis = dagrun.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'branch_1':
self.assertEqual(ti.state, State.NONE)
elif ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.SKIPPED)
else:
raise Exception
def test_with_skip_in_branch_downstream_dependencies(self):
self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag)
self.branch_op >> self.branch_1 >> self.branch_2
self.branch_op >> self.branch_2
self.dag.clear()
dagrun = self.dag.create_dagrun(
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING
)
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
tis = dagrun.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'branch_1':
self.assertEqual(ti.state, State.NONE)
elif ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.NONE)
else:
raise Exception