ENH: Enable adding additional env variables from container (#689)

add get_additional_environment_variables to Lightning container to be
able to flexibly add experiments specific env variables, e.g. Azure
mounting env variables for large file experiments... this can be
overridden in the config file directly.
This commit is contained in:
Kenza Bouzid 2022-11-29 19:23:12 +00:00 коммит произвёл GitHub
Родитель a079da888b
Коммит bdb89d0b8d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 38 добавлений и 6 удалений

Просмотреть файл

@ -104,7 +104,6 @@ DEFAULT_ENVIRONMENT_VARIABLES = {
"MKL_SERVICE_FORCE_INTEL": "1",
# Switching to a new software stack in AML for mounting datasets
"RSLEX_DIRECT_VOLUME_MOUNT": "true",
"RSLEX_DIRECT_VOLUME_MOUNT_MAX_CACHE_SIZE": "1",
"DATASET_MOUNT_CACHE_SIZE": "1",
}

Просмотреть файл

@ -252,6 +252,10 @@ class LightningContainer(WorkflowParams,
"""Returns a dictionary of tags that should be added to the AzureML run."""
return {}
def get_additional_environment_variables(self) -> Dict[str, str]:
"""Returns a dictionary of environment variables that should be added to the AzureML run."""
return {}
def on_run_extra_validation_epoch(self) -> None:
if hasattr(self.model, "on_run_extra_validation_epoch"):
assert self._model, "Model is not initialized."

Просмотреть файл

@ -10,7 +10,7 @@ import os
import param
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import matplotlib
from azureml.core import Workspace, Run
@ -180,6 +180,12 @@ class Runner:
**self.lightning_container.get_additional_aml_run_tags()
}
def additional_environment_variables(self) -> Dict[str, str]:
return {
DEBUG_DDP_ENV_VAR: self.experiment_config.debug_ddp.value,
**self.lightning_container.get_additional_environment_variables()
}
def run(self) -> Tuple[LightningContainer, AzureRunInfo]:
"""
The main entry point for training and testing models from the commandline. This chooses a model to train
@ -221,9 +227,7 @@ class Runner:
entry_script = Path(sys.argv[0]).resolve()
script_params = sys.argv[1:]
# TODO: Update environment variables
environment_variables: Dict[str, Any] = {}
environment_variables[DEBUG_DDP_ENV_VAR] = self.experiment_config.debug_ddp.value
environment_variables = self.additional_environment_variables()
# Get default datastore from the provided workspace. Authentication can take a few seconds, hence only do
# that if we are really submitting to AzureML.

Просмотреть файл

@ -7,7 +7,7 @@ import shutil
import sys
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, DEFAULT, create_autospec
import pytest
from _pytest.capture import SysCapture
@ -113,6 +113,31 @@ def test_additional_aml_run_tags(mock_runner: Runner) -> None:
assert "max_epochs" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
def test_additional_environment_variables(mock_runner: Runner) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}", "--cluster=foo"]
with patch.multiple(
"health_ml.runner",
submit_to_azure_if_needed=DEFAULT,
check_conda_environment=DEFAULT,
get_workspace=DEFAULT,
get_ml_client=DEFAULT,
) as mocks:
with patch("health_ml.runner.Runner.run_in_situ"):
with patch("health_ml.runner.Runner.parse_and_load_model"):
with patch("health_ml.runner.Runner.validate"):
with patch.object(sys, "argv", arguments):
mock_container = create_autospec(LightningContainer)
mock_container.get_additional_environment_variables = MagicMock(return_value={"foo": "bar"})
mock_runner.lightning_container = mock_container
mock_runner.run()
mocks["submit_to_azure_if_needed"].assert_called_once()
mock_env_vars = mocks["submit_to_azure_if_needed"].call_args[1]["environment_variables"]
assert DEBUG_DDP_ENV_VAR in mock_env_vars
assert "foo" in mock_env_vars
assert mock_env_vars["foo"] == "bar"
def test_run(mock_runner: Runner) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}"]