incubator-airflow/tests/core/test_core.py

489 строки
18 KiB
Python

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import multiprocessing
import os
import signal
import unittest
from datetime import timedelta
from time import sleep
from dateutil.relativedelta import relativedelta
from numpy.testing import assert_array_almost_equal
from airflow import settings
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.hooks.base_hook import BaseHook
from airflow.jobs.local_task_job import LocalTaskJob
from airflow.models import DagBag, DagRun, TaskFail, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.operators.bash import BashOperator
from airflow.operators.check_operator import CheckOperator, ValueCheckOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python import PythonOperator
from airflow.settings import Session
from airflow.utils.dates import infer_time_unit, round_time, scale_time_units
from airflow.utils.state import State
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
from tests.test_utils.config import conf_vars
DEV_NULL = '/dev/null'
DEFAULT_DATE = datetime(2015, 1, 1)
TEST_DAG_ID = 'unit_tests'
class OperatorSubclass(BaseOperator):
"""
An operator to test template substitution
"""
template_fields = ['some_templated_field']
def __init__(self, some_templated_field, *args, **kwargs):
super().__init__(*args, **kwargs)
self.some_templated_field = some_templated_field
def execute(self, context):
pass
class TestCore(unittest.TestCase):
default_scheduler_args = {"num_runs": 1}
def setUp(self):
self.dagbag = DagBag(
dag_folder=DEV_NULL, include_examples=True)
self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=self.args)
self.dag_bash = self.dagbag.dags['example_bash_operator']
self.runme_0 = self.dag_bash.get_task('runme_0')
self.run_after_loop = self.dag_bash.get_task('run_after_loop')
self.run_this_last = self.dag_bash.get_task('run_this_last')
def tearDown(self):
session = Session()
session.query(DagRun).filter(
DagRun.dag_id == TEST_DAG_ID).delete(
synchronize_session=False)
session.query(TaskInstance).filter(
TaskInstance.dag_id == TEST_DAG_ID).delete(
synchronize_session=False)
session.query(TaskFail).filter(
TaskFail.dag_id == TEST_DAG_ID).delete(
synchronize_session=False)
session.commit()
session.close()
def test_check_operators(self):
conn_id = "sqlite_default"
captain_hook = BaseHook.get_hook(conn_id=conn_id) # quite funny :D
captain_hook.run("CREATE TABLE operator_test_table (a, b)")
captain_hook.run("insert into operator_test_table values (1,2)")
op = CheckOperator(
task_id='check',
sql="select count(*) from operator_test_table",
conn_id=conn_id,
dag=self.dag)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
op = ValueCheckOperator(
task_id='value_check',
pass_value=95,
tolerance=0.1,
conn_id=conn_id,
sql="SELECT 100",
dag=self.dag)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
captain_hook.run("drop table operator_test_table")
def test_clear_api(self):
task = self.dag_bash.tasks[0]
task.clear(
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
upstream=True, downstream=True)
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.are_dependents_done()
def test_illegal_args(self):
"""
Tests that Operators reject illegal arguments
"""
msg = 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).'
with conf_vars({('operators', 'allow_illegal_arguments'): 'True'}):
with self.assertWarns(PendingDeprecationWarning) as warning:
BashOperator(
task_id='test_illegal_args',
bash_command='echo success',
dag=self.dag,
illegal_argument_1234='hello?')
assert any(msg in str(w) for w in warning.warnings)
def test_illegal_args_forbidden(self):
"""
Tests that operators raise exceptions on illegal arguments when
illegal arguments are not allowed.
"""
with self.assertRaises(AirflowException) as ctx:
BashOperator(
task_id='test_illegal_args',
bash_command='echo success',
dag=self.dag,
illegal_argument_1234='hello?')
self.assertIn(
('Invalid arguments were passed to BashOperator '
'(task_id: test_illegal_args).'),
str(ctx.exception))
def test_bash_operator(self):
op = BashOperator(
task_id='test_bash_operator',
bash_command="echo success",
dag=self.dag)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_bash_operator_multi_byte_output(self):
op = BashOperator(
task_id='test_multi_byte_bash_operator',
bash_command="echo \u2600",
dag=self.dag,
output_encoding='utf-8')
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_bash_operator_kill(self):
import psutil
sleep_time = "100%d" % os.getpid()
op = BashOperator(
task_id='test_bash_operator_kill',
execution_timeout=timedelta(seconds=1),
bash_command="/bin/bash -c 'sleep %s'" % sleep_time,
dag=self.dag)
self.assertRaises(
AirflowTaskTimeout,
op.run,
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
sleep(2)
pid = -1
for proc in psutil.process_iter():
if proc.cmdline() == ['sleep', sleep_time]:
pid = proc.pid
if pid != -1:
os.kill(pid, signal.SIGTERM)
self.fail("BashOperator's subprocess still running after stopping on timeout!")
def test_on_failure_callback(self):
# Annoying workaround for nonlocal not existing in python 2
data = {'called': False}
def check_failure(context, test_case=self):
data['called'] = True
error = context.get('exception')
test_case.assertIsInstance(error, AirflowException)
op = BashOperator(
task_id='check_on_failure_callback',
bash_command="exit 1",
dag=self.dag,
on_failure_callback=check_failure)
self.assertRaises(
AirflowException,
op.run,
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
self.assertTrue(data['called'])
def test_dryrun(self):
op = BashOperator(
task_id='test_dryrun',
bash_command="echo success",
dag=self.dag)
op.dry_run()
def test_sqlite(self):
import airflow.providers.sqlite.operators.sqlite
op = airflow.providers.sqlite.operators.sqlite.SqliteOperator(
task_id='time_sqlite',
sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))",
dag=self.dag)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_timeout(self):
op = PythonOperator(
task_id='test_timeout',
execution_timeout=timedelta(seconds=1),
python_callable=lambda: sleep(5),
dag=self.dag)
self.assertRaises(
AirflowTaskTimeout,
op.run,
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_python_op(self):
def test_py_op(templates_dict, ds, **kwargs):
if not templates_dict['ds'] == ds:
raise Exception("failure")
op = PythonOperator(
task_id='test_py_op',
python_callable=test_py_op,
templates_dict={'ds': "{{ ds }}"},
dag=self.dag)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_complex_template(self):
def verify_templated_field(context):
self.assertEqual(context['ti'].task.some_templated_field['bar'][1],
context['ds'])
op = OperatorSubclass(
task_id='test_complex_template',
some_templated_field={
'foo': '123',
'bar': ['baz', '{{ ds }}']
},
dag=self.dag)
op.execute = verify_templated_field
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_template_non_bool(self):
"""
Test templates can handle objects with no sense of truthiness
"""
class NonBoolObject:
def __len__(self): # pylint: disable=invalid-length-returned
return NotImplemented
def __bool__(self): # pylint: disable=invalid-bool-returned, bad-option-value
return NotImplemented
op = OperatorSubclass(
task_id='test_bad_template_obj',
some_templated_field=NonBoolObject(),
dag=self.dag)
op.resolve_template_files()
def test_task_get_template(self):
TI = TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
ti.dag = self.dag_bash
ti.run(ignore_ti_state=True)
context = ti.get_template_context()
# DEFAULT DATE is 2015-01-01
self.assertEqual(context['ds'], '2015-01-01')
self.assertEqual(context['ds_nodash'], '20150101')
# next_ds is 2015-01-02 as the dag interval is daily
self.assertEqual(context['next_ds'], '2015-01-02')
self.assertEqual(context['next_ds_nodash'], '20150102')
# prev_ds is 2014-12-31 as the dag interval is daily
self.assertEqual(context['prev_ds'], '2014-12-31')
self.assertEqual(context['prev_ds_nodash'], '20141231')
self.assertEqual(context['ts'], '2015-01-01T00:00:00+00:00')
self.assertEqual(context['ts_nodash'], '20150101T000000')
self.assertEqual(context['ts_nodash_with_tz'], '20150101T000000+0000')
self.assertEqual(context['yesterday_ds'], '2014-12-31')
self.assertEqual(context['yesterday_ds_nodash'], '20141231')
self.assertEqual(context['tomorrow_ds'], '2015-01-02')
self.assertEqual(context['tomorrow_ds_nodash'], '20150102')
def test_local_task_job(self):
TI = TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
job = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
job.run()
def test_raw_job(self):
TI = TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
ti.dag = self.dag_bash
ti.run(ignore_ti_state=True)
def test_round_time(self):
rt1 = round_time(datetime(2015, 1, 1, 6), timedelta(days=1))
self.assertEqual(datetime(2015, 1, 1, 0, 0), rt1)
rt2 = round_time(datetime(2015, 1, 2), relativedelta(months=1))
self.assertEqual(datetime(2015, 1, 1, 0, 0), rt2)
rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 16, 0, 0), rt3)
rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 15, 0, 0), rt4)
rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 14, 0, 0), rt5)
rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 14, 0, 0), rt6)
def test_infer_time_unit(self):
self.assertEqual('minutes', infer_time_unit([130, 5400, 10]))
self.assertEqual('seconds', infer_time_unit([110, 50, 10, 100]))
self.assertEqual('hours', infer_time_unit([100000, 50000, 10000, 20000]))
self.assertEqual('days', infer_time_unit([200000, 100000]))
def test_scale_time_units(self):
# use assert_almost_equal from numpy.testing since we are comparing
# floating point arrays
arr1 = scale_time_units([130, 5400, 10], 'minutes')
assert_array_almost_equal(arr1, [2.167, 90.0, 0.167], decimal=3)
arr2 = scale_time_units([110, 50, 10, 100], 'seconds')
assert_array_almost_equal(arr2, [110.0, 50.0, 10.0, 100.0], decimal=3)
arr3 = scale_time_units([100000, 50000, 10000, 20000], 'hours')
assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556],
decimal=3)
arr4 = scale_time_units([200000, 100000], 'days')
assert_array_almost_equal(arr4, [2.315, 1.157], decimal=3)
def test_bad_trigger_rule(self):
with self.assertRaises(AirflowException):
DummyOperator(
task_id='test_bad_trigger',
trigger_rule="non_existent",
dag=self.dag)
def test_terminate_task(self):
"""If a task instance's db state get deleted, it should fail"""
from airflow.executors.sequential_executor import SequentialExecutor
TI = TaskInstance
dag = self.dagbag.dags.get('test_utils')
task = dag.task_dict.get('sleeps_forever')
ti = TI(task=task, execution_date=DEFAULT_DATE)
job = LocalTaskJob(
task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
# Running task instance asynchronously
proc = multiprocessing.Process(target=job.run)
proc.start()
sleep(5)
settings.engine.dispose()
session = settings.Session()
ti.refresh_from_db(session=session)
# making sure it's actually running
self.assertEqual(State.RUNNING, ti.state)
ti = session.query(TI).filter_by(
dag_id=task.dag_id,
task_id=task.task_id,
execution_date=DEFAULT_DATE
).one()
# deleting the instance should result in a failure
session.delete(ti)
session.commit()
# waiting for the async task to finish
proc.join()
# making sure that the task ended up as failed
ti.refresh_from_db(session=session)
self.assertEqual(State.FAILED, ti.state)
session.close()
def test_task_fail_duration(self):
"""If a task fails, the duration should be recorded in TaskFail"""
op1 = BashOperator(
task_id='pass_sleepy',
bash_command='sleep 3',
dag=self.dag)
op2 = BashOperator(
task_id='fail_sleepy',
bash_command='sleep 5',
execution_timeout=timedelta(seconds=3),
retry_delay=timedelta(seconds=0),
dag=self.dag)
session = settings.Session()
try:
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
except Exception: # pylint: disable=broad-except
pass
try:
op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
except Exception: # pylint: disable=broad-except
pass
op1_fails = session.query(TaskFail).filter_by(
task_id='pass_sleepy',
dag_id=self.dag.dag_id,
execution_date=DEFAULT_DATE).all()
op2_fails = session.query(TaskFail).filter_by(
task_id='fail_sleepy',
dag_id=self.dag.dag_id,
execution_date=DEFAULT_DATE).all()
self.assertEqual(0, len(op1_fails))
self.assertEqual(1, len(op2_fails))
self.assertGreaterEqual(sum([f.duration for f in op2_fails]), 3)
def test_externally_triggered_dagrun(self):
TI = TaskInstance
# Create the dagrun between two "scheduled" execution dates of the DAG
execution_date = DEFAULT_DATE + timedelta(days=2)
execution_ds = execution_date.strftime('%Y-%m-%d')
execution_ds_nodash = execution_ds.replace('-', '')
dag = DAG(
TEST_DAG_ID,
default_args=self.args,
schedule_interval=timedelta(weeks=1),
start_date=DEFAULT_DATE)
task = DummyOperator(task_id='test_externally_triggered_dag_context',
dag=dag)
dag.create_dagrun(run_type=DagRunType.SCHEDULED,
execution_date=execution_date,
state=State.RUNNING,
external_trigger=True)
task.run(
start_date=execution_date, end_date=execution_date)
ti = TI(task=task, execution_date=execution_date)
context = ti.get_template_context()
# next_ds/prev_ds should be the execution date for manually triggered runs
self.assertEqual(context['next_ds'], execution_ds)
self.assertEqual(context['next_ds_nodash'], execution_ds_nodash)
self.assertEqual(context['prev_ds'], execution_ds)
self.assertEqual(context['prev_ds_nodash'], execution_ds_nodash)