489 строки
18 KiB
Python
489 строки
18 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import multiprocessing
|
|
import os
|
|
import signal
|
|
import unittest
|
|
from datetime import timedelta
|
|
from time import sleep
|
|
|
|
from dateutil.relativedelta import relativedelta
|
|
from numpy.testing import assert_array_almost_equal
|
|
|
|
from airflow import settings
|
|
from airflow.exceptions import AirflowException, AirflowTaskTimeout
|
|
from airflow.hooks.base_hook import BaseHook
|
|
from airflow.jobs.local_task_job import LocalTaskJob
|
|
from airflow.models import DagBag, DagRun, TaskFail, TaskInstance
|
|
from airflow.models.baseoperator import BaseOperator
|
|
from airflow.models.dag import DAG
|
|
from airflow.operators.bash import BashOperator
|
|
from airflow.operators.check_operator import CheckOperator, ValueCheckOperator
|
|
from airflow.operators.dummy_operator import DummyOperator
|
|
from airflow.operators.python import PythonOperator
|
|
from airflow.settings import Session
|
|
from airflow.utils.dates import infer_time_unit, round_time, scale_time_units
|
|
from airflow.utils.state import State
|
|
from airflow.utils.timezone import datetime
|
|
from airflow.utils.types import DagRunType
|
|
from tests.test_utils.config import conf_vars
|
|
|
|
DEV_NULL = '/dev/null'
|
|
DEFAULT_DATE = datetime(2015, 1, 1)
|
|
TEST_DAG_ID = 'unit_tests'
|
|
|
|
|
|
class OperatorSubclass(BaseOperator):
|
|
"""
|
|
An operator to test template substitution
|
|
"""
|
|
template_fields = ['some_templated_field']
|
|
|
|
def __init__(self, some_templated_field, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.some_templated_field = some_templated_field
|
|
|
|
def execute(self, context):
|
|
pass
|
|
|
|
|
|
class TestCore(unittest.TestCase):
|
|
default_scheduler_args = {"num_runs": 1}
|
|
|
|
def setUp(self):
|
|
self.dagbag = DagBag(
|
|
dag_folder=DEV_NULL, include_examples=True)
|
|
self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
|
|
self.dag = DAG(TEST_DAG_ID, default_args=self.args)
|
|
self.dag_bash = self.dagbag.dags['example_bash_operator']
|
|
self.runme_0 = self.dag_bash.get_task('runme_0')
|
|
self.run_after_loop = self.dag_bash.get_task('run_after_loop')
|
|
self.run_this_last = self.dag_bash.get_task('run_this_last')
|
|
|
|
def tearDown(self):
|
|
session = Session()
|
|
session.query(DagRun).filter(
|
|
DagRun.dag_id == TEST_DAG_ID).delete(
|
|
synchronize_session=False)
|
|
session.query(TaskInstance).filter(
|
|
TaskInstance.dag_id == TEST_DAG_ID).delete(
|
|
synchronize_session=False)
|
|
session.query(TaskFail).filter(
|
|
TaskFail.dag_id == TEST_DAG_ID).delete(
|
|
synchronize_session=False)
|
|
session.commit()
|
|
session.close()
|
|
|
|
def test_check_operators(self):
|
|
|
|
conn_id = "sqlite_default"
|
|
|
|
captain_hook = BaseHook.get_hook(conn_id=conn_id) # quite funny :D
|
|
captain_hook.run("CREATE TABLE operator_test_table (a, b)")
|
|
captain_hook.run("insert into operator_test_table values (1,2)")
|
|
|
|
op = CheckOperator(
|
|
task_id='check',
|
|
sql="select count(*) from operator_test_table",
|
|
conn_id=conn_id,
|
|
dag=self.dag)
|
|
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
|
|
op = ValueCheckOperator(
|
|
task_id='value_check',
|
|
pass_value=95,
|
|
tolerance=0.1,
|
|
conn_id=conn_id,
|
|
sql="SELECT 100",
|
|
dag=self.dag)
|
|
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
|
|
captain_hook.run("drop table operator_test_table")
|
|
|
|
def test_clear_api(self):
|
|
task = self.dag_bash.tasks[0]
|
|
task.clear(
|
|
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
|
|
upstream=True, downstream=True)
|
|
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
|
|
ti.are_dependents_done()
|
|
|
|
def test_illegal_args(self):
|
|
"""
|
|
Tests that Operators reject illegal arguments
|
|
"""
|
|
msg = 'Invalid arguments were passed to BashOperator (task_id: test_illegal_args).'
|
|
with conf_vars({('operators', 'allow_illegal_arguments'): 'True'}):
|
|
with self.assertWarns(PendingDeprecationWarning) as warning:
|
|
BashOperator(
|
|
task_id='test_illegal_args',
|
|
bash_command='echo success',
|
|
dag=self.dag,
|
|
illegal_argument_1234='hello?')
|
|
assert any(msg in str(w) for w in warning.warnings)
|
|
|
|
def test_illegal_args_forbidden(self):
|
|
"""
|
|
Tests that operators raise exceptions on illegal arguments when
|
|
illegal arguments are not allowed.
|
|
"""
|
|
with self.assertRaises(AirflowException) as ctx:
|
|
BashOperator(
|
|
task_id='test_illegal_args',
|
|
bash_command='echo success',
|
|
dag=self.dag,
|
|
illegal_argument_1234='hello?')
|
|
self.assertIn(
|
|
('Invalid arguments were passed to BashOperator '
|
|
'(task_id: test_illegal_args).'),
|
|
str(ctx.exception))
|
|
|
|
def test_bash_operator(self):
|
|
op = BashOperator(
|
|
task_id='test_bash_operator',
|
|
bash_command="echo success",
|
|
dag=self.dag)
|
|
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
|
|
def test_bash_operator_multi_byte_output(self):
|
|
op = BashOperator(
|
|
task_id='test_multi_byte_bash_operator',
|
|
bash_command="echo \u2600",
|
|
dag=self.dag,
|
|
output_encoding='utf-8')
|
|
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
|
|
def test_bash_operator_kill(self):
|
|
import psutil
|
|
sleep_time = "100%d" % os.getpid()
|
|
op = BashOperator(
|
|
task_id='test_bash_operator_kill',
|
|
execution_timeout=timedelta(seconds=1),
|
|
bash_command="/bin/bash -c 'sleep %s'" % sleep_time,
|
|
dag=self.dag)
|
|
self.assertRaises(
|
|
AirflowTaskTimeout,
|
|
op.run,
|
|
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
|
|
sleep(2)
|
|
pid = -1
|
|
for proc in psutil.process_iter():
|
|
if proc.cmdline() == ['sleep', sleep_time]:
|
|
pid = proc.pid
|
|
if pid != -1:
|
|
os.kill(pid, signal.SIGTERM)
|
|
self.fail("BashOperator's subprocess still running after stopping on timeout!")
|
|
|
|
def test_on_failure_callback(self):
|
|
# Annoying workaround for nonlocal not existing in python 2
|
|
data = {'called': False}
|
|
|
|
def check_failure(context, test_case=self):
|
|
data['called'] = True
|
|
error = context.get('exception')
|
|
test_case.assertIsInstance(error, AirflowException)
|
|
|
|
op = BashOperator(
|
|
task_id='check_on_failure_callback',
|
|
bash_command="exit 1",
|
|
dag=self.dag,
|
|
on_failure_callback=check_failure)
|
|
self.assertRaises(
|
|
AirflowException,
|
|
op.run,
|
|
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
self.assertTrue(data['called'])
|
|
|
|
def test_dryrun(self):
|
|
op = BashOperator(
|
|
task_id='test_dryrun',
|
|
bash_command="echo success",
|
|
dag=self.dag)
|
|
op.dry_run()
|
|
|
|
def test_sqlite(self):
|
|
import airflow.providers.sqlite.operators.sqlite
|
|
op = airflow.providers.sqlite.operators.sqlite.SqliteOperator(
|
|
task_id='time_sqlite',
|
|
sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))",
|
|
dag=self.dag)
|
|
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
|
|
def test_timeout(self):
|
|
op = PythonOperator(
|
|
task_id='test_timeout',
|
|
execution_timeout=timedelta(seconds=1),
|
|
python_callable=lambda: sleep(5),
|
|
dag=self.dag)
|
|
self.assertRaises(
|
|
AirflowTaskTimeout,
|
|
op.run,
|
|
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
|
|
def test_python_op(self):
|
|
def test_py_op(templates_dict, ds, **kwargs):
|
|
if not templates_dict['ds'] == ds:
|
|
raise Exception("failure")
|
|
|
|
op = PythonOperator(
|
|
task_id='test_py_op',
|
|
python_callable=test_py_op,
|
|
templates_dict={'ds': "{{ ds }}"},
|
|
dag=self.dag)
|
|
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
|
|
def test_complex_template(self):
|
|
def verify_templated_field(context):
|
|
self.assertEqual(context['ti'].task.some_templated_field['bar'][1],
|
|
context['ds'])
|
|
|
|
op = OperatorSubclass(
|
|
task_id='test_complex_template',
|
|
some_templated_field={
|
|
'foo': '123',
|
|
'bar': ['baz', '{{ ds }}']
|
|
},
|
|
dag=self.dag)
|
|
op.execute = verify_templated_field
|
|
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
|
|
def test_template_non_bool(self):
|
|
"""
|
|
Test templates can handle objects with no sense of truthiness
|
|
"""
|
|
|
|
class NonBoolObject:
|
|
def __len__(self): # pylint: disable=invalid-length-returned
|
|
return NotImplemented
|
|
|
|
def __bool__(self): # pylint: disable=invalid-bool-returned, bad-option-value
|
|
return NotImplemented
|
|
|
|
op = OperatorSubclass(
|
|
task_id='test_bad_template_obj',
|
|
some_templated_field=NonBoolObject(),
|
|
dag=self.dag)
|
|
op.resolve_template_files()
|
|
|
|
def test_task_get_template(self):
|
|
TI = TaskInstance
|
|
ti = TI(
|
|
task=self.runme_0, execution_date=DEFAULT_DATE)
|
|
ti.dag = self.dag_bash
|
|
ti.run(ignore_ti_state=True)
|
|
context = ti.get_template_context()
|
|
|
|
# DEFAULT DATE is 2015-01-01
|
|
self.assertEqual(context['ds'], '2015-01-01')
|
|
self.assertEqual(context['ds_nodash'], '20150101')
|
|
|
|
# next_ds is 2015-01-02 as the dag interval is daily
|
|
self.assertEqual(context['next_ds'], '2015-01-02')
|
|
self.assertEqual(context['next_ds_nodash'], '20150102')
|
|
|
|
# prev_ds is 2014-12-31 as the dag interval is daily
|
|
self.assertEqual(context['prev_ds'], '2014-12-31')
|
|
self.assertEqual(context['prev_ds_nodash'], '20141231')
|
|
|
|
self.assertEqual(context['ts'], '2015-01-01T00:00:00+00:00')
|
|
self.assertEqual(context['ts_nodash'], '20150101T000000')
|
|
self.assertEqual(context['ts_nodash_with_tz'], '20150101T000000+0000')
|
|
|
|
self.assertEqual(context['yesterday_ds'], '2014-12-31')
|
|
self.assertEqual(context['yesterday_ds_nodash'], '20141231')
|
|
|
|
self.assertEqual(context['tomorrow_ds'], '2015-01-02')
|
|
self.assertEqual(context['tomorrow_ds_nodash'], '20150102')
|
|
|
|
def test_local_task_job(self):
|
|
TI = TaskInstance
|
|
ti = TI(
|
|
task=self.runme_0, execution_date=DEFAULT_DATE)
|
|
job = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
|
|
job.run()
|
|
|
|
def test_raw_job(self):
|
|
TI = TaskInstance
|
|
ti = TI(
|
|
task=self.runme_0, execution_date=DEFAULT_DATE)
|
|
ti.dag = self.dag_bash
|
|
ti.run(ignore_ti_state=True)
|
|
|
|
def test_round_time(self):
|
|
|
|
rt1 = round_time(datetime(2015, 1, 1, 6), timedelta(days=1))
|
|
self.assertEqual(datetime(2015, 1, 1, 0, 0), rt1)
|
|
|
|
rt2 = round_time(datetime(2015, 1, 2), relativedelta(months=1))
|
|
self.assertEqual(datetime(2015, 1, 1, 0, 0), rt2)
|
|
|
|
rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(
|
|
2015, 9, 14, 0, 0))
|
|
self.assertEqual(datetime(2015, 9, 16, 0, 0), rt3)
|
|
|
|
rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(
|
|
2015, 9, 14, 0, 0))
|
|
self.assertEqual(datetime(2015, 9, 15, 0, 0), rt4)
|
|
|
|
rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(
|
|
2015, 9, 14, 0, 0))
|
|
self.assertEqual(datetime(2015, 9, 14, 0, 0), rt5)
|
|
|
|
rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(
|
|
2015, 9, 14, 0, 0))
|
|
self.assertEqual(datetime(2015, 9, 14, 0, 0), rt6)
|
|
|
|
def test_infer_time_unit(self):
|
|
|
|
self.assertEqual('minutes', infer_time_unit([130, 5400, 10]))
|
|
|
|
self.assertEqual('seconds', infer_time_unit([110, 50, 10, 100]))
|
|
|
|
self.assertEqual('hours', infer_time_unit([100000, 50000, 10000, 20000]))
|
|
|
|
self.assertEqual('days', infer_time_unit([200000, 100000]))
|
|
|
|
def test_scale_time_units(self):
|
|
|
|
# use assert_almost_equal from numpy.testing since we are comparing
|
|
# floating point arrays
|
|
arr1 = scale_time_units([130, 5400, 10], 'minutes')
|
|
assert_array_almost_equal(arr1, [2.167, 90.0, 0.167], decimal=3)
|
|
|
|
arr2 = scale_time_units([110, 50, 10, 100], 'seconds')
|
|
assert_array_almost_equal(arr2, [110.0, 50.0, 10.0, 100.0], decimal=3)
|
|
|
|
arr3 = scale_time_units([100000, 50000, 10000, 20000], 'hours')
|
|
assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556],
|
|
decimal=3)
|
|
|
|
arr4 = scale_time_units([200000, 100000], 'days')
|
|
assert_array_almost_equal(arr4, [2.315, 1.157], decimal=3)
|
|
|
|
def test_bad_trigger_rule(self):
|
|
with self.assertRaises(AirflowException):
|
|
DummyOperator(
|
|
task_id='test_bad_trigger',
|
|
trigger_rule="non_existent",
|
|
dag=self.dag)
|
|
|
|
def test_terminate_task(self):
|
|
"""If a task instance's db state get deleted, it should fail"""
|
|
from airflow.executors.sequential_executor import SequentialExecutor
|
|
TI = TaskInstance
|
|
dag = self.dagbag.dags.get('test_utils')
|
|
task = dag.task_dict.get('sleeps_forever')
|
|
|
|
ti = TI(task=task, execution_date=DEFAULT_DATE)
|
|
job = LocalTaskJob(
|
|
task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
|
|
|
|
# Running task instance asynchronously
|
|
proc = multiprocessing.Process(target=job.run)
|
|
proc.start()
|
|
sleep(5)
|
|
settings.engine.dispose()
|
|
session = settings.Session()
|
|
ti.refresh_from_db(session=session)
|
|
# making sure it's actually running
|
|
self.assertEqual(State.RUNNING, ti.state)
|
|
ti = session.query(TI).filter_by(
|
|
dag_id=task.dag_id,
|
|
task_id=task.task_id,
|
|
execution_date=DEFAULT_DATE
|
|
).one()
|
|
|
|
# deleting the instance should result in a failure
|
|
session.delete(ti)
|
|
session.commit()
|
|
# waiting for the async task to finish
|
|
proc.join()
|
|
|
|
# making sure that the task ended up as failed
|
|
ti.refresh_from_db(session=session)
|
|
self.assertEqual(State.FAILED, ti.state)
|
|
session.close()
|
|
|
|
def test_task_fail_duration(self):
|
|
"""If a task fails, the duration should be recorded in TaskFail"""
|
|
|
|
op1 = BashOperator(
|
|
task_id='pass_sleepy',
|
|
bash_command='sleep 3',
|
|
dag=self.dag)
|
|
op2 = BashOperator(
|
|
task_id='fail_sleepy',
|
|
bash_command='sleep 5',
|
|
execution_timeout=timedelta(seconds=3),
|
|
retry_delay=timedelta(seconds=0),
|
|
dag=self.dag)
|
|
session = settings.Session()
|
|
try:
|
|
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
except Exception: # pylint: disable=broad-except
|
|
pass
|
|
try:
|
|
op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
|
|
except Exception: # pylint: disable=broad-except
|
|
pass
|
|
op1_fails = session.query(TaskFail).filter_by(
|
|
task_id='pass_sleepy',
|
|
dag_id=self.dag.dag_id,
|
|
execution_date=DEFAULT_DATE).all()
|
|
op2_fails = session.query(TaskFail).filter_by(
|
|
task_id='fail_sleepy',
|
|
dag_id=self.dag.dag_id,
|
|
execution_date=DEFAULT_DATE).all()
|
|
|
|
self.assertEqual(0, len(op1_fails))
|
|
self.assertEqual(1, len(op2_fails))
|
|
self.assertGreaterEqual(sum([f.duration for f in op2_fails]), 3)
|
|
|
|
def test_externally_triggered_dagrun(self):
|
|
TI = TaskInstance
|
|
|
|
# Create the dagrun between two "scheduled" execution dates of the DAG
|
|
execution_date = DEFAULT_DATE + timedelta(days=2)
|
|
execution_ds = execution_date.strftime('%Y-%m-%d')
|
|
execution_ds_nodash = execution_ds.replace('-', '')
|
|
|
|
dag = DAG(
|
|
TEST_DAG_ID,
|
|
default_args=self.args,
|
|
schedule_interval=timedelta(weeks=1),
|
|
start_date=DEFAULT_DATE)
|
|
task = DummyOperator(task_id='test_externally_triggered_dag_context',
|
|
dag=dag)
|
|
dag.create_dagrun(run_type=DagRunType.SCHEDULED,
|
|
execution_date=execution_date,
|
|
state=State.RUNNING,
|
|
external_trigger=True)
|
|
task.run(
|
|
start_date=execution_date, end_date=execution_date)
|
|
|
|
ti = TI(task=task, execution_date=execution_date)
|
|
context = ti.get_template_context()
|
|
|
|
# next_ds/prev_ds should be the execution date for manually triggered runs
|
|
self.assertEqual(context['next_ds'], execution_ds)
|
|
self.assertEqual(context['next_ds_nodash'], execution_ds_nodash)
|
|
|
|
self.assertEqual(context['prev_ds'], execution_ds)
|
|
self.assertEqual(context['prev_ds_nodash'], execution_ds_nodash)
|