From 99c534e9fafb947b5949b5f3ba66e3dd090be22d Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 12 Jun 2020 22:24:00 +0100 Subject: [PATCH] Further validation that only task commands are run by executors (#9240) --- airflow/executors/kubernetes_executor.py | 7 ++----- tests/executors/test_dask_executor.py | 6 +++--- tests/executors/test_kubernetes_executor.py | 2 +- tests/executors/test_local_executor.py | 21 ++++++++++++++++----- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index bf51a1333c..3b8d61c19c 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -456,11 +456,8 @@ class AirflowKubernetesScheduler(LoggingMixin): key, command, kube_executor_config = next_job dag_id, task_id, execution_date, try_number = key - if isinstance(command, str): - command = [command] - - if command[0] != "airflow": - raise ValueError('The first element of command must be equal to "airflow".') + if command[0:3] != ["airflow", "tasks", "run"]: + raise ValueError('The command must start with ["airflow", "tasks", "run"].') pod = PodGenerator.construct_pod( namespace=self.namespace, diff --git a/tests/executors/test_dask_executor.py b/tests/executors/test_dask_executor.py index 84a92003bf..ab0826b782 100644 --- a/tests/executors/test_dask_executor.py +++ b/tests/executors/test_dask_executor.py @@ -46,12 +46,12 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) class TestBaseDask(unittest.TestCase): def assert_tasks_on_executor(self, executor): + + success_command = ['airflow', 'tasks', 'run', '--help'] + fail_command = ['airflow', 'tasks', 'run', 'false'] # start the executor executor.start() - success_command = ['true', 'some_parameter'] - fail_command = ['false', 'some_parameter'] - executor.execute_async(key='success', command=success_command) executor.execute_async(key='fail', command=fail_command) diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 27fe4bdc96..40808af59d 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -205,7 +205,7 @@ class TestKubernetesExecutor(unittest.TestCase): try_number = 1 kubernetes_executor.execute_async(key=('dag', 'task', datetime.utcnow(), try_number), queue=None, - command='command', + command=['airflow', 'tasks', 'run', 'true', 'some_parameter'], executor_config={}) kubernetes_executor.sync() kubernetes_executor.sync() diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index 32caca9fcd..7fc909cdf1 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. import datetime +import subprocess import unittest from unittest import mock @@ -27,13 +28,23 @@ class TestLocalExecutor(unittest.TestCase): TEST_SUCCESS_COMMANDS = 5 - def execution_parallelism(self, parallelism=0): + @mock.patch('airflow.executors.local_executor.subprocess.check_call') + def execution_parallelism(self, mock_check_call, parallelism=0): + success_command = ['airflow', 'tasks', 'run', 'true', 'some_parameter'] + fail_command = ['airflow', 'tasks', 'run', 'false'] + + def fake_execute_command(command, close_fds=True): # pylint: disable=unused-argument + if command != success_command: + raise subprocess.CalledProcessError(returncode=1, cmd=command) + else: + return 0 + + mock_check_call.side_effect = fake_execute_command + executor = LocalExecutor(parallelism=parallelism) executor.start() success_key = 'success {}' - success_command = ['true', 'some_parameter'] - fail_command = ['false', 'some_parameter'] self.assertTrue(executor.result_queue.empty()) execution_date = datetime.datetime.now() @@ -61,11 +72,11 @@ class TestLocalExecutor(unittest.TestCase): self.assertEqual(executor.workers_used, expected) def test_execution_unlimited_parallelism(self): - self.execution_parallelism(parallelism=0) + self.execution_parallelism(parallelism=0) # pylint: disable=no-value-for-parameter def test_execution_limited_parallelism(self): test_parallelism = 2 - self.execution_parallelism(parallelism=test_parallelism) + self.execution_parallelism(parallelism=test_parallelism) # pylint: disable=no-value-for-parameter @mock.patch('airflow.executors.local_executor.LocalExecutor.sync') @mock.patch('airflow.executors.base_executor.BaseExecutor.trigger_tasks')