зеркало из https://github.com/microsoft/hi-ml.git
ENH: Enable adding tags to aml runs from LightningContainer (#603)
Add get_additional_aml_run_tag to LightningContainer to be able to add extra tags to aml runs to facilitate runs comparison and filtering.
This commit is contained in:
Родитель
71066ed7df
Коммит
133c994fbd
|
@ -272,3 +272,6 @@ class HelloWorld(LightningContainer):
|
|||
*super().get_callbacks()]
|
||||
else:
|
||||
return super().get_callbacks()
|
||||
|
||||
def get_additional_aml_run_tags(self) -> Dict[str, str]:
|
||||
return {"max_epochs": str(self.max_epochs)}
|
||||
|
|
|
@ -224,6 +224,10 @@ class LightningContainer(WorkflowParams,
|
|||
argument `experiment`, falling back to the model class name if not set."""
|
||||
return self.experiment or self.model_name
|
||||
|
||||
def get_additional_aml_run_tags(self) -> Dict[str, str]:
|
||||
"""Returns a dictionary of tags that should be added to the AzureML run."""
|
||||
return {}
|
||||
|
||||
|
||||
class LightningModuleWithOptimizer(LightningModule):
|
||||
"""
|
||||
|
|
|
@ -170,7 +170,8 @@ class Runner:
|
|||
"""
|
||||
return {
|
||||
"commandline_args": " ".join(script_params),
|
||||
"tag": self.lightning_container.tag
|
||||
"tag": self.lightning_container.tag,
|
||||
**self.lightning_container.get_additional_aml_run_tags()
|
||||
}
|
||||
|
||||
def run(self) -> Tuple[LightningContainer, AzureRunInfo]:
|
||||
|
|
|
@ -97,6 +97,21 @@ def test_ddp_debug_flag(debug_ddp: DebugDDPOptions, mock_runner: Runner) -> None
|
|||
assert mock_submit_to_azure_if_needed.call_args[1]["environment_variables"][DEBUG_DDP_ENV_VAR] == debug_ddp
|
||||
|
||||
|
||||
def test_additional_aml_run_tags(mock_runner: Runner) -> None:
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}", "--cluster=foo"]
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_azure_if_needed:
|
||||
with patch("health_ml.runner.check_conda_environment"):
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch("health_ml.runner.Runner.run_in_situ"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
mock_submit_to_azure_if_needed.assert_called_once()
|
||||
assert "commandline_args" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
assert "tag" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
assert "max_epochs" in mock_submit_to_azure_if_needed.call_args[1]["tags"]
|
||||
|
||||
|
||||
def test_run(mock_runner: Runner) -> None:
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}"]
|
||||
|
|
Загрузка…
Ссылка в новой задаче