зеркало из https://github.com/microsoft/hi-ml.git
FIX: Fix env variables not passed to AML SDK v2 jobs (#885)
Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com>
This commit is contained in:
Родитель
6b25f43ea0
Коммит
683def950a
|
@ -444,6 +444,7 @@ def effective_experiment_name(experiment_name: Optional[str], entry_script: Opti
|
|||
def submit_run_v2(
|
||||
workspace: Optional[Workspace],
|
||||
environment: EnvironmentV2,
|
||||
environment_variables: Optional[Dict[str, str]] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
input_datasets_v2: Optional[Dict[str, Input]] = None,
|
||||
output_datasets_v2: Optional[Dict[str, Output]] = None,
|
||||
|
@ -467,6 +468,7 @@ def submit_run_v2(
|
|||
|
||||
:param workspace: The AzureML workspace to use.
|
||||
:param environment: An AML v2 Environment object.
|
||||
:param environment_variables: The environment variables that should be set when running in AzureML.
|
||||
:param experiment_name: The name of the experiment that will be used or created. If the experiment name contains
|
||||
characters that are not valid in Azure, those will be removed.
|
||||
:param input_datasets_v2: An optional dictionary of Inputs to pass in to the command.
|
||||
|
@ -547,6 +549,7 @@ def submit_run_v2(
|
|||
inputs=input_datasets_v2,
|
||||
outputs=output_datasets_v2,
|
||||
environment=environment.name + "@latest",
|
||||
environment_variables=environment_variables,
|
||||
compute=compute_target,
|
||||
experiment_name=experiment_name,
|
||||
tags=tags or {},
|
||||
|
@ -995,6 +998,7 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
output_datasets_v2=output_datasets_v2,
|
||||
experiment_name=experiment_name,
|
||||
environment=registered_env,
|
||||
environment_variables=environment_variables,
|
||||
snapshot_root_directory=snapshot_root_directory,
|
||||
entry_script=entry_script,
|
||||
script_params=script_params,
|
||||
|
|
|
@ -800,6 +800,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
|||
inputs=dummy_inputs,
|
||||
outputs=dummy_outputs,
|
||||
environment=dummy_environment_name + "@latest",
|
||||
environment_variables=None,
|
||||
compute=dummy_compute_target,
|
||||
experiment_name=dummy_experiment_name,
|
||||
tags=dummy_tags,
|
||||
|
@ -862,6 +863,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
|||
inputs=dummy_inputs,
|
||||
outputs=dummy_outputs,
|
||||
environment=dummy_environment_name + "@latest",
|
||||
environment_variables=None,
|
||||
compute=dummy_compute_target,
|
||||
experiment_name=dummy_experiment_name,
|
||||
tags=dummy_tags,
|
||||
|
@ -903,6 +905,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
|
|||
inputs=dummy_inputs,
|
||||
outputs=dummy_outputs,
|
||||
environment=dummy_environment_name + "@latest",
|
||||
environment_variables=None,
|
||||
compute=dummy_compute_target,
|
||||
experiment_name=dummy_experiment_name,
|
||||
tags=dummy_tags,
|
||||
|
@ -1929,26 +1932,62 @@ def test_submitting_script_with_sdk_v2_passes_display_name(tmp_path: Path) -> No
|
|||
# Create a minimal script in a temp folder.
|
||||
test_script = tmp_path / "test_script.py"
|
||||
test_script.write_text("print('hello world')")
|
||||
shared_config_json = get_shared_config_json()
|
||||
conda_env_path = create_empty_conda_env(tmp_path)
|
||||
display_name = "my_display_name"
|
||||
with patch.multiple(
|
||||
"health_azure.himl",
|
||||
get_ml_client=DEFAULT,
|
||||
get_workspace=DEFAULT,
|
||||
create_python_environment_v2=DEFAULT,
|
||||
register_environment_v2=DEFAULT,
|
||||
):
|
||||
with patch("health_azure.himl.command", side_effect=NotImplementedError) as mock_command:
|
||||
with pytest.raises(NotImplementedError):
|
||||
himl.submit_to_azure_if_needed(
|
||||
entry_script=test_script,
|
||||
conda_environment_file=conda_env_path,
|
||||
snapshot_root_directory=tmp_path,
|
||||
submit_to_azureml=True,
|
||||
strictly_aml_v1=False,
|
||||
display_name=display_name,
|
||||
)
|
||||
mock_command.assert_called_once()
|
||||
_, call_kwargs = mock_command.call_args
|
||||
assert call_kwargs.get("display_name") == display_name, "display_name was not passed to command"
|
||||
|
||||
with check_config_json(tmp_path, shared_config_json=shared_config_json), change_working_directory(tmp_path), patch(
|
||||
"health_azure.himl.command", side_effect=ValueError
|
||||
) as mock_command:
|
||||
with pytest.raises(ValueError):
|
||||
himl.submit_to_azure_if_needed(
|
||||
aml_workspace=None,
|
||||
entry_script=test_script,
|
||||
conda_environment_file=conda_env_path,
|
||||
snapshot_root_directory=tmp_path,
|
||||
submit_to_azureml=True,
|
||||
strictly_aml_v1=False,
|
||||
display_name=display_name,
|
||||
)
|
||||
mock_command.assert_called_once()
|
||||
_, call_kwargs = mock_command.call_args
|
||||
assert call_kwargs.get("display_name") == display_name, "display_name was not passed to command"
|
||||
|
||||
def test_submitting_script_with_sdk_v2_passes_environment_variables(tmp_path: Path) -> None:
|
||||
"""
|
||||
Test that submission of a script with SDK v2 passes the environment variables to the "command" function
|
||||
that does the actual submission.
|
||||
"""
|
||||
# Create a minimal script in a temp folder.
|
||||
test_script = tmp_path / "test_script.py"
|
||||
test_script.write_text("print('hello world')")
|
||||
conda_env_path = create_empty_conda_env(tmp_path)
|
||||
environment_variables = {"foo": "bar"}
|
||||
|
||||
with patch.multiple(
|
||||
"health_azure.himl",
|
||||
get_ml_client=DEFAULT,
|
||||
get_workspace=DEFAULT,
|
||||
create_python_environment_v2=DEFAULT,
|
||||
register_environment_v2=DEFAULT,
|
||||
):
|
||||
with patch("health_azure.himl.command", side_effect=NotImplementedError) as mock_command:
|
||||
with pytest.raises(NotImplementedError):
|
||||
himl.submit_to_azure_if_needed(
|
||||
entry_script=test_script,
|
||||
conda_environment_file=conda_env_path,
|
||||
snapshot_root_directory=tmp_path,
|
||||
submit_to_azureml=True,
|
||||
strictly_aml_v1=False,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
mock_command.assert_called_once()
|
||||
_, call_kwargs = mock_command.call_args
|
||||
assert "environment_variables" in call_kwargs
|
||||
assert call_kwargs.get("environment_variables") == environment_variables, "environment_variables not passed"
|
||||
|
||||
|
||||
def test_conda_env_missing(tmp_path: Path) -> None:
|
||||
|
|
Загрузка…
Ссылка в новой задаче