[AIRFLOW-5889] Make polling for AWS Batch job status more resillient (#6765)

- errors in polling for job status should not fail
  the airflow task when the polling hits an API throttle
  limit; polling should detect those cases and retry a
  few times to get the job status, only failing the task
  when the job description cannot be retrieved
- added typing for the BatchProtocol method return
  types, based on the botocore.client.Batch types
- applied trivial format consistency using black, i.e.
  $ black -t py36 -l 96 {files}
This commit is contained in:
Darren Weber 2019-12-12 03:30:43 -08:00 коммит произвёл Ash Berlin-Taylor
Родитель 6882d355b9
Коммит 479ee63921
2 изменённых файлов: 300 добавлений и 139 удалений

Просмотреть файл

@ -23,6 +23,9 @@ from random import randint
from time import sleep
from typing import Optional
import botocore.exceptions
import botocore.waiter
from airflow.contrib.hooks.aws_hook import AwsHook
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@ -31,16 +34,22 @@ from airflow.utils.decorators import apply_defaults
class BatchProtocol(Protocol):
def submit_job(self, jobName, jobQueue, jobDefinition, containerOverrides):
"""
.. seealso:: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
"""
def describe_jobs(self, jobs) -> dict:
...
def get_waiter(self, x: str):
def get_waiter(self, x: str) -> botocore.waiter.Waiter:
...
def describe_jobs(self, jobs):
def submit_job(
self, jobName, jobQueue, jobDefinition, arrayProperties, parameters, containerOverrides
) -> dict:
...
def terminate_job(self, jobId: str, reason: str):
def terminate_job(self, jobId: str, reason: str) -> dict:
...
@ -72,6 +81,8 @@ class AWSBatchOperator(BaseOperator):
:param max_retries: exponential backoff retries while waiter is not
merged, 4200 = 48 hours
:type max_retries: int
:param status_retries: number of retries to get job description (status), 10
:type status_retries: int
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
(http://boto3.readthedocs.io/en/latest/guide/configuration.html).
@ -81,14 +92,33 @@ class AWSBatchOperator(BaseOperator):
:type region_name: str
"""
ui_color = '#c3dae0'
client = None # type: Optional[BatchProtocol]
arn = None # type: Optional[str]
template_fields = ('job_name', 'overrides', 'parameters',)
MAX_RETRIES = 4200
STATUS_RETRIES = 10
ui_color = "#c3dae0"
client = None # type: BatchProtocol
arn = None # type: str
template_fields = (
"job_name",
"overrides",
"parameters",
)
@apply_defaults
def __init__(self, job_name, job_definition, job_queue, overrides, array_properties=None,
parameters=None, max_retries=4200, aws_conn_id=None, region_name=None, **kwargs):
def __init__(
self,
job_name,
job_definition,
job_queue,
overrides,
array_properties=None,
parameters=None,
max_retries=MAX_RETRIES,
status_retries=STATUS_RETRIES,
aws_conn_id=None,
region_name=None,
**kwargs,
):
super().__init__(**kwargs)
self.job_name = job_name
@ -100,6 +130,7 @@ class AWSBatchOperator(BaseOperator):
self.array_properties = array_properties or {}
self.parameters = parameters
self.max_retries = max_retries
self.status_retries = status_retries
self.jobId = None # pylint: disable=invalid-name
self.jobName = None # pylint: disable=invalid-name
@ -108,37 +139,36 @@ class AWSBatchOperator(BaseOperator):
def execute(self, context):
self.log.info(
'Running AWS Batch Job - Job definition: %s - on queue %s',
self.job_definition, self.job_queue
"Running AWS Batch Job - Job definition: %s - on queue %s",
self.job_definition,
self.job_queue,
)
self.log.info('AWSBatchOperator overrides: %s', self.overrides)
self.log.info("AWSBatchOperator overrides: %s", self.overrides)
self.client = self.hook.get_client_type(
'batch',
region_name=self.region_name
)
self.client = self.hook.get_client_type("batch", region_name=self.region_name)
try:
response = self.client.submit_job(
jobName=self.job_name,
jobQueue=self.job_queue,
jobDefinition=self.job_definition,
arrayProperties=self.array_properties,
parameters=self.parameters,
containerOverrides=self.overrides)
containerOverrides=self.overrides,
)
self.log.info('AWS Batch Job started: %s', response)
self.log.info("AWS Batch Job started: %s", response)
self.jobId = response['jobId']
self.jobName = response['jobName']
self.jobId = response["jobId"]
self.jobName = response["jobName"]
self._wait_for_task_ended()
self._check_success_task()
self.log.info('AWS Batch Job has been successfully executed: %s', response)
self.log.info("AWS Batch Job has been successfully executed: %s", response)
except Exception as e:
self.log.info('AWS Batch Job has failed executed')
self.log.info("AWS Batch Job has failed executed")
raise AirflowException(e)
def _wait_for_task_ended(self):
@ -152,64 +182,110 @@ class AWSBatchOperator(BaseOperator):
* docs.aws.amazon.com/general/latest/gr/api-retries.html
"""
try:
waiter = self.client.get_waiter('job_execution_complete')
waiter = self.client.get_waiter("job_execution_complete")
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(jobs=[self.jobId])
except ValueError:
# If waiter not available use expo
self._poll_for_task_ended()
# Allow a batch job some time to spin up. A random interval
# decreases the chances of exceeding an AWS API throttle
# limit when there are many concurrent tasks.
pause = randint(5, 30)
def _poll_for_task_ended(self):
"""
Poll for job status
retries = 1
while retries <= self.max_retries:
self.log.info('AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds',
self.jobId, retries, self.max_retries, pause)
sleep(pause)
* docs.aws.amazon.com/general/latest/gr/api-retries.html
"""
# Allow a batch job some time to spin up. A random interval
# decreases the chances of exceeding an AWS API throttle
# limit when there are many concurrent tasks.
pause = randint(5, 30)
tries = 0
while tries < self.max_retries:
tries += 1
self.log.info(
"AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds",
self.jobId,
tries,
self.max_retries,
pause,
)
sleep(pause)
response = self._get_job_description()
jobs = response.get("jobs")
status = jobs[-1]["status"] # check last job status
self.log.info("AWS Batch job (%s) status: %s", self.jobId, status)
# status options: 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
if status in ["SUCCEEDED", "FAILED"]:
break
pause = 1 + pow(tries * 0.3, 2)
def _get_job_description(self) -> Optional[dict]:
"""
Get job description
* https://docs.aws.amazon.com/batch/latest/APIReference/API_DescribeJobs.html
* https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
"""
tries = 0
while tries < self.status_retries:
tries += 1
try:
response = self.client.describe_jobs(jobs=[self.jobId])
status = response['jobs'][-1]['status']
self.log.info('AWS Batch job (%s) status: %s', self.jobId, status)
if status in ['SUCCEEDED', 'FAILED']:
break
if response and response.get("jobs"):
return response
else:
self.log.error("Job description has no jobs (%s): %s", self.jobId, response)
except botocore.exceptions.ClientError as err:
response = err.response
self.log.error("Job description error (%s): %s", self.jobId, response)
if tries < self.status_retries:
error = response.get("Error", {})
if error.get("Code") == "TooManyRequestsException":
pause = randint(1, 10) # avoid excess requests with a random pause
self.log.info(
"AWS Batch job (%s) status retry (%d of %d) in the next %.2f seconds",
self.jobId,
tries,
self.status_retries,
pause,
)
sleep(pause)
continue
retries += 1
pause = 1 + pow(retries * 0.3, 2)
msg = "Failed to get job description ({})".format(self.jobId)
raise AirflowException(msg)
def _check_success_task(self):
response = self.client.describe_jobs(
jobs=[self.jobId],
)
"""
Check the final status of the batch job; the job status options are:
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
"""
response = self._get_job_description()
jobs = response.get("jobs")
self.log.info('AWS Batch stopped, check status: %s', response)
if len(response.get('jobs')) < 1:
raise AirflowException('No job found for {}'.format(response))
matching_jobs = [job for job in jobs if job["jobId"] == self.jobId]
if not matching_jobs:
raise AirflowException(
"Job ({}) has no job description {}".format(self.jobId, response)
)
for job in response['jobs']:
job_status = job['status']
if job_status == 'FAILED':
reason = job['statusReason']
raise AirflowException('Job failed with status {}'.format(reason))
elif job_status in [
'SUBMITTED',
'PENDING',
'RUNNABLE',
'STARTING',
'RUNNING'
]:
raise AirflowException(
'This task is still pending {}'.format(job_status))
job = matching_jobs[0]
self.log.info("AWS Batch stopped, check status: %s", job)
job_status = job["status"]
if job_status == "FAILED":
reason = job["statusReason"]
raise AirflowException("Job ({}) failed with status {}".format(self.jobId, reason))
elif job_status in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]:
raise AirflowException(
"Job ({}) is still pending {}".format(self.jobId, job_status)
)
def get_hook(self):
return AwsHook(
aws_conn_id=self.aws_conn_id
)
return AwsHook(aws_conn_id=self.aws_conn_id)
def on_kill(self):
response = self.client.terminate_job(
jobId=self.jobId,
reason='Task killed by the user')
response = self.client.terminate_job(jobId=self.jobId, reason="Task killed by the user")
self.log.info(response)

Просмотреть файл

@ -21,72 +21,84 @@
import sys
import unittest
import botocore.exceptions
from airflow.contrib.operators.awsbatch_operator import AWSBatchOperator
from airflow.exceptions import AirflowException
from tests.compat import mock
JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3"
JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
RESPONSE_WITHOUT_FAILURES = {
"jobName": "51455483-c62c-48ac-9b88-53a6a725baa3",
"jobId": "8ba9d676-4108-4474-9dca-8bbac1da9b19"
"jobName": JOB_NAME,
"jobId": JOB_ID,
}
class TestAWSBatchOperator(unittest.TestCase):
@mock.patch('airflow.contrib.operators.awsbatch_operator.AwsHook')
MAX_RETRIES = 2
STATUS_RETRIES = 3
@mock.patch("airflow.contrib.operators.awsbatch_operator.AwsHook")
def setUp(self, aws_hook_mock):
self.aws_hook_mock = aws_hook_mock
self.batch = AWSBatchOperator(
task_id='task',
job_name='51455483-c62c-48ac-9b88-53a6a725baa3',
job_queue='queue',
job_definition='hello-world',
max_retries=5,
task_id="task",
job_name=JOB_NAME,
job_queue="queue",
job_definition="hello-world",
max_retries=self.MAX_RETRIES,
status_retries=self.STATUS_RETRIES,
parameters=None,
overrides={},
array_properties=None,
aws_conn_id=None,
region_name='eu-west-1')
region_name="eu-west-1",
)
def test_init(self):
self.assertEqual(self.batch.job_name, '51455483-c62c-48ac-9b88-53a6a725baa3')
self.assertEqual(self.batch.job_queue, 'queue')
self.assertEqual(self.batch.job_definition, 'hello-world')
self.assertEqual(self.batch.max_retries, 5)
self.assertEqual(self.batch.job_name, JOB_NAME)
self.assertEqual(self.batch.job_queue, "queue")
self.assertEqual(self.batch.job_definition, "hello-world")
self.assertEqual(self.batch.max_retries, self.MAX_RETRIES)
self.assertEqual(self.batch.status_retries, self.STATUS_RETRIES)
self.assertEqual(self.batch.parameters, None)
self.assertEqual(self.batch.overrides, {})
self.assertEqual(self.batch.array_properties, {})
self.assertEqual(self.batch.region_name, 'eu-west-1')
self.assertEqual(self.batch.region_name, "eu-west-1")
self.assertEqual(self.batch.aws_conn_id, None)
self.assertEqual(self.batch.hook, self.aws_hook_mock.return_value)
self.aws_hook_mock.assert_called_once_with(aws_conn_id=None)
def test_template_fields_overrides(self):
self.assertEqual(self.batch.template_fields, ('job_name', 'overrides', 'parameters',))
self.assertEqual(self.batch.template_fields, ("job_name", "overrides", "parameters",))
@mock.patch.object(AWSBatchOperator, '_wait_for_task_ended')
@mock.patch.object(AWSBatchOperator, '_check_success_task')
@mock.patch.object(AWSBatchOperator, "_wait_for_task_ended")
@mock.patch.object(AWSBatchOperator, "_check_success_task")
def test_execute_without_failures(self, check_mock, wait_mock):
client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
client_mock.submit_job.return_value = RESPONSE_WITHOUT_FAILURES
self.batch.execute(None)
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch',
region_name='eu-west-1')
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with(
"batch", region_name="eu-west-1"
)
client_mock.submit_job.assert_called_once_with(
jobQueue='queue',
jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
jobQueue="queue",
jobName=JOB_NAME,
containerOverrides={},
jobDefinition='hello-world',
jobDefinition="hello-world",
arrayProperties={},
parameters=None
parameters=None,
)
wait_mock.assert_called_once_with()
check_mock.assert_called_once_with()
self.assertEqual(self.batch.jobId, '8ba9d676-4108-4474-9dca-8bbac1da9b19')
self.assertEqual(self.batch.jobId, JOB_ID)
def test_execute_with_failures(self):
client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
@ -95,122 +107,195 @@ class TestAWSBatchOperator(unittest.TestCase):
with self.assertRaises(AirflowException):
self.batch.execute(None)
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('batch',
region_name='eu-west-1')
self.aws_hook_mock.return_value.get_client_type.assert_called_once_with(
"batch", region_name="eu-west-1"
)
client_mock.submit_job.assert_called_once_with(
jobQueue='queue',
jobName='51455483-c62c-48ac-9b88-53a6a725baa3',
jobQueue="queue",
jobName=JOB_NAME,
containerOverrides={},
jobDefinition='hello-world',
jobDefinition="hello-world",
arrayProperties={},
parameters=None
parameters=None,
)
def test_wait_end_tasks(self):
client_mock = mock.Mock()
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
self.batch.jobId = JOB_ID
self.batch.client = client_mock
self.batch._wait_for_task_ended()
client_mock.get_waiter.assert_called_once_with('job_execution_complete')
client_mock.get_waiter.return_value.wait.assert_called_once_with(
jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19']
)
client_mock.get_waiter.assert_called_once_with("job_execution_complete")
client_mock.get_waiter.return_value.wait.assert_called_once_with(jobs=[JOB_ID])
self.assertEqual(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts)
@mock.patch("airflow.contrib.operators.awsbatch_operator.randint")
def test_poll_job_status_success(self, mock_randint):
client_mock = mock.Mock()
self.batch.jobId = JOB_ID
self.batch.client = client_mock
mock_randint.return_value = 0 # don't pause in unit tests
client_mock.get_waiter.return_value.wait.side_effect = ValueError()
client_mock.describe_jobs.return_value = {
"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]
}
self.batch._wait_for_task_ended()
client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
@mock.patch("airflow.contrib.operators.awsbatch_operator.randint")
def test_poll_job_status_running(self, mock_randint):
client_mock = mock.Mock()
self.batch.jobId = JOB_ID
self.batch.client = client_mock
mock_randint.return_value = 0 # don't pause in unit tests
client_mock.get_waiter.return_value.wait.side_effect = ValueError()
client_mock.describe_jobs.return_value = {
"jobs": [{"jobId": JOB_ID, "status": "RUNNING"}]
}
self.batch._wait_for_task_ended()
# self.assertEqual(client_mock.describe_jobs.call_count, self.STATUS_RETRIES)
client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
self.assertEqual(client_mock.describe_jobs.call_count, self.MAX_RETRIES)
@mock.patch("airflow.contrib.operators.awsbatch_operator.randint")
def test_poll_job_status_hit_api_throttle(self, mock_randint):
client_mock = mock.Mock()
self.batch.jobId = JOB_ID
self.batch.client = client_mock
mock_randint.return_value = 0 # don't pause in unit tests
client_mock.describe_jobs.side_effect = botocore.exceptions.ClientError(
error_response={"Error": {"Code": "TooManyRequestsException"}},
operation_name="get job description",
)
with self.assertRaises(Exception) as e:
self.batch._poll_for_task_ended()
self.assertIn("Failed to get job description", str(e.exception))
client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
self.assertEqual(client_mock.describe_jobs.call_count, self.STATUS_RETRIES)
def test_check_success_tasks_raises(self):
client_mock = mock.Mock()
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
self.batch.jobId = JOB_ID
self.batch.client = client_mock
client_mock.describe_jobs.return_value = {
'jobs': []
}
client_mock.describe_jobs.return_value = {"jobs": []}
with self.assertRaises(Exception) as e:
self.batch._check_success_task()
# Ordering of str(dict) is not guaranteed.
self.assertIn('No job found for ', str(e.exception))
self.assertIn("Failed to get job description", str(e.exception))
def test_check_success_tasks_raises_failed(self):
client_mock = mock.Mock()
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
self.batch.jobId = JOB_ID
self.batch.client = client_mock
client_mock.describe_jobs.return_value = {
'jobs': [{
'status': 'FAILED',
'statusReason': 'This is an error reason',
'attempts': [{
'exitCode': 1
}]
}]
"jobs": [
{
"jobId": JOB_ID,
"status": "FAILED",
"statusReason": "This is an error reason",
"attempts": [{"exitCode": 1}],
}
]
}
with self.assertRaises(Exception) as e:
self.batch._check_success_task()
# Ordering of str(dict) is not guaranteed.
self.assertIn('Job failed with status ', str(e.exception))
self.assertIn("Job ({}) failed with status ".format(JOB_ID), str(e.exception))
def test_check_success_tasks_raises_pending(self):
client_mock = mock.Mock()
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
self.batch.jobId = JOB_ID
self.batch.client = client_mock
client_mock.describe_jobs.return_value = {
'jobs': [{
'status': 'RUNNABLE'
}]
"jobs": [{"jobId": JOB_ID, "status": "RUNNABLE"}]
}
with self.assertRaises(Exception) as e:
self.batch._check_success_task()
# Ordering of str(dict) is not guaranteed.
self.assertIn('This task is still pending ', str(e.exception))
self.assertIn("Job ({}) is still pending".format(JOB_ID), str(e.exception))
def test_check_success_tasks_raises_multiple(self):
client_mock = mock.Mock()
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
self.batch.jobId = JOB_ID
self.batch.client = client_mock
client_mock.describe_jobs.return_value = {
'jobs': [{
'status': 'FAILED',
'statusReason': 'This is an error reason',
'attempts': [{
'exitCode': 1
}, {
'exitCode': 10
}]
}]
"jobs": [
{
"jobId": JOB_ID,
"status": "FAILED",
"statusReason": "This is an error reason",
"attempts": [{"exitCode": 1}, {"exitCode": 10}],
}
]
}
with self.assertRaises(Exception) as e:
self.batch._check_success_task()
# Ordering of str(dict) is not guaranteed.
self.assertIn('Job failed with status ', str(e.exception))
self.assertIn("Job ({}) failed with status ".format(JOB_ID), str(e.exception))
def test_check_success_task_not_raises(self):
client_mock = mock.Mock()
self.batch.jobId = '8ba9d676-4108-4474-9dca-8bbac1da9b19'
self.batch.jobId = JOB_ID
self.batch.client = client_mock
client_mock.describe_jobs.return_value = {
'jobs': [{
'status': 'SUCCEEDED'
}]
"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]
}
self.batch._check_success_task()
# Ordering of str(dict) is not guaranteed.
client_mock.describe_jobs.assert_called_once_with(jobs=['8ba9d676-4108-4474-9dca-8bbac1da9b19'])
client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
def test_check_success_task_raises_without_jobs(self):
client_mock = mock.Mock()
self.batch.jobId = JOB_ID
self.batch.client = client_mock
client_mock.describe_jobs.return_value = {"jobs": []}
with self.assertRaises(Exception) as e:
self.batch._check_success_task()
client_mock.describe_jobs.assert_called_with(jobs=[JOB_ID])
self.assertEqual(client_mock.describe_jobs.call_count, self.STATUS_RETRIES)
self.assertIn("Failed to get job description", str(e.exception))
def test_kill_job(self):
client_mock = mock.Mock()
self.batch.jobId = JOB_ID
self.batch.client = client_mock
client_mock.terminate_job.return_value = {}
self.batch.on_kill()
client_mock.terminate_job.assert_called_once_with(
jobId=JOB_ID, reason="Task killed by the user"
)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()