add xcom push for ECSOperator (#12096)
This pushes the last cloudwatch event to xcom when do_xcom_push is True Co-authored-by: Felix Uellendall <feluelle@users.noreply.github.com>
This commit is contained in:
Родитель
a74fdb7307
Коммит
8d42d9ed69
|
@ -17,8 +17,9 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from collections import deque
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Generator, Optional
|
||||||
|
|
||||||
from botocore.waiter import Waiter
|
from botocore.waiter import Waiter
|
||||||
|
|
||||||
|
@ -197,8 +198,14 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
|
||||||
self._wait_for_task_ended()
|
self._wait_for_task_ended()
|
||||||
|
|
||||||
self._check_success_task()
|
self._check_success_task()
|
||||||
|
|
||||||
self.log.info('ECS Task has been successfully executed')
|
self.log.info('ECS Task has been successfully executed')
|
||||||
|
|
||||||
|
if self.do_xcom_push:
|
||||||
|
return self._last_log_message()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def _start_task(self):
|
def _start_task(self):
|
||||||
run_opts = {
|
run_opts = {
|
||||||
'cluster': self.cluster,
|
'cluster': self.cluster,
|
||||||
|
@ -260,6 +267,25 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
|
||||||
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
|
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
|
||||||
waiter.wait(cluster=self.cluster, tasks=[self.arn])
|
waiter.wait(cluster=self.cluster, tasks=[self.arn])
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def _cloudwatch_log_events(self) -> Generator:
|
||||||
|
if self._aws_logs_enabled():
|
||||||
|
task_id = self.arn.split("/")[-1]
|
||||||
|
stream_name = f"{self.awslogs_stream_prefix}/{task_id}"
|
||||||
|
yield from self.get_logs_hook().get_log_events(self.awslogs_group, stream_name)
|
||||||
|
else:
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
def _aws_logs_enabled(self):
|
||||||
|
return self.awslogs_group and self.awslogs_stream_prefix
|
||||||
|
|
||||||
|
def _last_log_message(self):
|
||||||
|
try:
|
||||||
|
return deque(self._cloudwatch_log_events(), maxlen=1).pop()["message"]
|
||||||
|
except IndexError:
|
||||||
|
return None
|
||||||
|
|
||||||
def _check_success_task(self) -> None:
|
def _check_success_task(self) -> None:
|
||||||
if not self.client or not self.arn:
|
if not self.client or not self.arn:
|
||||||
return
|
return
|
||||||
|
@ -268,11 +294,7 @@ class ECSOperator(BaseOperator): # pylint: disable=too-many-instance-attributes
|
||||||
self.log.info('ECS Task stopped, check status: %s', response)
|
self.log.info('ECS Task stopped, check status: %s', response)
|
||||||
|
|
||||||
# Get logs from CloudWatch if the awslogs log driver was used
|
# Get logs from CloudWatch if the awslogs log driver was used
|
||||||
if self.awslogs_group and self.awslogs_stream_prefix:
|
for event in self._cloudwatch_log_events():
|
||||||
self.log.info('ECS Task logs output:')
|
|
||||||
task_id = self.arn.split("/")[-1]
|
|
||||||
stream_name = f"{self.awslogs_stream_prefix}/{task_id}"
|
|
||||||
for event in self.get_logs_hook().get_log_events(self.awslogs_group, stream_name):
|
|
||||||
event_dt = datetime.fromtimestamp(event['timestamp'] / 1000.0)
|
event_dt = datetime.fromtimestamp(event['timestamp'] / 1000.0)
|
||||||
self.log.info("[%s] %s", event_dt.isoformat(), event['message'])
|
self.log.info("[%s] %s", event_dt.isoformat(), event['message'])
|
||||||
|
|
||||||
|
|
|
@ -314,3 +314,18 @@ class TestECSOperator(unittest.TestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
|
self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
|
||||||
|
def test_execute_xcom_with_log(self, mock_cloudwatch_log_message):
|
||||||
|
self.ecs.do_xcom_push = True
|
||||||
|
self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
|
||||||
|
|
||||||
|
@mock.patch.object(ECSOperator, '_last_log_message', return_value=None)
|
||||||
|
def test_execute_xcom_with_no_log(self, mock_cloudwatch_log_message):
|
||||||
|
self.ecs.do_xcom_push = True
|
||||||
|
self.assertEqual(self.ecs.execute(None), mock_cloudwatch_log_message.return_value)
|
||||||
|
|
||||||
|
@mock.patch.object(ECSOperator, '_last_log_message', return_value="Log output")
|
||||||
|
def test_execute_xcom_disabled(self, mock_cloudwatch_log_message):
|
||||||
|
self.ecs.do_xcom_push = False
|
||||||
|
self.assertEqual(self.ecs.execute(None), None)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче