add xcom push for ECSOperator (#12096)
This pushes the last cloudwatch event to xcom when do_xcom_push is True Co-authored-by: Felix Uellendall <feluelle@users.noreply.github.com>
This commit is contained in:
Родитель
a74fdb7307
Коммит
8d42d9ed69
|
@ -17,8 +17,9 @@
|
|||
# under the License.
|
||||
import re
|
||||
import sys
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Generator, Optional
|
||||
|
||||
from botocore.waiter import Waiter
|
||||
|
||||
|
@ -197,8 +198,14 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
|
|||
self._wait_for_task_ended()
|
||||
|
||||
self._check_success_task()
|
||||
|
||||
self.log.info('ECS Task has been successfully executed')
|
||||
|
||||
if self.do_xcom_push:
|
||||
return self._last_log_message()
|
||||
|
||||
return None
|
||||
|
||||
def _start_task(self):
|
||||
run_opts = {
|
||||
'cluster': self.cluster,
|
||||
|
@ -260,6 +267,25 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
|
|||
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
|
||||
waiter.wait(cluster=self.cluster, tasks=[self.arn])
|
||||
|
||||
return
|
||||
|
||||
def _cloudwatch_log_events(self) -> Generator:
|
||||
if self._aws_logs_enabled():
|
||||
task_id = self.arn.split("/")[-1]
|
||||
stream_name = f"{self.awslogs_stream_prefix}/{task_id}"
|
||||
yield from self.get_logs_hook().get_log_events(self.awslogs_group, stream_name)
|
||||
else:
|
||||
yield from ()
|
||||
|
||||
def _aws_logs_enabled(self):
|
||||
return self.awslogs_group and self.awslogs_stream_prefix
|
||||
|
||||
def _last_log_message(self):
|
||||
try:
|
||||
return deque(self._cloudwatch_log_events(), maxlen=1).pop()["message"]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
def _check_success_task(self) -> None:
|
||||
if not self.client or not self.arn:
|
||||
return
|
||||
|
@ -268,11 +294,7 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
|
|||
self.log.info('ECS Task stopped, check status: %s', response)
|
||||
|
||||
# Get logs from CloudWatch if the awslogs log driver was used
|
||||
if self.awslogs_group and self.awslogs_stream_prefix:
|
||||
self.log.info('ECS Task logs output:')
|
||||
task_id = self.arn.split("/")[-1]
|
||||
stream_name = f"{self.awslogs_stream_prefix}/{task_id}"
|
||||
for event in self.get_logs_hook().get_log_events(self.awslogs_group, stream_name):
|
||||
for event in self._cloudwatch_log_events():
|
||||
event_dt = datetime.fromtimestamp(event['timestamp'] / 1000.0)
|
||||
self.log.info("[%s] %s", event_dt.isoformat(), event['message'])
|
||||
|
||||
|
|
|
@ -314,3 +314,18 @@ class TestECSOperator(unittest.TestCase):
|
|||
self.assertEqual(
|
||||
self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
|
||||
)
|
||||
|
||||
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
|
||||
def test_execute_xcom_with_log(self, mock_cloudwatch_log_message):
|
||||
self.ecs.do_xcom_push = True
|
||||
self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
|
||||
|
||||
@mock.patch.object(ECSOperator, '_last_log_message', return_value=None)
|
||||
def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message):
|
||||
self.ecs.do_xcom_push = True
|
||||
self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
|
||||
|
||||
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
|
||||
def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
|
||||
self.ecs.do_xcom_push = False
|
||||
self.assertEqual(self.ecs.execute(None), None)
|
||||
|
|
Загрузка…
Ссылка в новой задаче