[AIRFLOW-5889] Make polling for AWS Batch job status more resillient (#6765)
- errors in polling for job status should not fail the airflow task when the polling hits an API throttle limit; polling should detect those cases and retry a few times to get the job status, only failing the task when the job description cannot be retrieved - added typing for the BatchProtocol method return types, based on the botocore.client.Batch types - applied trivial format consistency using black, i.e. $ black -t py36 -l 96 {files}
This commit is contained in:
Родитель
6882d355b9
Коммит
479ee63921
|
@ -23,6 +23,9 @@ from random import randint
|
|||
from time import sleep
|
||||
from typing import Optional
|
||||
|
||||
import botocore.exceptions
|
||||
import botocore.waiter
|
||||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.models import BaseOperator
|
||||
|
@ -31,16 +34,22 @@ from airflow.utils.decorators import apply_defaults
|
|||
|
||||
|
||||
class BatchProtocol(Protocol):
|
||||
def submit_job(self, jobName, jobQueue, jobDefinition, containerOverrides):
|
||||
"""
|
||||
.. seealso:: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
|
||||
"""
|
||||
|
||||
def describe_jobs(self, jobs) -> dict:
|
||||
...
|
||||
|
||||
def get_waiter(self, x: str):
|
||||
def get_waiter(self, x: str) -> botocore.waiter.Waiter:
|
||||
...
|
||||
|
||||
def describe_jobs(self, jobs):
|
||||
def submit_job(
|
||||
self, jobName, jobQueue, jobDefinition, arrayProperties, parameters, containerOverrides
|
||||
) -> dict:
|
||||
...
|
||||
|
||||
def terminate_job(self, jobId: str, reason: str):
|
||||
def terminate_job(self, jobId: str, reason: str) -> dict:
|
||||
...
|
||||
|
||||
|
||||
|
@ -72,6 +81,8 @@ class AWSBatchOperator(BaseOperator):
|
|||
:param max_retries: exponential backoff retries while waiter is not
|
||||
merged, 4200 = 48 hours
|
||||
:type max_retries: int
|
||||
:param status_retries: number of retries to get job description (status), 10
|
||||
:type status_retries: int
|
||||
:param aws_conn_id: connection id of AWS credentials / region name. If None,
|
||||
credential boto3 strategy will be used
|
||||
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
|
||||
|
@ -81,14 +92,33 @@ class AWSBatchOperator(BaseOperator):
|
|||
:type region_name: str
|
||||
"""
|
||||
|
||||
ui_color = '#c3dae0'
|
||||
client = None # type: Optional[BatchProtocol]
|
||||
arn = None # type: Optional[str]
|
||||
template_fields = ('job_name', 'overrides', 'parameters',)
|
||||
MAX_RETRIES = 4200
|
||||
STATUS_RETRIES = 10
|
||||
|
||||
ui_color = "#c3dae0"
|
||||
client = None # type: BatchProtocol
|
||||
arn = None # type: str
|
||||
template_fields = (
|
||||
"job_name",
|
||||
"overrides",
|
||||
"parameters",
|
||||
)
|
||||
|
||||
@apply_defaults
|
||||
def __init__(self, job_name, job_definition, job_queue, overrides, array_properties=None,
|
||||
parameters=None, max_retries=4200, aws_conn_id=None, region_name=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
job_name,
|
||||
job_definition,
|
||||
job_queue,
|
||||
overrides,
|
||||
array_properties=None,
|
||||
parameters=None,
|
||||
max_retries=MAX_RETRIES,
|
||||
status_retries=STATUS_RETRIES,
|
||||
aws_conn_id=None,
|
||||
region_name=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.job_name = job_name
|
||||
|
@ -100,6 +130,7 @@ class AWSBatchOperator(BaseOperator):
|
|||
self.array_properties = array_properties or {}
|
||||
self.parameters = parameters
|
||||
self.max_retries = max_retries
|
||||
self.status_retries = status_retries
|
||||
|
||||
self.jobId = None # pylint: disable=invalid-name
|
||||
self.jobName = None # pylint: disable=invalid-name
|
||||
|
@ -108,37 +139,36 @@ class AWSBatchOperator(BaseOperator):
|
|||
|
||||
def execute(self, context):
|
||||
self.log.info(
|
||||
'Running AWS Batch Job - Job definition: %s - on queue %s',
|
||||
self.job_definition, self.job_queue
|
||||
"Running AWS Batch Job - Job definition: %s - on queue %s",
|
||||
self.job_definition,
|
||||
self.job_queue,
|
||||
)
|
||||
self.log.info('AWSBatchOperator overrides: %s', self.overrides)
|
||||
self.log.info("AWSBatchOperator overrides: %s", self.overrides)
|
||||
|
||||
self.client = self.hook.get_client_type(
|
||||
'batch',
|
||||
region_name=self.region_name
|
||||
)
|
||||
self.client = self.hook.get_client_type("batch", region_name=self.region_name)
|
||||
|
||||
try:
|
||||
|
||||
response = self.client.submit_job(
|
||||
jobName=self.job_name,
|
||||
jobQueue=self.job_queue,
|
||||
jobDefinition=self.job_definition,
|
||||
arrayProperties=self.array_properties,
|
||||
parameters=self.parameters,
|
||||
containerOverrides=self.overrides)
|
||||
containerOverrides=self.overrides,
|
||||
)
|
||||
|
||||
self.log.info('AWS Batch Job started: %s', response)
|
||||
self.log.info("AWS Batch Job started: %s", response)
|
||||
|
||||
self.jobId = response['jobId']
|
||||
self.jobName = response['jobName']
|
||||
self.jobId = response["jobId"]
|
||||
self.jobName = response["jobName"]
|
||||
|
||||
self._wait_for_task_ended()
|
||||
|
||||
self._check_success_task()
|
||||
|
||||
self.log.info('AWS Batch Job has been successfully executed: %s', response)
|
||||
self.log.info("AWS Batch Job has been successfully executed: %s", response)
|
||||
except Exception as e:
|
||||
self.log.info('AWS Batch Job has failed executed')
|
||||
self.log.info("AWS Batch Job has failed executed")
|
||||
raise AirflowException(e)
|
||||
|
||||
def _wait_for_task_ended(self):
|
||||
|
@ -152,64 +182,110 @@ class AWSBatchOperator(BaseOperator):
|
|||
* docs.aws.amazon.com/general/latest/gr/api-retries.html
|
||||
"""
|
||||
try:
|
||||
waiter = self.client.get_waiter('job_execution_complete')
|
||||
waiter = self.client.get_waiter("job_execution_complete")
|
||||
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
|
||||
waiter.wait(jobs=[self.jobId])
|
||||
except ValueError:
|
||||
# If waiter not available use expo
|
||||
self._poll_for_task_ended()
|
||||
|
||||
# Allow a batch job some time to spin up. A random interval
|
||||
# decreases the chances of exceeding an AWS API throttle
|
||||
# limit when there are many concurrent tasks.
|
||||
pause = randint(5, 30)
|
||||
def _poll_for_task_ended(self):
|
||||
"""
|
||||
Poll for job status
|
||||
|
||||
retries = 1
|
||||
while retries <= self.max_retries:
|
||||
self.log.info('AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds',
|
||||
self.jobId, retries, self.max_retries, pause)
|
||||
sleep(pause)
|
||||
* docs.aws.amazon.com/general/latest/gr/api-retries.html
|
||||
"""
|
||||
# Allow a batch job some time to spin up. A random interval
|
||||
# decreases the chances of exceeding an AWS API throttle
|
||||
# limit when there are many concurrent tasks.
|
||||
pause = randint(5, 30)
|
||||
|
||||
tries = 0
|
||||
while tries < self.max_retries:
|
||||
tries += 1
|
||||
self.log.info(
|
||||
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
|
||||
self.jobId,
|
||||
tries,
|
||||
self.max_retries,
|
||||
pause,
|
||||
)
|
||||
sleep(pause)
|
||||
|
||||
response = self._get_job_description()
|
||||
jobs = response.get("jobs")
|
||||
status = jobs[-1]["status"] # check last job status
|
||||
self.log.info("AWS Batch job (%s) status: %s", self.jobId, status)
|
||||
|
||||
# status options: 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
|
||||
if status in ["SUCCEEDED", "FAILED"]:
|
||||
break
|
||||
|
||||
pause = 1 + pow(tries * 0.3, 2)
|
||||
|
||||
def _get_job_description(self) -> Optional[dict]:
|
||||
"""
|
||||
Get job description
|
||||
|
||||
* https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html
|
||||
* https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
|
||||
"""
|
||||
tries = 0
|
||||
while tries < self.status_retries:
|
||||
tries += 1
|
||||
try:
|
||||
response = self.client.describe_jobs(jobs=[self.jobId])
|
||||
status = response['jobs'][-1]['status']
|
||||
self.log.info('AWS Batch job (%s) status: %s', self.jobId, status)
|
||||
if status in ['SUCCEEDED', 'FAILED']:
|
||||
break
|
||||
if response and response.get("jobs"):
|
||||
return response
|
||||
else:
|
||||
self.log.error("Job description has no jobs (%s): %s", self.jobId, response)
|
||||
except botocore.exceptions.ClientError as err:
|
||||
response = err.response
|
||||
self.log.error("Job description error (%s): %s", self.jobId, response)
|
||||
if tries < self.status_retries:
|
||||
error = response.get("Error", {})
|
||||
if error.get("Code") == "TooManyRequestsException":
|
||||
pause = randint(1, 10) # avoid excess requests with a random pause
|
||||
self.log.info(
|
||||
"AWS Batch job (%s) status retry (%d of %d) in the next %.2f seconds",
|
||||
self.jobId,
|
||||
tries,
|
||||
self.status_retries,
|
||||
pause,
|
||||
)
|
||||
sleep(pause)
|
||||
continue
|
||||
|
||||
retries += 1
|
||||
pause = 1 + pow(retries * 0.3, 2)
|
||||
msg = "Failed to get job description ({})".format(self.jobId)
|
||||
raise AirflowException(msg)
|
||||
|
||||
def _check_success_task(self):
|
||||
response = self.client.describe_jobs(
|
||||
jobs=[self.jobId],
|
||||
)
|
||||
"""
|
||||
Check the final status of the batch job; the job status options are:
|
||||
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
|
||||
"""
|
||||
response = self._get_job_description()
|
||||
jobs = response.get("jobs")
|
||||
|
||||
self.log.info('AWS Batch stopped, check status: %s', response)
|
||||
if len(response.get('jobs')) < 1:
|
||||
raise AirflowException('No job found for {}'.format(response))
|
||||
matching_jobs = [job for job in jobs if job["jobId"] == self.jobId]
|
||||
if not matching_jobs:
|
||||
raise AirflowException(
|
||||
"Job ({}) has no job description {}".format(self.jobId, response)
|
||||
)
|
||||
|
||||
for job in response['jobs']:
|
||||
job_status = job['status']
|
||||
if job_status == 'FAILED':
|
||||
reason = job['statusReason']
|
||||
raise AirflowException('Job failed with status {}'.format(reason))
|
||||
elif job_status in [
|
||||
'SUBMITTED',
|
||||
'PENDING',
|
||||
'RUNNABLE',
|
||||
'STARTING',
|
||||
'RUNNING'
|
||||
]:
|
||||
raise AirflowException(
|
||||
'This task is still pending {}'.format(job_status))
|
||||
job = matching_jobs[0]
|
||||
self.log.info("AWS Batch stopped, check status: %s", job)
|
||||
job_status = job["status"]
|
||||
if job_status == "FAILED":
|
||||
reason = job["statusReason"]
|
||||
raise AirflowException("Job ({}) failed with status {}".format(self.jobId, reason))
|
||||
elif job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]:
|
||||
raise AirflowException(
|
||||
"Job ({}) is still pending {}".format(self.jobId, job_status)
|
||||
)
|
||||
|
||||
def get_hook(self):
|
||||
return AwsHook(
|
||||
aws_conn_id=self.aws_conn_id
|
||||
)
|
||||
return AwsHook(aws_conn_id=self.aws_conn_id)
|
||||
|
||||
def on_kill(self):
|
||||
response = self.client.terminate_job(
|
||||
jobId=self.jobId,
|
||||
reason='Task killed by the user')
|
||||
|
||||
response = self.client.terminate_job(jobId=self.jobId, reason="Task killed by the user")
|
||||
self.log.info(response)
|
||||
|
|
|
@ -21,72 +21,84 @@
|
|||
import sys
|
||||
import unittest
|
||||
|
||||
import botocore.exceptions
|
||||
|
||||
from airflow.contrib.operators.awsbatch_operator import AWSBatchOperator
|
||||
from airflow.exceptions import AirflowException
|
||||
from tests.compat import mock
|
||||
|
||||
JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3"
|
||||
JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
|
||||
|
||||
RESPONSE_WITHOUT_FAILURES = {
|
||||
"jobName": "51455483-c62c-48ac-9b88-53a6a725baa3",
|
||||
"jobId": "8ba9d676-4108-4474-9dca-8bbac1da9b19"
|
||||
"jobName": JOB_NAME,
|
||||
"jobId": JOB_ID,
|
||||
}
|
||||
|
||||
|
||||
class TestAWSBatchOperator(unittest.TestCase):
|
||||
|
||||
@mock.patch('airflow.contrib.operators.awsbatch_operator.AwsHook')
|
||||
MAX_RETRIES = 2
|
||||
STATUS_RETRIES = 3
|
||||
|
||||
@mock.patch("airflow.contrib.operators.awsbatch_operator.AwsHook")
|
||||
def setUp(self, aws_hook_mock):
|
||||
self.aws_hook_mock = aws_hook_mock
|
||||
self.batch = AWSBatchOperator(
|
||||
task_id='task',
|
||||
job_name='51455483-c62c-48ac-9b88-53a6a725baa3',
|
||||
job_queue='queue',
|
||||
job_definition='hello-world',
|
||||
max_retries=5,
|
||||
task_id="task",
|
||||
job_name=JOB_NAME,
|
||||
job_queue="queue",
|
||||
job_definition="hello-world",
|
||||
max_retries=self.MAX_RETRIES,
|
||||
status_retries=self.STATUS_RETRIES,
|
||||
parameters=None,
|
||||
overrides={},
|
||||
array_properties=None,
|
||||
aws_conn_id=None,
|
||||
region_name='eu-west-1')
|
||||
region_name="eu-west-1",
|
||||
)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.batch.job_name, '51455483-c62c-48ac-9b88-53a6a725baa3')
|
||||
self.assertEqual(self.batch.job_queue, 'queue')
|
||||
self.assertEqual(self.batch.job_definition, 'hello-world')
|
||||
self.assertEqual(self.batch.max_retries, 5)
|
||||
self.assertEqual(self.batch.job_name, JOB_NAME)
|
||||
self.assertEqual(self.batch.job_queue, "queue")
|
||||
self.assertEqual(self.batch.job_definition, "hello-world")
|
||||
self.assertEqual(self.batch.max_retries, self.MAX_RETRIES)
|
||||
self.assertEqual(self.batch.status_retries, self.STATUS_RETRIES)
|
||||
self.assertEqual(self.batch.parameters, None)
|
||||
self.assertEqual(self.batch.overrides, {})
|
||||
self.assertEqual(self.batch.array_properties, {})
|
||||
self.assertEqual(self.batch.region_name, 'eu-west-1')
|
||||
self.assertEqual(self.batch.region_name, "eu-west-1")
|
||||
self.assertEqual(self.batch.aws_conn_id, None)
|
||||
self.assertEqual(self.batch.hook, self.aws_hook_mock.return_value)
|
||||
|
||||
self.aws_hook_mock.assert_called_once_with(aws_conn_id=None)
|
||||
|
||||
def test_template_fields_overrides(self):
|
||||
self.assertEqual(self.batch.template_fields, ('job_name', 'overrides', 'parameters',))
|
||||
self.assertEqual(self.batch.template_fields, ("job_name", "overrides", "parameters",))
|
||||
|
||||
@mock.patch.object(AWSBatchOperator, '_wait_for_task_ended')
|
||||
@mock.patch.object(AWSBatchOperator, '_check_success_task')
|
||||
@mock.patch.object(AWSBatchOperator, "_wait_for_task_ended")
|
||||
@mock.patch.object(AWSBatchOperator, "_check_success_task")
|
||||
def test_execute_without_failures(self, check_mock, wait_mock):
|
||||
client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
|
||||
client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES
|
||||
|
||||
self.batch.execute(None)
|
||||
|
||||
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch',
|
||||
region_name='eu-west-1')
|
||||
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with(
|
||||
"batch", region_name="eu-west-1"
|
||||
)
|
||||
client_mock.submit_job.assert_called_once_with(
|
||||
jobQueue='queue',
|
||||
jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
|
||||
jobQueue="queue",
|
||||
jobName=JOB_NAME,
|
||||
containerOverrides={},
|
||||
jobDefinition='hello-world',
|
||||
jobDefinition="hello-world",
|
||||
arrayProperties={},
|
||||
parameters=None
|
||||
parameters=None,
|
||||
)
|
||||
|
||||
wait_mock.assert_called_once_with()
|
||||
check_mock.assert_called_once_with()
|
||||
self.assertEqual(self.batch.jobId, '8ba9d676-4108-4474-9dca-8bbac1da9b19')
|
||||
self.assertEqual(self.batch.jobId, JOB_ID)
|
||||
|
||||
def test_execute_with_failures(self):
|
||||
client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
|
||||
|
@ -95,122 +107,195 @@ class TestAWSBatchOperator(unittest.TestCase):
|
|||
with self.assertRaises(AirflowException):
|
||||
self.batch.execute(None)
|
||||
|
||||
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch',
|
||||
region_name='eu-west-1')
|
||||
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with(
|
||||
"batch", region_name="eu-west-1"
|
||||
)
|
||||
client_mock.submit_job.assert_called_once_with(
|
||||
jobQueue='queue',
|
||||
jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
|
||||
jobQueue="queue",
|
||||
jobName=JOB_NAME,
|
||||
containerOverrides={},
|
||||
jobDefinition='hello-world',
|
||||
jobDefinition="hello-world",
|
||||
arrayProperties={},
|
||||
parameters=None
|
||||
parameters=None,
|
||||
)
|
||||
|
||||
def test_wait_end_tasks(self):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
self.batch._wait_for_task_ended()
|
||||
|
||||
client_mock.get_waiter.assert_called_once_with('job_execution_complete')
|
||||
client_mock.get_waiter.return_value.wait.assert_called_once_with(
|
||||
jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19']
|
||||
)
|
||||
client_mock.get_waiter.assert_called_once_with("job_execution_complete")
|
||||
client_mock.get_waiter.return_value.wait.assert_called_once_with(jobs=[JOB_ID])
|
||||
self.assertEqual(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts)
|
||||
|
||||
@mock.patch("airflow.contrib.operators.awsbatch_operator.randint")
|
||||
def test_poll_job_status_success(self, mock_randint):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
mock_randint.return_value = 0 # don't pause in unit tests
|
||||
client_mock.get_waiter.return_value.wait.side_effect = ValueError()
|
||||
client_mock.describe_jobs.return_value = {
|
||||
"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]
|
||||
}
|
||||
|
||||
self.batch._wait_for_task_ended()
|
||||
|
||||
client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
|
||||
|
||||
@mock.patch("airflow.contrib.operators.awsbatch_operator.randint")
|
||||
def test_poll_job_status_running(self, mock_randint):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
mock_randint.return_value = 0 # don't pause in unit tests
|
||||
client_mock.get_waiter.return_value.wait.side_effect = ValueError()
|
||||
client_mock.describe_jobs.return_value = {
|
||||
"jobs": [{"jobId": JOB_ID, "status": "RUNNING"}]
|
||||
}
|
||||
|
||||
self.batch._wait_for_task_ended()
|
||||
|
||||
# self.assertEqual(client_mock.describe_jobs.call_count, self.STATUS_RETRIES)
|
||||
client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
|
||||
self.assertEqual(client_mock.describe_jobs.call_count, self.MAX_RETRIES)
|
||||
|
||||
@mock.patch("airflow.contrib.operators.awsbatch_operator.randint")
|
||||
def test_poll_job_status_hit_api_throttle(self, mock_randint):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
mock_randint.return_value = 0 # don't pause in unit tests
|
||||
client_mock.describe_jobs.side_effect = botocore.exceptions.ClientError(
|
||||
error_response={"Error": {"Code": "TooManyRequestsException"}},
|
||||
operation_name="get job description",
|
||||
)
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
self.batch._poll_for_task_ended()
|
||||
|
||||
self.assertIn("Failed to get job description", str(e.exception))
|
||||
client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
|
||||
self.assertEqual(client_mock.describe_jobs.call_count, self.STATUS_RETRIES)
|
||||
|
||||
def test_check_success_tasks_raises(self):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
client_mock.describe_jobs.return_value = {
|
||||
'jobs': []
|
||||
}
|
||||
client_mock.describe_jobs.return_value = {"jobs": []}
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
self.batch._check_success_task()
|
||||
|
||||
# Ordering of str(dict) is not guaranteed.
|
||||
self.assertIn('No job found for ', str(e.exception))
|
||||
self.assertIn("Failed to get job description", str(e.exception))
|
||||
|
||||
def test_check_success_tasks_raises_failed(self):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
client_mock.describe_jobs.return_value = {
|
||||
'jobs': [{
|
||||
'status': 'FAILED',
|
||||
'statusReason': 'This is an error reason',
|
||||
'attempts': [{
|
||||
'exitCode': 1
|
||||
}]
|
||||
}]
|
||||
"jobs": [
|
||||
{
|
||||
"jobId": JOB_ID,
|
||||
"status": "FAILED",
|
||||
"statusReason": "This is an error reason",
|
||||
"attempts": [{"exitCode": 1}],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
self.batch._check_success_task()
|
||||
|
||||
# Ordering of str(dict) is not guaranteed.
|
||||
self.assertIn('Job failed with status ', str(e.exception))
|
||||
self.assertIn("Job ({}) failed with status ".format(JOB_ID), str(e.exception))
|
||||
|
||||
def test_check_success_tasks_raises_pending(self):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
client_mock.describe_jobs.return_value = {
|
||||
'jobs': [{
|
||||
'status': 'RUNNABLE'
|
||||
}]
|
||||
"jobs": [{"jobId": JOB_ID, "status": "RUNNABLE"}]
|
||||
}
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
self.batch._check_success_task()
|
||||
|
||||
# Ordering of str(dict) is not guaranteed.
|
||||
self.assertIn('This task is still pending ', str(e.exception))
|
||||
self.assertIn("Job ({}) is still pending".format(JOB_ID), str(e.exception))
|
||||
|
||||
def test_check_success_tasks_raises_multiple(self):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
client_mock.describe_jobs.return_value = {
|
||||
'jobs': [{
|
||||
'status': 'FAILED',
|
||||
'statusReason': 'This is an error reason',
|
||||
'attempts': [{
|
||||
'exitCode': 1
|
||||
}, {
|
||||
'exitCode': 10
|
||||
}]
|
||||
}]
|
||||
"jobs": [
|
||||
{
|
||||
"jobId": JOB_ID,
|
||||
"status": "FAILED",
|
||||
"statusReason": "This is an error reason",
|
||||
"attempts": [{"exitCode": 1}, {"exitCode": 10}],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
self.batch._check_success_task()
|
||||
|
||||
# Ordering of str(dict) is not guaranteed.
|
||||
self.assertIn('Job failed with status ', str(e.exception))
|
||||
self.assertIn("Job ({}) failed with status ".format(JOB_ID), str(e.exception))
|
||||
|
||||
def test_check_success_task_not_raises(self):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
client_mock.describe_jobs.return_value = {
|
||||
'jobs': [{
|
||||
'status': 'SUCCEEDED'
|
||||
}]
|
||||
"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]
|
||||
}
|
||||
|
||||
self.batch._check_success_task()
|
||||
|
||||
# Ordering of str(dict) is not guaranteed.
|
||||
client_mock.describe_jobs.assert_called_once_with(jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19'])
|
||||
client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
|
||||
|
||||
def test_check_success_task_raises_without_jobs(self):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
client_mock.describe_jobs.return_value = {"jobs": []}
|
||||
|
||||
with self.assertRaises(Exception) as e:
|
||||
self.batch._check_success_task()
|
||||
|
||||
client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
|
||||
self.assertEqual(client_mock.describe_jobs.call_count, self.STATUS_RETRIES)
|
||||
self.assertIn("Failed to get job description", str(e.exception))
|
||||
|
||||
def test_kill_job(self):
|
||||
client_mock = mock.Mock()
|
||||
self.batch.jobId = JOB_ID
|
||||
self.batch.client = client_mock
|
||||
|
||||
client_mock.terminate_job.return_value = {}
|
||||
|
||||
self.batch.on_kill()
|
||||
|
||||
client_mock.terminate_job.assert_called_once_with(
|
||||
jobId=JOB_ID, reason="Task killed by the user"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Загрузка…
Ссылка в новой задаче