diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py index 1c025a335d..c3535d1a6a 100644 --- a/mlos_bench/mlos_bench/environments/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py @@ -9,6 +9,7 @@ e.g. Application Environment """ import logging +import re from datetime import datetime from typing import Dict, Iterable, Optional, Tuple @@ -32,6 +33,8 @@ class RemoteEnv(ScriptEnv): e.g. Application Environment """ + _RE_SPECIAL = re.compile(r"\W+") + def __init__( # pylint: disable=too-many-arguments self, *, @@ -72,6 +75,7 @@ class RemoteEnv(ScriptEnv): ) self._wait_boot = self.config.get("wait_boot", False) + self._command_prefix = "mlos-" + self._RE_SPECIAL.sub("-", self.name).lower() + "-" assert self._service is not None and isinstance( self._service, SupportsRemoteExec @@ -116,7 +120,7 @@ class RemoteEnv(ScriptEnv): if self._script_setup: _LOG.info("Set up the remote environment: %s", self) - (status, _timestamp, _output) = self._remote_exec(self._script_setup) + (status, _timestamp, _output) = self._remote_exec("setup", self._script_setup) _LOG.info("Remote set up complete: %s :: %s", self, status) self._is_ready = status.is_succeeded() else: @@ -145,7 +149,7 @@ class RemoteEnv(ScriptEnv): if not (status.is_ready() and self._script_run): return result - (status, timestamp, output) = self._remote_exec(self._script_run) + (status, timestamp, output) = self._remote_exec("run", self._script_run) if status.is_succeeded() and output is not None: output = self._extract_stdout_results(output.get("stdout", "")) _LOG.info("Remote run complete: %s :: %s = %s", self, status, output) @@ -155,16 +159,22 @@ class RemoteEnv(ScriptEnv): """Clean up and shut down the remote environment.""" if self._script_teardown: _LOG.info("Remote teardown: %s", self) - (status, _timestamp, _output) = self._remote_exec(self._script_teardown) + (status, _timestamp, _output) = self._remote_exec("teardown", self._script_teardown) _LOG.info("Remote teardown complete: %s :: %s", self, status) super().teardown() - def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, datetime, Optional[dict]]: + def _remote_exec( + self, + command_name: str, + script: Iterable[str], + ) -> Tuple[Status, datetime, Optional[dict]]: """ Run a script on the remote host. Parameters ---------- + command_name : str + Name of the command to be executed on the remote host. script : [str] List of commands to be executed on the remote host. @@ -175,10 +185,14 @@ class RemoteEnv(ScriptEnv): Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} """ env_params = self._get_env_params() - _LOG.debug("Submit script: %s with %s", self, env_params) + command_name = self._command_prefix + command_name + _LOG.debug("Submit command: %s with %s", command_name, env_params) (status, output) = self._remote_exec_service.remote_exec( script, - config=self._params, + config={ + **self._params, + "commandName": command_name, + }, env_params=env_params, ) _LOG.debug("Script submitted: %s %s :: %s", self, status, output) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index 3d390645f5..b62ede5fab 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -6,6 +6,7 @@ import json import logging +from datetime import datetime from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import requests @@ -99,15 +100,25 @@ class AzureVMService( "?api-version=2022-03-01" ) - # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command + # From: + # https://learn.microsoft.com/en-us/rest/api/compute/virtual-machine-run-commands/create-or-update _URL_REXEC_RUN = ( "https://management.azure.com" "/subscriptions/{subscription}" "/resourceGroups/{resource_group}" "/providers/Microsoft.Compute" "/virtualMachines/{vm_name}" - "/runCommand" - "?api-version=2022-03-01" + "/runcommands/{command_name}" + "?api-version=2024-07-01" + ) + _URL_REXEC_RESULT = ( + "https://management.azure.com" + "/subscriptions/{subscription}" + "/resourceGroups/{resource_group}" + "/providers/Microsoft.Compute" + "/virtualMachines/{vm_name}" + "/runcommands/{command_name}" + "?$expand=instanceView&api-version=2024-07-01" ) def __init__( @@ -231,6 +242,28 @@ class AzureVMService( params.setdefault(f"{params['vmName']}-deployment") return self._wait_while(self._check_operation_status, Status.RUNNING, params) + def wait_remote_exec_operation(self, params: dict) -> Tuple["Status", dict]: + """ + Waits for a pending remote execution on an Azure VM to resolve to SUCCEEDED or + FAILED. Return TIMED_OUT when timing out. + + Parameters + ---------- + params: dict + Flat dictionary of (key, value) pairs of tunable parameters. + Must have the "asyncResultsUrl" key to get the results. + If the key is not present, return Status.PENDING. + + Returns + ------- + result : (Status, dict) + A pair of Status and result. + Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} + Result is info on the operation runtime if SUCCEEDED, otherwise {}. + """ + _LOG.info("Wait for run command %s on VM %s", params["commandName"], params["vmName"]) + return self._wait_while(self._check_remote_exec_status, Status.RUNNING, params) + def wait_os_operation(self, params: dict) -> Tuple["Status", dict]: return self.wait_host_operation(params) @@ -481,6 +514,8 @@ class AzureVMService( "subscription", "resourceGroup", "vmName", + "commandName", + "location", ], ) @@ -488,21 +523,28 @@ class AzureVMService( _LOG.info("Run a script on VM: %s\n %s", config["vmName"], "\n ".join(script)) json_req = { - "commandId": "RunShellScript", - "script": list(script), - "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()], + "location": config["location"], + "properties": { + "source": {"script": "; ".join(script)}, + "protectedParameters": [ + {"name": key, "value": val} for (key, val) in env_params.items() + ], + "timeoutInSeconds": int(self._poll_timeout), + "asyncExecution": True, + }, } url = self._URL_REXEC_RUN.format( subscription=config["subscription"], resource_group=config["resourceGroup"], vm_name=config["vmName"], + command_name=config["commandName"], ) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2)) + _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) - response = requests.post( + response = requests.put( url, json=json_req, headers=self._get_headers(), @@ -518,19 +560,73 @@ class AzureVMService( else: _LOG.info("Response: %s", response) - if response.status_code == 200: - # TODO: extract the results from JSON response - return (Status.SUCCEEDED, config) - elif response.status_code == 202: + if response.status_code in {200, 201}: + results_url = self._URL_REXEC_RESULT.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + command_name=config["commandName"], + ) return ( Status.PENDING, - {**config, "asyncResultsUrl": response.headers.get("Azure-AsyncOperation")}, + {**config, "asyncResultsUrl": results_url}, ) else: _LOG.error("Response: %s :: %s", response, response.text) - # _LOG.error("Bad Request:\n%s", response.request.body) return (Status.FAILED, {}) + def _check_remote_exec_status(self, params: dict) -> Tuple[Status, dict]: + """ + Checks the status of a pending remote execution on an Azure VM. + + Parameters + ---------- + params: dict + Flat dictionary of (key, value) pairs of tunable parameters. + Must have the "asyncResultsUrl" key to get the results. + If the key is not present, return Status.PENDING. + + Returns + ------- + result : (Status, dict) + A pair of Status and result. + Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED} + Result is info on the operation runtime if SUCCEEDED, otherwise {}. + """ + url = params.get("asyncResultsUrl") + if url is None: + return Status.PENDING, {} + + session = self._get_session(params) + try: + response = session.get(url, timeout=self._request_timeout) + except requests.exceptions.ReadTimeout: + _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) + return Status.RUNNING, {} + except requests.exceptions.RequestException as ex: + _LOG.exception("Error in request checking operation status", exc_info=ex) + return (Status.FAILED, {}) + + if _LOG.isEnabledFor(logging.DEBUG): + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) + + if response.status_code == 200: + output = response.json() + execution_state = ( + output.get("properties", {}).get("instanceView", {}).get("executionState") + ) + if execution_state in {"Running", "Pending"}: + return Status.RUNNING, {} + elif execution_state == "Succeeded": + return Status.SUCCEEDED, output + + _LOG.error("Response: %s :: %s", response, response.text) + return Status.FAILED, {} + def get_remote_exec_results(self, config: dict) -> Tuple[Status, dict]: """ Get the results of the asynchronously running command. @@ -547,13 +643,34 @@ class AzureVMService( result : (Status, dict) A pair of Status and result. Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} - A dict can have an "stdout" key with the remote output. + A dict can have an "stdout" key with the remote output + and an "stderr" key for errors / warnings. """ _LOG.info("Check the results on VM: %s", config.get("vmName")) - (status, result) = self.wait_host_operation(config) + (status, result) = self.wait_remote_exec_operation(config) _LOG.debug("Result: %s :: %s", status, result) if not status.is_succeeded(): # TODO: Extract the telemetry and status from stdout, if available return (status, result) - val = result.get("properties", {}).get("output", {}).get("value", []) - return (status, {"stdout": val[0].get("message", "")} if val else {}) + + output = result.get("properties", {}).get("instanceView", {}) + exit_code = output.get("exitCode") + execution_state = output.get("executionState") + outputs = output.get("output", "").strip().split("\n") + errors = output.get("error", "").strip().split("\n") + + if execution_state == "Succeeded" and exit_code == 0: + status = Status.SUCCEEDED + else: + status = Status.FAILED + + return ( + status, + { + "stdout": outputs, + "stderr": errors, + "exitCode": exit_code, + "startTimestamp": datetime.fromisoformat(output["startTime"]), + "endTimestamp": datetime.fromisoformat(output["endTime"]), + }, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py index 988177180e..240a32e4c5 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py @@ -5,6 +5,7 @@ """Tests for mlos_bench.services.remote.azure.azure_vm_services.""" from copy import deepcopy +from datetime import datetime, timezone from unittest.mock import MagicMock, patch import pytest @@ -273,8 +274,8 @@ def test_wait_vm_operation_retry( @pytest.mark.parametrize( ("http_status_code", "operation_status"), [ - (200, Status.SUCCEEDED), - (202, Status.PENDING), + (200, Status.PENDING), + (201, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), ], @@ -291,16 +292,18 @@ def test_remote_exec_status( mock_response = MagicMock() mock_response.status_code = http_status_code - mock_response.json = MagicMock( - return_value={ - "fake response": "body as json to dict", - } - ) - mock_requests.post.return_value = mock_response + mock_response.json.return_value = { + "fake response": "body as json to dict", + } + mock_requests.put.return_value = mock_response status, _ = azure_vm_service_remote_exec_only.remote_exec( script, - config={"vmName": "test-vm"}, + config={ + "vmName": "test-vm", + "commandName": "TEST_COMMAND", + "location": "TEST_LOCATION", + }, env_params={}, ) @@ -308,7 +311,7 @@ def test_remote_exec_status( @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_headers_output( +def test_remote_exec_output( mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService, ) -> None: @@ -318,18 +321,22 @@ def test_remote_exec_headers_output( script = ["command_1", "command_2"] mock_response = MagicMock() - mock_response.status_code = 202 + mock_response.status_code = 201 mock_response.headers = {"Azure-AsyncOperation": async_url_value} mock_response.json = MagicMock( return_value={ "fake response": "body as json to dict", } ) - mock_requests.post.return_value = mock_response + mock_requests.put.return_value = mock_response _, cmd_output = azure_vm_service_remote_exec_only.remote_exec( script, - config={"vmName": "test-vm"}, + config={ + "vmName": "test-vm", + "commandName": "TEST_COMMAND", + "location": "TEST_LOCATION", + }, env_params={ "param_1": 123, "param_2": "abc", @@ -337,12 +344,18 @@ def test_remote_exec_headers_output( ) assert async_url_key in cmd_output - assert cmd_output[async_url_key] == async_url_value - assert mock_requests.post.call_args[1]["json"] == { - "commandId": "RunShellScript", - "script": script, - "parameters": [{"name": "param_1", "value": 123}, {"name": "param_2", "value": "abc"}], + assert mock_requests.put.call_args[1]["json"] == { + "location": "TEST_LOCATION", + "properties": { + "source": {"script": "; ".join(script)}, + "protectedParameters": [ + {"name": "param_1", "value": 123}, + {"name": "param_2", "value": "abc"}, + ], + "timeoutInSeconds": 2, + "asyncExecution": True, + }, } @@ -353,14 +366,23 @@ def test_remote_exec_headers_output( Status.SUCCEEDED, { "properties": { - "output": { - "value": [ - {"message": "DUMMY_STDOUT_STDERR"}, - ] + "instanceView": { + "output": "DUMMY_STDOUT\n", + "error": "DUMMY_STDERR\n", + "executionState": "Succeeded", + "exitCode": 0, + "startTime": "2024-01-01T00:00:00+00:00", + "endTime": "2024-01-01T00:01:00+00:00", } } }, - {"stdout": "DUMMY_STDOUT_STDERR"}, + { + "stdout": ["DUMMY_STDOUT"], + "stderr": ["DUMMY_STDERR"], + "exitCode": 0, + "startTimestamp": datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + "endTimestamp": datetime(2024, 1, 1, 0, 1, 0, tzinfo=timezone.utc), + }, ), (Status.PENDING, {}, {}), (Status.FAILED, {}, {}), @@ -373,15 +395,21 @@ def test_get_remote_exec_results( results_output: dict, ) -> None: """Test getting the results of the remote execution on Azure.""" - params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"} + params = { + "asyncResultsUrl": "DUMMY_ASYNC_URL", + } - mock_wait_host_operation = MagicMock() - mock_wait_host_operation.return_value = (operation_status, wait_output) - # azure_vm_service.wait_host_operation = mock_wait_host_operation - setattr(azure_vm_service_remote_exec_only, "wait_host_operation", mock_wait_host_operation) + mock_wait_remote_exec_operation = MagicMock() + mock_wait_remote_exec_operation.return_value = (operation_status, wait_output) + # azure_vm_service.wait_remote_exec_operation = mock_wait_remote_exec_operation + setattr( + azure_vm_service_remote_exec_only, + "wait_remote_exec_operation", + mock_wait_remote_exec_operation, + ) status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results(params) assert status == operation_status - assert mock_wait_host_operation.call_args[0][0] == params + assert mock_wait_remote_exec_operation.call_args[0][0] == params assert cmd_output == results_output