This pushes the last cloudwatch event to xcom when do_xcom_push is True

Co-authored-by: Felix Uellendall <feluelle@users.noreply.github.com>
This commit is contained in:
baxievski 2021-01-11 10:07:10 +01:00 коммит произвёл GitHub
Родитель a74fdb7307
Коммит 8d42d9ed69
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 45 добавлений и 8 удалений

Просмотреть файл

@ -17,8 +17,9 @@
# under the License. # under the License.
import re import re
import sys import sys
from collections import deque
from datetime import datetime from datetime import datetime
from typing import Dict, Optional from typing import Dict, Generator, Optional
from botocore.waiter import Waiter from botocore.waiter import Waiter
@ -197,8 +198,14 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
self._wait_for_task_ended() self._wait_for_task_ended()
self._check_success_task() self._check_success_task()
self.log.info('ECS Task has been successfully executed') self.log.info('ECS Task has been successfully executed')
if self.do_xcom_push:
return self._last_log_message()
return None
def _start_task(self): def _start_task(self):
run_opts = { run_opts = {
'cluster': self.cluster, 'cluster': self.cluster,
@ -260,6 +267,25 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(cluster=self.cluster, tasks=[self.arn]) waiter.wait(cluster=self.cluster, tasks=[self.arn])
return
def _cloudwatch_log_events(self) -> Generator:
if self._aws_logs_enabled():
task_id = self.arn.split("/")[-1]
stream_name = f"{self.awslogs_stream_prefix}/{task_id}"
yield from self.get_logs_hook().get_log_events(self.awslogs_group, stream_name)
else:
yield from ()
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix
def _last_log_message(self):
try:
return deque(self._cloudwatch_log_events(), maxlen=1).pop()["message"]
except IndexError:
return None
def _check_success_task(self) -> None: def _check_success_task(self) -> None:
if not self.client or not self.arn: if not self.client or not self.arn:
return return
@ -268,11 +294,7 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
self.log.info('ECS Task stopped, check status: %s', response) self.log.info('ECS Task stopped, check status: %s', response)
# Get logs from CloudWatch if the awslogs log driver was used # Get logs from CloudWatch if the awslogs log driver was used
if self.awslogs_group and self.awslogs_stream_prefix: for event in self._cloudwatch_log_events():
self.log.info('ECS Task logs output:')
task_id = self.arn.split("/")[-1]
stream_name = f"{self.awslogs_stream_prefix}/{task_id}"
for event in self.get_logs_hook().get_log_events(self.awslogs_group, stream_name):
event_dt = datetime.fromtimestamp(event['timestamp'] / 1000.0) event_dt = datetime.fromtimestamp(event['timestamp'] / 1000.0)
self.log.info("[%s] %s", event_dt.isoformat(), event['message']) self.log.info("[%s] %s", event_dt.isoformat(), event['message'])

Просмотреть файл

@ -314,3 +314,18 @@ class TestECSOperator(unittest.TestCase):
self.assertEqual( self.assertEqual(
self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55' self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
) )
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
def test_execute_xcom_with_log(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = True
self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
@mock.patch.object(ECSOperator, '_last_log_message', return_value=None)
def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = True
self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
self.ecs.do_xcom_push = False
self.assertEqual(self.ecs.execute(None), None)