This pushes the last cloudwatch event to xcom when do_xcom_push is True

Co-authored-by: Felix Uellendall <feluelle@users.noreply.github.com>
This commit is contained in:
baxievski 2021-01-11 10:07:10 +01:00 коммит произвёл GitHub
Родитель a74fdb7307
Коммит 8d42d9ed69
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 45 добавлений и 8 удалений

Просмотреть файл

@ -17,8 +17,9 @@
# under the License.
import re
import sys
from collections import deque
from datetime import datetime
from typing import Dict, Optional
from typing import Dict, Generator, Optional
from botocore.waiter import Waiter
@ -197,8 +198,14 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
self._wait_for_task_ended()
self._check_success_task()
self.log.info('ECS Task has been successfully executed')
if self.do_xcom_push:
return self._last_log_message()
return None
def _start_task(self):
run_opts = {
'cluster': self.cluster,
@ -260,6 +267,25 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(cluster=self.cluster, tasks=[self.arn])
return
def _cloudwatch_log_events(self) -> Generator:
if self._aws_logs_enabled():
task_id = self.arn.split("/")[-1]
stream_name = f"{self.awslogs_stream_prefix}/{task_id}"
yield from self.get_logs_hook().get_log_events(self.awslogs_group, stream_name)
else:
yield from ()
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix
def _last_log_message(self):
try:
return deque(self._cloudwatch_log_events(), maxlen=1).pop()["message"]
except IndexError:
return None
def _check_success_task(self) -> None:
if not self.client or not self.arn:
return
@ -268,11 +294,7 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
self.log.info('ECS Task stopped, check status: %s', response)
# Get logs from CloudWatch if the awslogs log driver was used
if self.awslogs_group and self.awslogs_stream_prefix:
self.log.info('ECS Task logs output:')
task_id = self.arn.split("/")[-1]
stream_name = f"{self.awslogs_stream_prefix}/{task_id}"
for event in self.get_logs_hook().get_log_events(self.awslogs_group, stream_name):
for event in self._cloudwatch_log_events():
event_dt = datetime.fromtimestamp(event['timestamp'] / 1000.0)
self.log.info("[%s] %s", event_dt.isoformat(), event['message'])

Просмотреть файл

@ -314,3 +314,18 @@ class TestECSOperator(unittest.TestCase):
self.assertEqual(
self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
)
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
def test_execute_xcom_with_log(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = True
self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
@mock.patch.object(ECSOperator, '_last_log_message', return_value=None)
def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = True
self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = False
self.assertEqual(self.ecs.execute(None), None)