diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 9080a3a7e0..d7d1292cf8 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -125,7 +125,7 @@ class KubeConfig: # pylint: disable=too-many-instance-attributes return int(val) -class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): +class KubernetesJobWatcher(LoggingMixin): """Watches for Kubernetes jobs""" def __init__(self, @@ -142,6 +142,31 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin): self.watcher_queue = watcher_queue self.resource_version = resource_version self.kube_config = kube_config + self.watcher_process = multiprocessing.Process(target=self.run, args=()) + + def start(self): + """ + Start the watcher process + """ + self.watcher_process.start() + + def is_alive(self): + """ + Check if the watcher process is alive + """ + self.watcher_process.is_alive() + + def join(self): + """ + Join watcher process + """ + self.watcher_process.join() + + def terminate(self): + """ + Terminate watcher process + """ + self.watcher_process.terminate() def run(self) -> None: """Performs watching""" diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 9bd1868022..038ec28e38 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # +import multiprocessing import random import re import string @@ -30,7 +31,9 @@ from tests.test_utils.config import conf_vars try: from kubernetes.client.rest import ApiException - from airflow.executors.kubernetes_executor import AirflowKubernetesScheduler, KubernetesExecutor + from airflow.executors.kubernetes_executor import ( + AirflowKubernetesScheduler, KubernetesExecutor, KubernetesJobWatcher, + ) from airflow.kubernetes import pod_generator from airflow.kubernetes.pod_generator import PodGenerator from airflow.utils.state import State @@ -38,6 +41,56 @@ except ImportError: AirflowKubernetesScheduler = None # type: ignore +class TestKubernetesJobWatcher(unittest.TestCase): + def setUp(self) -> None: + self.watcher_queue = multiprocessing.Manager().Queue() + self.watcher = KubernetesJobWatcher( + namespace="namespace", + multi_namespace_mode=False, + watcher_queue=self.watcher_queue, + resource_version="0", + worker_uuid="0", + kube_config=None, + ) + + def test_running_task(self): + self.watcher.process_status( + pod_id="pod_id", + namespace="namespace", + status="Running", + annotations={"foo": "bar"}, + resource_version="5", + event={"type": "ADDED"} + ) + self.assertTrue(self.watcher_queue.empty()) + + def test_succeeded_task(self): + self.watcher.process_status( + pod_id="pod_id", + namespace="namespace", + status="Succeeded", + annotations={"foo": "bar"}, + resource_version="5", + event={"type": "ADDED"} + ) + result = self.watcher_queue.get_nowait() + self.assertEqual(('pod_id', 'namespace', None, {'foo': 'bar'}, '5'), result) + self.assertTrue(self.watcher_queue.empty()) + + def test_failed_task(self): + self.watcher.process_status( + pod_id="pod_id", + namespace="namespace", + status="Failed", + annotations={"foo": "bar"}, + resource_version="5", + event={"type": "ADDED"} + ) + result = self.watcher_queue.get_nowait() + self.assertEqual(('pod_id', 'namespace', "failed", {'foo': 'bar'}, '5'), result) + self.assertTrue(self.watcher_queue.empty()) + + # pylint: disable=unused-argument class TestAirflowKubernetesScheduler(unittest.TestCase): @staticmethod