Move to azure VM managed run commands (#853)

A comparison of the existing action run command API vs new managed
version:
https://learn.microsoft.com/en-us/azure/virtual-machines/run-command-overview

The main benefits of moving to the new managed version are:
- Much longer remote execution timeouts can be set beyond the current 90
mins. (Supposedly supports up to days)
- Returns the exit code of script ran.
- Can run multiple managed run commands in parallel.
- Supports storing large script output into blob storage (not used in
this PR)

The first point is really needed for steps like the benchbase loading
phase, where loading TPCC scale factor 200+ is a struggle due to hitting
the previous 90 min timeout limit.

This should be possible by setting `"pollTimeout": 5400` to much higher
values as needed in the relevant VM service configs.

---------

Co-authored-by: Eu Jing Chua <eujingchua@microsoft.com>
Co-authored-by: Brian Kroth <bpkroth@users.noreply.github.com>
Co-authored-by: Sergiy Matusevych <sergiym@microsoft.com>
This commit is contained in:
Eu Jing Chua 2024-10-03 15:05:05 -07:00 коммит произвёл GitHub
Родитель c69b7da1b5
Коммит a85df21e1b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 212 добавлений и 53 удалений

Просмотреть файл

@ -9,6 +9,7 @@ e.g. Application Environment
"""
import logging
import re
from datetime import datetime
from typing import Dict, Iterable, Optional, Tuple
@ -32,6 +33,8 @@ class RemoteEnv(ScriptEnv):
e.g. Application Environment
"""
_RE_SPECIAL = re.compile(r"\W+")
def __init__( # pylint: disable=too-many-arguments
self,
*,
@ -72,6 +75,7 @@ class RemoteEnv(ScriptEnv):
)
self._wait_boot = self.config.get("wait_boot", False)
self._command_prefix = "mlos-" + self._RE_SPECIAL.sub("-", self.name).lower() + "-"
assert self._service is not None and isinstance(
self._service, SupportsRemoteExec
@ -116,7 +120,7 @@ class RemoteEnv(ScriptEnv):
if self._script_setup:
_LOG.info("Set up the remote environment: %s", self)
(status, _timestamp, _output) = self._remote_exec(self._script_setup)
(status, _timestamp, _output) = self._remote_exec("setup", self._script_setup)
_LOG.info("Remote set up complete: %s :: %s", self, status)
self._is_ready = status.is_succeeded()
else:
@ -145,7 +149,7 @@ class RemoteEnv(ScriptEnv):
if not (status.is_ready() and self._script_run):
return result
(status, timestamp, output) = self._remote_exec(self._script_run)
(status, timestamp, output) = self._remote_exec("run", self._script_run)
if status.is_succeeded() and output is not None:
output = self._extract_stdout_results(output.get("stdout", ""))
_LOG.info("Remote run complete: %s :: %s = %s", self, status, output)
@ -155,16 +159,22 @@ class RemoteEnv(ScriptEnv):
"""Clean up and shut down the remote environment."""
if self._script_teardown:
_LOG.info("Remote teardown: %s", self)
(status, _timestamp, _output) = self._remote_exec(self._script_teardown)
(status, _timestamp, _output) = self._remote_exec("teardown", self._script_teardown)
_LOG.info("Remote teardown complete: %s :: %s", self, status)
super().teardown()
def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, datetime, Optional[dict]]:
def _remote_exec(
self,
command_name: str,
script: Iterable[str],
) -> Tuple[Status, datetime, Optional[dict]]:
"""
Run a script on the remote host.
Parameters
----------
command_name : str
Name of the command to be executed on the remote host.
script : [str]
List of commands to be executed on the remote host.
@ -175,10 +185,14 @@ class RemoteEnv(ScriptEnv):
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
"""
env_params = self._get_env_params()
_LOG.debug("Submit script: %s with %s", self, env_params)
command_name = self._command_prefix + command_name
_LOG.debug("Submit command: %s with %s", command_name, env_params)
(status, output) = self._remote_exec_service.remote_exec(
script,
config=self._params,
config={
**self._params,
"commandName": command_name,
},
env_params=env_params,
)
_LOG.debug("Script submitted: %s %s :: %s", self, status, output)

Просмотреть файл

@ -6,6 +6,7 @@
import json
import logging
from datetime import datetime
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import requests
@ -99,15 +100,25 @@ class AzureVMService(
"?api-version=2022-03-01"
)
# From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command
# From:
# https://learn.microsoft.com/en-us/rest/api/compute/virtual-machine-run-commands/create-or-update
_URL_REXEC_RUN = (
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/Microsoft.Compute"
"/virtualMachines/{vm_name}"
"/runCommand"
"?api-version=2022-03-01"
"/runcommands/{command_name}"
"?api-version=2024-07-01"
)
_URL_REXEC_RESULT = (
"https://management.azure.com"
"/subscriptions/{subscription}"
"/resourceGroups/{resource_group}"
"/providers/Microsoft.Compute"
"/virtualMachines/{vm_name}"
"/runcommands/{command_name}"
"?$expand=instanceView&api-version=2024-07-01"
)
def __init__(
@ -231,6 +242,28 @@ class AzureVMService(
params.setdefault(f"{params['vmName']}-deployment")
return self._wait_while(self._check_operation_status, Status.RUNNING, params)
def wait_remote_exec_operation(self, params: dict) -> Tuple["Status", dict]:
"""
Waits for a pending remote execution on an Azure VM to resolve to SUCCEEDED or
FAILED. Return TIMED_OUT when timing out.
Parameters
----------
params: dict
Flat dictionary of (key, value) pairs of tunable parameters.
Must have the "asyncResultsUrl" key to get the results.
If the key is not present, return Status.PENDING.
Returns
-------
result : (Status, dict)
A pair of Status and result.
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""
_LOG.info("Wait for run command %s on VM %s", params["commandName"], params["vmName"])
return self._wait_while(self._check_remote_exec_status, Status.RUNNING, params)
def wait_os_operation(self, params: dict) -> Tuple["Status", dict]:
return self.wait_host_operation(params)
@ -481,6 +514,8 @@ class AzureVMService(
"subscription",
"resourceGroup",
"vmName",
"commandName",
"location",
],
)
@ -488,21 +523,28 @@ class AzureVMService(
_LOG.info("Run a script on VM: %s\n %s", config["vmName"], "\n ".join(script))
json_req = {
"commandId": "RunShellScript",
"script": list(script),
"parameters": [{"name": key, "value": val} for (key, val) in env_params.items()],
"location": config["location"],
"properties": {
"source": {"script": "; ".join(script)},
"protectedParameters": [
{"name": key, "value": val} for (key, val) in env_params.items()
],
"timeoutInSeconds": int(self._poll_timeout),
"asyncExecution": True,
},
}
url = self._URL_REXEC_RUN.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
command_name=config["commandName"],
)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2))
_LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2))
response = requests.post(
response = requests.put(
url,
json=json_req,
headers=self._get_headers(),
@ -518,19 +560,73 @@ class AzureVMService(
else:
_LOG.info("Response: %s", response)
if response.status_code == 200:
# TODO: extract the results from JSON response
return (Status.SUCCEEDED, config)
elif response.status_code == 202:
if response.status_code in {200, 201}:
results_url = self._URL_REXEC_RESULT.format(
subscription=config["subscription"],
resource_group=config["resourceGroup"],
vm_name=config["vmName"],
command_name=config["commandName"],
)
return (
Status.PENDING,
{**config, "asyncResultsUrl": response.headers.get("Azure-AsyncOperation")},
{**config, "asyncResultsUrl": results_url},
)
else:
_LOG.error("Response: %s :: %s", response, response.text)
# _LOG.error("Bad Request:\n%s", response.request.body)
return (Status.FAILED, {})
def _check_remote_exec_status(self, params: dict) -> Tuple[Status, dict]:
"""
Checks the status of a pending remote execution on an Azure VM.
Parameters
----------
params: dict
Flat dictionary of (key, value) pairs of tunable parameters.
Must have the "asyncResultsUrl" key to get the results.
If the key is not present, return Status.PENDING.
Returns
-------
result : (Status, dict)
A pair of Status and result.
Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED}
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
"""
url = params.get("asyncResultsUrl")
if url is None:
return Status.PENDING, {}
session = self._get_session(params)
try:
response = session.get(url, timeout=self._request_timeout)
except requests.exceptions.ReadTimeout:
_LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url)
return Status.RUNNING, {}
except requests.exceptions.RequestException as ex:
_LOG.exception("Error in request checking operation status", exc_info=ex)
return (Status.FAILED, {})
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug(
"Response: %s\n%s",
response,
json.dumps(response.json(), indent=2) if response.content else "",
)
if response.status_code == 200:
output = response.json()
execution_state = (
output.get("properties", {}).get("instanceView", {}).get("executionState")
)
if execution_state in {"Running", "Pending"}:
return Status.RUNNING, {}
elif execution_state == "Succeeded":
return Status.SUCCEEDED, output
_LOG.error("Response: %s :: %s", response, response.text)
return Status.FAILED, {}
def get_remote_exec_results(self, config: dict) -> Tuple[Status, dict]:
"""
Get the results of the asynchronously running command.
@ -547,13 +643,34 @@ class AzureVMService(
result : (Status, dict)
A pair of Status and result.
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
A dict can have an "stdout" key with the remote output.
A dict can have an "stdout" key with the remote output
and an "stderr" key for errors / warnings.
"""
_LOG.info("Check the results on VM: %s", config.get("vmName"))
(status, result) = self.wait_host_operation(config)
(status, result) = self.wait_remote_exec_operation(config)
_LOG.debug("Result: %s :: %s", status, result)
if not status.is_succeeded():
# TODO: Extract the telemetry and status from stdout, if available
return (status, result)
val = result.get("properties", {}).get("output", {}).get("value", [])
return (status, {"stdout": val[0].get("message", "")} if val else {})
output = result.get("properties", {}).get("instanceView", {})
exit_code = output.get("exitCode")
execution_state = output.get("executionState")
outputs = output.get("output", "").strip().split("\n")
errors = output.get("error", "").strip().split("\n")
if execution_state == "Succeeded" and exit_code == 0:
status = Status.SUCCEEDED
else:
status = Status.FAILED
return (
status,
{
"stdout": outputs,
"stderr": errors,
"exitCode": exit_code,
"startTimestamp": datetime.fromisoformat(output["startTime"]),
"endTimestamp": datetime.fromisoformat(output["endTime"]),
},
)

Просмотреть файл

@ -5,6 +5,7 @@
"""Tests for mlos_bench.services.remote.azure.azure_vm_services."""
from copy import deepcopy
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
@ -273,8 +274,8 @@ def test_wait_vm_operation_retry(
@pytest.mark.parametrize(
("http_status_code", "operation_status"),
[
(200, Status.SUCCEEDED),
(202, Status.PENDING),
(200, Status.PENDING),
(201, Status.PENDING),
(401, Status.FAILED),
(404, Status.FAILED),
],
@ -291,16 +292,18 @@ def test_remote_exec_status(
mock_response = MagicMock()
mock_response.status_code = http_status_code
mock_response.json = MagicMock(
return_value={
"fake response": "body as json to dict",
}
)
mock_requests.post.return_value = mock_response
mock_response.json.return_value = {
"fake response": "body as json to dict",
}
mock_requests.put.return_value = mock_response
status, _ = azure_vm_service_remote_exec_only.remote_exec(
script,
config={"vmName": "test-vm"},
config={
"vmName": "test-vm",
"commandName": "TEST_COMMAND",
"location": "TEST_LOCATION",
},
env_params={},
)
@ -308,7 +311,7 @@ def test_remote_exec_status(
@patch("mlos_bench.services.remote.azure.azure_vm_services.requests")
def test_remote_exec_headers_output(
def test_remote_exec_output(
mock_requests: MagicMock,
azure_vm_service_remote_exec_only: AzureVMService,
) -> None:
@ -318,18 +321,22 @@ def test_remote_exec_headers_output(
script = ["command_1", "command_2"]
mock_response = MagicMock()
mock_response.status_code = 202
mock_response.status_code = 201
mock_response.headers = {"Azure-AsyncOperation": async_url_value}
mock_response.json = MagicMock(
return_value={
"fake response": "body as json to dict",
}
)
mock_requests.post.return_value = mock_response
mock_requests.put.return_value = mock_response
_, cmd_output = azure_vm_service_remote_exec_only.remote_exec(
script,
config={"vmName": "test-vm"},
config={
"vmName": "test-vm",
"commandName": "TEST_COMMAND",
"location": "TEST_LOCATION",
},
env_params={
"param_1": 123,
"param_2": "abc",
@ -337,12 +344,18 @@ def test_remote_exec_headers_output(
)
assert async_url_key in cmd_output
assert cmd_output[async_url_key] == async_url_value
assert mock_requests.post.call_args[1]["json"] == {
"commandId": "RunShellScript",
"script": script,
"parameters": [{"name": "param_1", "value": 123}, {"name": "param_2", "value": "abc"}],
assert mock_requests.put.call_args[1]["json"] == {
"location": "TEST_LOCATION",
"properties": {
"source": {"script": "; ".join(script)},
"protectedParameters": [
{"name": "param_1", "value": 123},
{"name": "param_2", "value": "abc"},
],
"timeoutInSeconds": 2,
"asyncExecution": True,
},
}
@ -353,14 +366,23 @@ def test_remote_exec_headers_output(
Status.SUCCEEDED,
{
"properties": {
"output": {
"value": [
{"message": "DUMMY_STDOUT_STDERR"},
]
"instanceView": {
"output": "DUMMY_STDOUT\n",
"error": "DUMMY_STDERR\n",
"executionState": "Succeeded",
"exitCode": 0,
"startTime": "2024-01-01T00:00:00+00:00",
"endTime": "2024-01-01T00:01:00+00:00",
}
}
},
{"stdout": "DUMMY_STDOUT_STDERR"},
{
"stdout": ["DUMMY_STDOUT"],
"stderr": ["DUMMY_STDERR"],
"exitCode": 0,
"startTimestamp": datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
"endTimestamp": datetime(2024, 1, 1, 0, 1, 0, tzinfo=timezone.utc),
},
),
(Status.PENDING, {}, {}),
(Status.FAILED, {}, {}),
@ -373,15 +395,21 @@ def test_get_remote_exec_results(
results_output: dict,
) -> None:
"""Test getting the results of the remote execution on Azure."""
params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"}
params = {
"asyncResultsUrl": "DUMMY_ASYNC_URL",
}
mock_wait_host_operation = MagicMock()
mock_wait_host_operation.return_value = (operation_status, wait_output)
# azure_vm_service.wait_host_operation = mock_wait_host_operation
setattr(azure_vm_service_remote_exec_only, "wait_host_operation", mock_wait_host_operation)
mock_wait_remote_exec_operation = MagicMock()
mock_wait_remote_exec_operation.return_value = (operation_status, wait_output)
# azure_vm_service.wait_remote_exec_operation = mock_wait_remote_exec_operation
setattr(
azure_vm_service_remote_exec_only,
"wait_remote_exec_operation",
mock_wait_remote_exec_operation,
)
status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results(params)
assert status == operation_status
assert mock_wait_host_operation.call_args[0][0] == params
assert mock_wait_remote_exec_operation.call_args[0][0] == params
assert cmd_output == results_output