ENH: Enable hyperparameter search with AML SDK v2 (#650)

Enable hyperparameter search with AML SDK v2
This commit is contained in:
Melissa Bristow 2022-11-09 15:50:54 +00:00 коммит произвёл GitHub
Родитель 0b65bd42cf
Коммит 89bb3d4a02
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 252 добавлений и 26 удалений

Просмотреть файл

@ -1,4 +1,4 @@
# Hyperparameter Search via Hyperdrive
# Hyperparameter Search via Hyperdrive (AML SDK v1)
[HyperDrive runs](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters)
can start multiple AzureML jobs in parallel. This can be used for tuning hyperparameters, or executing multiple
@ -27,3 +27,37 @@ submit_to_azure_if_needed(..., hyperdrive_config=hyperdrive_config)
For further examples, please check the [example scripts here](examples.md), and the
[HyperDrive documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters).
# Hyperparameter Search in AML SDK v2
There is no concept of a HyperDriveConfig in AML SDK v2. Instead, hyperparameter search arguments are passed into a
command, and then the 'sweep' method is called [AML
docs](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters). To specify a hyperparameter
search job you must specify the method `get_parameter_tuning_args` in your Container. This should return a dictionary of
the arguments to be passed in to the command. For example:
```python
def get_parameter_tuning_args(self) -> Dict[str, Any]:
from azure.ai.ml.entities import Choice
from health_azure.himl import (MAX_TOTAL_TRIALS_ARG, PARAM_SAMPLING_ARG, SAMPLING_ALGORITHM_ARG,
PRIMARY_METRIC_ARG, GOAL_ARG)
values = [0.1, 0.5, 0.9]
argument_name = "learning_rate"
param_sampling = {argument_name: Choice(values)}
metric_name = "val/loss"
hparam_args = {
MAX_TOTAL_TRIALS_ARG: len(values),
PARAM_SAMPLING_ARG: param_sampling,
SAMPLING_ALGORITHM_ARG: "grid",
PRIMARY_METRIC_ARG: metric_name,
GOAL_ARG: "Minimize"
}
return hparam_args
```
Additional parameters, sampling strategies, limits etc. are described in the link above. Note that each job that is
created will receive an additional command line argument `<argument_name>` and it is your job to update the script to be
able to parse and use this argument.

Просмотреть файл

@ -22,7 +22,7 @@ The `hi-ml` toolbox provides
azure_setup.md
authentication.md
datasets.md
hyperdrive.md
hyperparameter_search.md
lowpriority.md
commandline_tools.md
downloading.md

Просмотреть файл

@ -18,12 +18,14 @@ from argparse import ArgumentParser
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from azure.ai.ml import MLClient, Input, Output, command
from azure.ai.ml.constants import AssetTypes, InputOutputModes
from azure.ai.ml.entities import Data, Job
from azure.ai.ml.entities import Data, Job, Command, Sweep
from azure.ai.ml.entities import Environment as EnvironmentV2
from azure.ai.ml.sweep import Choice
from azureml._base_sdk_common import user_agent
from azureml.core import ComputeTarget, Environment, Experiment, Run, RunConfiguration, ScriptRunConfig, Workspace
from azureml.core.runconfig import DockerConfiguration, MpiConfiguration
@ -55,6 +57,13 @@ RUN_RECOVERY_FILE = "most_recent_run.txt"
SDK_NAME = "innereye"
SDK_VERSION = "2.0"
# hyperparameter search args
PARAM_SAMPLING_ARG = "parameter_sampling"
MAX_TOTAL_TRIALS_ARG = "max_total_trials"
PRIMARY_METRIC_ARG = "primary_metric"
SAMPLING_ALGORITHM_ARG = "sampling_algorithm"
GOAL_ARG = "goal"
@dataclass
class AzureRunInfo:
@ -266,6 +275,30 @@ def create_grid_hyperdrive_config(values: List[str],
)
def create_grid_hyperparam_args_v2(values: List[Any],
argument_name: str,
metric_name: str) -> Dict[str, Any]:
"""
Create a dictionary of arguments to create an Azure ML v2 SDK Sweep job.
:param values: The list of values to try for the commandline argument given by `argument_name`.
:param argument_name: The name of the commandline argument that each of the child runs gets, to
indicate which value they should work on.
:param metric_name: The name of the metric that the sweep job will compare runs by. Please note that it is
your responsibility to make sure a metric with this name is logged to the Run in your training script
:return: A dictionary of arguments and values to pass in to the command job.
"""
param_sampling = {argument_name: Choice(values)}
hyperparam_args = {
MAX_TOTAL_TRIALS_ARG: len(values),
PARAM_SAMPLING_ARG: param_sampling,
SAMPLING_ALGORITHM_ARG: "grid",
PRIMARY_METRIC_ARG: metric_name,
GOAL_ARG: "Minimize"
}
return hyperparam_args
def create_crossval_hyperdrive_config(num_splits: int,
cross_val_index_arg_name: str = "crossval_index",
metric_name: str = "val/loss") -> HyperDriveConfig:
@ -286,6 +319,24 @@ def create_crossval_hyperdrive_config(num_splits: int,
metric_name=metric_name)
def create_crossval_hyperparam_args_v2(num_splits: int,
cross_val_index_arg_name: str = "crossval_index",
metric_name: str = "val/loss") -> Dict[str, Any]:
"""
Create a dictionary of arguments to create an Azure ML v2 SDK Sweep job.
:param num_splits: The number of splits for k-fold cross validation
:param cross_val_index_arg_name: The name of the commandline argument that each of the child runs gets, to
indicate which split they should work on.
:param metric_name: The name of the metric that the HyperDriveConfig will compare runs by. Please note that it is
your responsibility to make sure a metric with this name is logged to the Run in your training script
:return: A dictionary of arguments and values to pass in to the command job.
"""
return create_grid_hyperparam_args_v2(values=list(map(str, range(num_splits))),
argument_name=cross_val_index_arg_name,
metric_name=metric_name)
def create_script_run(snapshot_root_directory: Optional[Path] = None,
entry_script: Optional[PathOrString] = None,
script_params: Optional[List[str]] = None) -> ScriptRunConfig:
@ -370,7 +421,8 @@ def submit_run_v2(workspace: Optional[Workspace],
wait_for_completion: bool = False,
wait_for_completion_show_output: bool = False,
workspace_config_path: Optional[PathOrString] = None,
ml_client: Optional[MLClient] = None) -> Job:
ml_client: Optional[MLClient] = None,
hyperparam_args: Optional[Dict[str, Any]] = None) -> Job:
"""
Starts a v2 AML Job on a given workspace by submitting a command
@ -392,8 +444,10 @@ def submit_run_v2(workspace: Optional[Workspace],
the completion of this run (if True).
:param wait_for_completion_show_output: If wait_for_completion is True this parameter indicates whether to show the
run output on sys.stdout.
:param workspace_config_path:
:param ml_client:
:param workspace_config_path: If not provided with an AzureML Workspace, then load one given the information in this
config
:param ml_client: An Azure MLClient object for interacting with Azure resources.
:param hyperparam_args: A dictionary of hyperparameter search args to pass into a sweep job.
:return: An AzureML Run object.
"""
if ml_client is None:
@ -429,20 +483,59 @@ def submit_run_v2(workspace: Optional[Workspace],
else:
output_datasets_v2 = {}
command_job = command(
code=str(snapshot_root_directory),
command=cmd,
inputs=input_datasets_v2,
outputs=output_datasets_v2,
environment=environment.name + "@latest",
compute=compute_target,
experiment_name=experiment_name,
environment_variables={
"JOB_EXECUTION_MODE": "Basic",
"AZUREML_COMPUTE_USE_COMMON_RUNTIME": "true"
}
)
returned_job = ml_client.jobs.create_or_update(command_job)
job_to_submit: Union[Command, Sweep]
if hyperparam_args:
param_sampling = hyperparam_args[PARAM_SAMPLING_ARG]
for sample_param, choices in param_sampling.items():
input_datasets_v2[sample_param] = choices.values[0]
cmd += f" --{sample_param}=" + "${{inputs." + sample_param + "}}"
command_job = command(
code=str(snapshot_root_directory),
command=cmd,
inputs=input_datasets_v2,
outputs=output_datasets_v2,
environment=environment.name + "@latest",
compute=compute_target,
experiment_name=experiment_name,
environment_variables={
"JOB_EXECUTION_MODE": "Basic",
}
)
del hyperparam_args[PARAM_SAMPLING_ARG]
# override command with parameter expressions
command_job = command_job(
**param_sampling,
)
job_to_submit = command_job.sweep(
compute=compute_target, # AML docs suggest setting this here although already passed to command
**hyperparam_args
)
# AML docs state to reset certain properties here which aren't picked up from the
# underlying command such as experiment name and max_total_trials
job_to_submit.experiment_name = experiment_name
job_to_submit.set_limits(max_total_trials=hyperparam_args.get(MAX_TOTAL_TRIALS_ARG, None))
else:
job_to_submit = command(
code=str(snapshot_root_directory),
command=cmd,
inputs=input_datasets_v2,
outputs=output_datasets_v2,
environment=environment.name + "@latest",
compute=compute_target,
experiment_name=experiment_name,
environment_variables={
"JOB_EXECUTION_MODE": "Basic",
}
)
returned_job = ml_client.jobs.create_or_update(job_to_submit)
logging.info(f"URL to job: {returned_job.services['Studio'].endpoint}") # type: ignore
return returned_job
@ -608,6 +701,7 @@ def submit_to_azure_if_needed( # type: ignore
tags: Optional[Dict[str, str]] = None,
after_submission: Optional[Callable[[Run], None]] = None,
hyperdrive_config: Optional[HyperDriveConfig] = None,
hyperparam_args: Optional[Dict[str, Any]] = None,
create_output_folders: bool = True,
strictly_aml_v1: bool = False,
) -> AzureRunInfo: # pragma: no cover
@ -785,6 +879,7 @@ def submit_to_azure_if_needed( # type: ignore
registered_env = register_environment_v2(environment, ml_client)
input_datasets_v2 = create_v2_inputs(ml_client, cleaned_input_datasets)
output_datasets_v2 = create_v2_outputs(cleaned_output_datasets)
run = submit_run_v2(workspace=workspace,
input_datasets_v2=input_datasets_v2,
output_datasets_v2=output_datasets_v2,
@ -796,7 +891,9 @@ def submit_to_azure_if_needed( # type: ignore
compute_target=compute_cluster_name,
tags=tags,
wait_for_completion=wait_for_completion,
wait_for_completion_show_output=wait_for_completion_show_output)
wait_for_completion_show_output=wait_for_completion_show_output,
hyperparam_args=hyperparam_args
)
if after_submission is not None and strictly_aml_v1:
after_submission(run)

Просмотреть файл

@ -27,6 +27,7 @@ from _pytest.capture import CaptureFixture
from azure.ai.ml import Input, Output, MLClient
from azure.ai.ml.constants import AssetTypes
from azure.ai.ml.entities import Data
from azure.ai.ml.sweep import Choice
from azureml._restclient.constants import RunStatus
from azureml.core import ComputeTarget, Environment, RunConfiguration, ScriptRunConfig, Workspace
from azureml.data.azure_storage_datastore import AzureBlobDatastore
@ -1307,6 +1308,46 @@ def test_create_crossval_hyperdrive_config(_: MagicMock, num_crossval_splits: in
assert crossval_config._max_total_runs == num_crossval_splits
def test_create_crossval_hyperparam_args_v2() -> None:
num_splits = 3
crossval_args = himl.create_crossval_hyperparam_args_v2(num_splits)
assert isinstance(crossval_args, Dict)
assert crossval_args[himl.MAX_TOTAL_TRIALS_ARG] == num_splits
assert isinstance(crossval_args[himl.PARAM_SAMPLING_ARG], Dict)
assert isinstance(crossval_args[himl.PARAM_SAMPLING_ARG]["crossval_index"], Choice)
assert crossval_args[himl.PRIMARY_METRIC_ARG] == "val/loss"
assert crossval_args[himl.SAMPLING_ALGORITHM_ARG] == "grid"
assert crossval_args[himl.GOAL_ARG] == "Minimize"
def test_create_grid_hyperparam_args_v2() -> None:
mock_values_float = [0.1, 0.2, 0.5]
mock_arg_name_float = "float_number"
mock_metric_name_float = mock_arg_name_float
hparams_args_float = himl.create_grid_hyperparam_args_v2(mock_values_float, mock_arg_name_float,
mock_metric_name_float)
assert isinstance(hparams_args_float, Dict)
assert hparams_args_float[himl.MAX_TOTAL_TRIALS_ARG] == len(mock_values_float)
assert isinstance(hparams_args_float[himl.PARAM_SAMPLING_ARG], Dict)
assert isinstance(hparams_args_float[himl.PARAM_SAMPLING_ARG][mock_arg_name_float], Choice)
assert hparams_args_float[himl.PRIMARY_METRIC_ARG] == mock_metric_name_float
assert hparams_args_float[himl.SAMPLING_ALGORITHM_ARG] == "grid"
assert hparams_args_float[himl.GOAL_ARG] == "Minimize"
mock_values_str = ["a", "b", "c"]
mock_arg_name_str = "letter"
mock_metric_name_str = mock_arg_name_str
hparam_args_str = himl.create_grid_hyperparam_args_v2(mock_values_str, mock_arg_name_str,
mock_metric_name_str)
assert isinstance(hparam_args_str, Dict)
assert hparam_args_str[himl.MAX_TOTAL_TRIALS_ARG] == len(mock_values_str)
assert isinstance(hparam_args_str[himl.PARAM_SAMPLING_ARG], Dict)
assert isinstance(hparam_args_str[himl.PARAM_SAMPLING_ARG][mock_arg_name_str], Choice)
assert hparam_args_str[himl.PRIMARY_METRIC_ARG] == mock_metric_name_str
assert hparam_args_str[himl.SAMPLING_ALGORITHM_ARG] == "grid"
assert hparam_args_str[himl.GOAL_ARG] == "Minimize"
@pytest.mark.fast
@pytest.mark.parametrize("cross_validation_metric_name", [None, "accuracy"])
@patch("sys.argv")

Просмотреть файл

@ -11,13 +11,14 @@ import re
from enum import Enum, unique
from param import Parameterized
from pathlib import Path
from typing import List, Optional
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from azureml.train.hyperdrive import HyperDriveConfig
from health_azure import create_crossval_hyperdrive_config
from health_azure.himl import create_grid_hyperdrive_config
from health_azure.himl import (create_grid_hyperdrive_config, create_crossval_hyperparam_args_v2,
create_grid_hyperparam_args_v2)
from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT,
ENV_AMLT_SNAPSHOT_DIR, ENV_AMLT_AZ_BATCHAI_DIR,
is_amulet_job, get_amulet_aml_working_dir)
@ -295,6 +296,28 @@ class WorkflowParams(param.Parameterized):
metric_name="val/loss"
)
def get_crossval_hyperparam_args_v2(self) -> Dict[str, Any]:
"""
Wrapper function to create hyperparameter search arguments specifically for running cross validation
with AML SDK v2
:return: A dictionary of hyperparameter search arguments and values.
"""
return create_crossval_hyperparam_args_v2(num_splits=self.crossval_count,
cross_val_index_arg_name=self.CROSSVAL_INDEX_ARG_NAME,
metric_name="val/loss")
def get_grid_hyperparam_args_v2(self) -> Dict[str, Any]:
"""
Wrapper function to create hyperparameter search arguments specifically for running grid search
with AML SDK v2
:return: A dictionary of hyperparameter search arguments and values.
"""
return create_grid_hyperparam_args_v2(values=list(map(str, range(self.different_seeds))),
argument_name=self.RANDOM_SEED_ARG_NAME,
metric_name="val/loss")
class DatasetParams(param.Parameterized):
datastore: str = param.String(default="", doc="Datastore to look for data in")

Просмотреть файл

@ -94,6 +94,14 @@ class LightningContainer(WorkflowParams,
raise NotImplementedError("Parameter search is not implemented. Please override 'get_parameter_tuning_config' "
"in your model container.")
def get_parameter_tuning_args(self) -> Dict[str, Any]:
"""
Returns a dictionary of hyperperameter argument names and values as expected by a AML SDK v2 job
to perform hyperparameter search
"""
raise NotImplementedError("Parameter search is not implemented. Please override 'get_parameter_tuning_args' "
"in your model container.")
def update_experiment_config(self, experiment_config: ExperimentConfig) -> None:
"""
This method allows overriding ExperimentConfig parameters from within a LightningContainer.
@ -183,6 +191,21 @@ class LightningContainer(WorkflowParams,
return self.get_different_seeds_hyperdrive_config()
return None
def get_hyperparam_args(self) -> Optional[Dict[str, Any]]:
"""
Returns a dictionary of hyperparameter search arguments that will be passed to an AML v2 command to
enable either hyperparameter tuning, cross validation, or running with different seeds.
:return: A dictionary of hyperparameter search arguments and values.
"""
if self.hyperdrive:
return self.get_parameter_tuning_args()
if self.is_crossvalidation_enabled:
return self.get_crossval_hyperparam_args_v2()
if self.different_seeds > 0:
return self.get_grid_hyperparam_args_v2()
return None
def load_model_checkpoint(self, checkpoint_path: Path) -> None:
"""
Load a checkpoint from the given path. We need to define a separate method since pytorch lightning cannot

Просмотреть файл

@ -251,7 +251,14 @@ class Runner:
all_local_datasets=all_local_datasets, # type: ignore
datastore=datastore,
use_mounting=use_mounting)
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
if self.experiment_config.strictly_aml_v1:
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
hyperparam_args = None
else:
hyperparam_args = self.lightning_container.get_hyperparam_args()
hyperdrive_config = None
if self.experiment_config.cluster and not is_running_in_azure_ml():
ml_client = get_ml_client()
@ -278,6 +285,7 @@ class Runner:
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
docker_shm_size=self.experiment_config.docker_shm_size,
hyperdrive_config=hyperdrive_config,
hyperparam_args=hyperparam_args,
create_output_folders=False,
after_submission=after_submission_hook,
tags=self.additional_run_tags(script_params),

Просмотреть файл

@ -219,7 +219,7 @@ def _test_hyperdrive_submission(mock_runner: Runner,
expected_argument_name: str,
expected_argument_values: List[str]) -> None:
model_name = "HelloWorld"
arguments = ["", f"--model={model_name}", "--cluster=foo", commandline_arg]
arguments = ["", f"--model={model_name}", "--cluster=foo", commandline_arg, "--strictly_aml_v1=True"]
# Use a special simplified environment file only for the tests here. Copy that to a temp folder, then let the runner
# start in that temp folder.
with change_working_folder_and_add_environment(mock_runner.project_root):