ENH: Enable adding tags to aml runs from LightningContainer (#603)

Add get_additional_aml_run_tag to LightningContainer to be able to add extra tags to aml runs to facilitate runs comparison and filtering.
This commit is contained in:
Kenza Bouzid 2022-09-14 11:39:32 +01:00 коммит произвёл GitHub
Родитель 71066ed7df
Коммит 133c994fbd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 24 добавлений и 1 удалений

Просмотреть файл

@ -272,3 +272,6 @@ class HelloWorld(LightningContainer):
*super().get_callbacks()]
else:
return super().get_callbacks()
def get_additional_aml_run_tags(self) -> Dict[str, str]:
return {"max_epochs": str(self.max_epochs)}

Просмотреть файл

@ -224,6 +224,10 @@ class LightningContainer(WorkflowParams,
argument `experiment`, falling back to the model class name if not set."""
return self.experiment or self.model_name
def get_additional_aml_run_tags(self) -> Dict[str, str]:
"""Returns a dictionary of tags that should be added to the AzureML run."""
return {}
class LightningModuleWithOptimizer(LightningModule):
"""

Просмотреть файл

@ -170,7 +170,8 @@ class Runner:
"""
return {
"commandline_args": " ".join(script_params),
"tag": self.lightning_container.tag
"tag": self.lightning_container.tag,
**self.lightning_container.get_additional_aml_run_tags()
}
def run(self) -> Tuple[LightningContainer, AzureRunInfo]:

Просмотреть файл

@ -97,6 +97,21 @@ def test_ddp_debug_flag(debug_ddp: DebugDDPOptions, mock_runner: Runner) -> None
assert mock_submit_to_azure_if_needed.call_args[1]["environment_variables"][DEBUG_DDP_ENV_VAR] == debug_ddp
def test_additional_aml_run_tags(mock_runner: Runner) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}", "--cluster=foo"]
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
with patch("health_ml.runner.check_conda_environment"):
with patch("health_ml.runner.get_workspace"):
with patch("health_ml.runner.Runner.run_in_situ"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
mock_submit_to_azure_if_needed.assert_called_once()
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
assert "max_epochs" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
def test_run(mock_runner: Runner) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}"]