160 строки
5.6 KiB
Python
160 строки
5.6 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
#
|
|
|
|
import datetime
|
|
from unittest.mock import ANY, Mock, patch
|
|
|
|
from pytest import raises
|
|
from sqlalchemy.exc import OperationalError
|
|
|
|
from airflow.executors.sequential_executor import SequentialExecutor
|
|
from airflow.jobs.base_job import BaseJob
|
|
from airflow.utils import timezone
|
|
from airflow.utils.session import create_session
|
|
from airflow.utils.state import State
|
|
from tests.test_utils.config import conf_vars
|
|
|
|
|
|
class MockJob(BaseJob):
|
|
__mapper_args__ = {
|
|
'polymorphic_identity': 'MockJob'
|
|
}
|
|
|
|
def __init__(self, func, **kwargs):
|
|
self.func = func
|
|
super().__init__(**kwargs)
|
|
|
|
def _execute(self):
|
|
return self.func()
|
|
|
|
|
|
class TestBaseJob:
|
|
def test_state_success(self):
|
|
job = MockJob(lambda: True)
|
|
job.run()
|
|
|
|
assert job.state == State.SUCCESS
|
|
assert job.end_date is not None
|
|
|
|
def test_state_sysexit(self):
|
|
import sys
|
|
job = MockJob(lambda: sys.exit(0))
|
|
job.run()
|
|
|
|
assert job.state == State.SUCCESS
|
|
assert job.end_date is not None
|
|
|
|
def test_state_failed(self):
|
|
def abort():
|
|
raise RuntimeError("fail")
|
|
|
|
job = MockJob(abort)
|
|
with raises(RuntimeError):
|
|
job.run()
|
|
|
|
assert job.state == State.FAILED
|
|
assert job.end_date is not None
|
|
|
|
def test_most_recent_job(self):
|
|
with create_session() as session:
|
|
old_job = MockJob(None, heartrate=10)
|
|
old_job.latest_heartbeat = old_job.latest_heartbeat - datetime.timedelta(seconds=20)
|
|
job = MockJob(None, heartrate=10)
|
|
session.add(job)
|
|
session.add(old_job)
|
|
session.flush()
|
|
|
|
assert MockJob.most_recent_job(session=session) == job
|
|
|
|
session.rollback()
|
|
|
|
def test_is_alive(self):
|
|
job = MockJob(None, heartrate=10, state=State.RUNNING)
|
|
assert job.is_alive() is True
|
|
|
|
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=20)
|
|
assert job.is_alive() is True
|
|
|
|
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=21)
|
|
assert job.is_alive() is False
|
|
|
|
# test because .seconds was used before instead of total_seconds
|
|
# internal repr of datetime is (days, seconds)
|
|
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(days=1)
|
|
assert job.is_alive() is False
|
|
|
|
job.state = State.SUCCESS
|
|
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10)
|
|
assert job.is_alive() is False, "Completed jobs even with recent heartbeat should not be alive"
|
|
|
|
@patch('airflow.jobs.base_job.create_session')
|
|
def test_heartbeat_failed(self, mock_create_session):
|
|
when = timezone.utcnow() - datetime.timedelta(seconds=60)
|
|
with create_session() as session:
|
|
mock_session = Mock(spec_set=session, name="MockSession")
|
|
mock_create_session.return_value.__enter__.return_value = mock_session
|
|
|
|
job = MockJob(None, heartrate=10, state=State.RUNNING)
|
|
job.latest_heartbeat = when
|
|
|
|
mock_session.commit.side_effect = OperationalError("Force fail", {}, None)
|
|
|
|
job.heartbeat()
|
|
|
|
assert job.latest_heartbeat == when, "attribute not updated when heartbeat fails"
|
|
|
|
@conf_vars({('scheduler', 'max_tis_per_query'): '100'})
|
|
@patch('airflow.jobs.base_job.ExecutorLoader.get_default_executor')
|
|
@patch('airflow.jobs.base_job.get_hostname')
|
|
@patch('airflow.jobs.base_job.getpass.getuser')
|
|
def test_essential_attr(self, mock_getuser, mock_hostname, mock_default_executor):
|
|
mock_sequential_executor = SequentialExecutor()
|
|
mock_hostname.return_value = "test_hostname"
|
|
mock_getuser.return_value = "testuser"
|
|
mock_default_executor.return_value = mock_sequential_executor
|
|
|
|
test_job = MockJob(None, heartrate=10, dag_id="example_dag", state=State.RUNNING)
|
|
assert test_job.executor_class == "SequentialExecutor"
|
|
assert test_job.heartrate == 10
|
|
assert test_job.dag_id == "example_dag"
|
|
assert test_job.hostname == "test_hostname"
|
|
assert test_job.max_tis_per_query == 100
|
|
assert test_job.unixname == "testuser"
|
|
assert test_job.state == "running"
|
|
assert test_job.executor == mock_sequential_executor
|
|
|
|
def test_heartbeat(self, frozen_sleep, monkeypatch):
|
|
monkeypatch.setattr('airflow.jobs.base_job.sleep', frozen_sleep)
|
|
with create_session() as session:
|
|
job = MockJob(None, heartrate=10)
|
|
job.latest_heartbeat = timezone.utcnow()
|
|
session.add(job)
|
|
session.commit()
|
|
|
|
hb_callback = Mock()
|
|
job.heartbeat_callback = hb_callback
|
|
|
|
job.heartbeat()
|
|
|
|
hb_callback.assert_called_once_with(session=ANY)
|
|
|
|
hb_callback.reset_mock()
|
|
job.heartbeat(only_if_necessary=True)
|
|
assert hb_callback.called is False
|