FIX: Fix env variables not passed to AML SDK v2 jobs (#885)

Co-authored-by: Anton Schwaighofer <antonsc@microsoft.com>
This commit is contained in:
Fernando Pérez-García 2023-05-11 08:52:18 +01:00 коммит произвёл GitHub
Родитель 6b25f43ea0
Коммит 683def950a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 60 добавлений и 17 удалений

Просмотреть файл

@ -444,6 +444,7 @@ def effective_experiment_name(experiment_name: Optional[str], entry_script: Opti
def submit_run_v2(
workspace: Optional[Workspace],
environment: EnvironmentV2,
environment_variables: Optional[Dict[str, str]] = None,
experiment_name: Optional[str] = None,
input_datasets_v2: Optional[Dict[str, Input]] = None,
output_datasets_v2: Optional[Dict[str, Output]] = None,
@ -467,6 +468,7 @@ def submit_run_v2(
:param workspace: The AzureML workspace to use.
:param environment: An AML v2 Environment object.
:param environment_variables: The environment variables that should be set when running in AzureML.
:param experiment_name: The name of the experiment that will be used or created. If the experiment name contains
characters that are not valid in Azure, those will be removed.
:param input_datasets_v2: An optional dictionary of Inputs to pass in to the command.
@ -547,6 +549,7 @@ def submit_run_v2(
inputs=input_datasets_v2,
outputs=output_datasets_v2,
environment=environment.name + "@latest",
environment_variables=environment_variables,
compute=compute_target,
experiment_name=experiment_name,
tags=tags or {},
@ -995,6 +998,7 @@ def submit_to_azure_if_needed( # type: ignore
output_datasets_v2=output_datasets_v2,
experiment_name=experiment_name,
environment=registered_env,
environment_variables=environment_variables,
snapshot_root_directory=snapshot_root_directory,
entry_script=entry_script,
script_params=script_params,

Просмотреть файл

@ -800,6 +800,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
inputs=dummy_inputs,
outputs=dummy_outputs,
environment=dummy_environment_name + "@latest",
environment_variables=None,
compute=dummy_compute_target,
experiment_name=dummy_experiment_name,
tags=dummy_tags,
@ -862,6 +863,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
inputs=dummy_inputs,
outputs=dummy_outputs,
environment=dummy_environment_name + "@latest",
environment_variables=None,
compute=dummy_compute_target,
experiment_name=dummy_experiment_name,
tags=dummy_tags,
@ -903,6 +905,7 @@ def test_submit_run_v2(tmp_path: Path) -> None:
inputs=dummy_inputs,
outputs=dummy_outputs,
environment=dummy_environment_name + "@latest",
environment_variables=None,
compute=dummy_compute_target,
experiment_name=dummy_experiment_name,
tags=dummy_tags,
@ -1929,26 +1932,62 @@ def test_submitting_script_with_sdk_v2_passes_display_name(tmp_path: Path) -> No
# Create a minimal script in a temp folder.
test_script = tmp_path / "test_script.py"
test_script.write_text("print('hello world')")
shared_config_json = get_shared_config_json()
conda_env_path = create_empty_conda_env(tmp_path)
display_name = "my_display_name"
with patch.multiple(
"health_azure.himl",
get_ml_client=DEFAULT,
get_workspace=DEFAULT,
create_python_environment_v2=DEFAULT,
register_environment_v2=DEFAULT,
):
with patch("health_azure.himl.command", side_effect=NotImplementedError) as mock_command:
with pytest.raises(NotImplementedError):
himl.submit_to_azure_if_needed(
entry_script=test_script,
conda_environment_file=conda_env_path,
snapshot_root_directory=tmp_path,
submit_to_azureml=True,
strictly_aml_v1=False,
display_name=display_name,
)
mock_command.assert_called_once()
_, call_kwargs = mock_command.call_args
assert call_kwargs.get("display_name") == display_name, "display_name was not passed to command"
with check_config_json(tmp_path, shared_config_json=shared_config_json), change_working_directory(tmp_path), patch(
"health_azure.himl.command", side_effect=ValueError
) as mock_command:
with pytest.raises(ValueError):
himl.submit_to_azure_if_needed(
aml_workspace=None,
entry_script=test_script,
conda_environment_file=conda_env_path,
snapshot_root_directory=tmp_path,
submit_to_azureml=True,
strictly_aml_v1=False,
display_name=display_name,
)
mock_command.assert_called_once()
_, call_kwargs = mock_command.call_args
assert call_kwargs.get("display_name") == display_name, "display_name was not passed to command"
def test_submitting_script_with_sdk_v2_passes_environment_variables(tmp_path: Path) -> None:
"""
Test that submission of a script with SDK v2 passes the environment variables to the "command" function
that does the actual submission.
"""
# Create a minimal script in a temp folder.
test_script = tmp_path / "test_script.py"
test_script.write_text("print('hello world')")
conda_env_path = create_empty_conda_env(tmp_path)
environment_variables = {"foo": "bar"}
with patch.multiple(
"health_azure.himl",
get_ml_client=DEFAULT,
get_workspace=DEFAULT,
create_python_environment_v2=DEFAULT,
register_environment_v2=DEFAULT,
):
with patch("health_azure.himl.command", side_effect=NotImplementedError) as mock_command:
with pytest.raises(NotImplementedError):
himl.submit_to_azure_if_needed(
entry_script=test_script,
conda_environment_file=conda_env_path,
snapshot_root_directory=tmp_path,
submit_to_azureml=True,
strictly_aml_v1=False,
environment_variables=environment_variables,
)
mock_command.assert_called_once()
_, call_kwargs = mock_command.call_args
assert "environment_variables" in call_kwargs
assert call_kwargs.get("environment_variables") == environment_variables, "environment_variables not passed"
def test_conda_env_missing(tmp_path: Path) -> None: