зеркало из https://github.com/microsoft/hi-ml.git
ENH: Enable adding additional env variables from container (#689)
add get_additional_environment_variables to Lightning container to be able to flexibly add experiments specific env variables, e.g. Azure mounting env variables for large file experiments... this can be overridden in the config file directly.
This commit is contained in:
Родитель
a079da888b
Коммит
bdb89d0b8d
|
@ -104,7 +104,6 @@ DEFAULT_ENVIRONMENT_VARIABLES = {
|
|||
"MKL_SERVICE_FORCE_INTEL": "1",
|
||||
# Switching to a new software stack in AML for mounting datasets
|
||||
"RSLEX_DIRECT_VOLUME_MOUNT": "true",
|
||||
"RSLEX_DIRECT_VOLUME_MOUNT_MAX_CACHE_SIZE": "1",
|
||||
"DATASET_MOUNT_CACHE_SIZE": "1",
|
||||
}
|
||||
|
||||
|
|
|
@ -252,6 +252,10 @@ class LightningContainer(WorkflowParams,
|
|||
"""Returns a dictionary of tags that should be added to the AzureML run."""
|
||||
return {}
|
||||
|
||||
def get_additional_environment_variables(self) -> Dict[str, str]:
|
||||
"""Returns a dictionary of environment variables that should be added to the AzureML run."""
|
||||
return {}
|
||||
|
||||
def on_run_extra_validation_epoch(self) -> None:
|
||||
if hasattr(self.model, "on_run_extra_validation_epoch"):
|
||||
assert self._model, "Model is not initialized."
|
||||
|
|
|
@ -10,7 +10,7 @@ import os
|
|||
import param
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import matplotlib
|
||||
from azureml.core import Workspace, Run
|
||||
|
@ -180,6 +180,12 @@ class Runner:
|
|||
**self.lightning_container.get_additional_aml_run_tags()
|
||||
}
|
||||
|
||||
def additional_environment_variables(self) -> Dict[str, str]:
|
||||
return {
|
||||
DEBUG_DDP_ENV_VAR: self.experiment_config.debug_ddp.value,
|
||||
**self.lightning_container.get_additional_environment_variables()
|
||||
}
|
||||
|
||||
def run(self) -> Tuple[LightningContainer, AzureRunInfo]:
|
||||
"""
|
||||
The main entry point for training and testing models from the commandline. This chooses a model to train
|
||||
|
@ -221,9 +227,7 @@ class Runner:
|
|||
entry_script = Path(sys.argv[0]).resolve()
|
||||
script_params = sys.argv[1:]
|
||||
|
||||
# TODO: Update environment variables
|
||||
environment_variables: Dict[str, Any] = {}
|
||||
environment_variables[DEBUG_DDP_ENV_VAR] = self.experiment_config.debug_ddp.value
|
||||
environment_variables = self.additional_environment_variables()
|
||||
|
||||
# Get default datastore from the provided workspace. Authentication can take a few seconds, hence only do
|
||||
# that if we are really submitting to AzureML.
|
||||
|
|
|
@ -7,7 +7,7 @@ import shutil
|
|||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import patch, MagicMock, DEFAULT, create_autospec
|
||||
|
||||
import pytest
|
||||
from _pytest.capture import SysCapture
|
||||
|
@ -113,6 +113,31 @@ def test_additional_aml_run_tags(mock_runner: Runner) -> None:
|
|||
assert "max_epochs" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
|
||||
|
||||
def test_additional_environment_variables(mock_runner: Runner) -> None:
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}", "--cluster=foo"]
|
||||
with patch.multiple(
|
||||
"health_ml.runner",
|
||||
submit_to_azure_if_needed=DEFAULT,
|
||||
check_conda_environment=DEFAULT,
|
||||
get_workspace=DEFAULT,
|
||||
get_ml_client=DEFAULT,
|
||||
) as mocks:
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch("health_ml.runner.Runner.parse_and_load_model"):
|
||||
with patch("health_ml.runner.Runner.validate"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_container = create_autospec(LightningContainer)
|
||||
mock_container.get_additional_environment_variables = MagicMock(return_value={"foo": "bar"})
|
||||
mock_runner.lightning_container = mock_container
|
||||
mock_runner.run()
|
||||
mocks["submit_to_azure_if_needed"].assert_called_once()
|
||||
mock_env_vars = mocks["submit_to_azure_if_needed"].call_args[1]["environment_variables"]
|
||||
assert DEBUG_DDP_ENV_VAR in mock_env_vars
|
||||
assert "foo" in mock_env_vars
|
||||
assert mock_env_vars["foo"] == "bar"
|
||||
|
||||
|
||||
def test_run(mock_runner: Runner) -> None:
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}"]
|
||||
|
|
Загрузка…
Ссылка в новой задаче