зеркало из https://github.com/microsoft/hi-ml.git
ENH: Enable hyperparameter search with AML SDK v2 (#650)
Enable hyperparameter search with AML SDK v2
This commit is contained in:
Родитель
0b65bd42cf
Коммит
89bb3d4a02
|
@ -1,4 +1,4 @@
|
|||
# Hyperparameter Search via Hyperdrive
|
||||
# Hyperparameter Search via Hyperdrive (AML SDK v1)
|
||||
|
||||
[HyperDrive runs](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters)
|
||||
can start multiple AzureML jobs in parallel. This can be used for tuning hyperparameters, or executing multiple
|
||||
|
@ -27,3 +27,37 @@ submit_to_azure_if_needed(..., hyperdrive_config=hyperdrive_config)
|
|||
|
||||
For further examples, please check the [example scripts here](examples.md), and the
|
||||
[HyperDrive documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters).
|
||||
|
||||
|
||||
# Hyperparameter Search in AML SDK v2
|
||||
|
||||
There is no concept of a HyperDriveConfig in AML SDK v2. Instead, hyperparameter search arguments are passed into a
|
||||
command, and then the 'sweep' method is called [AML
|
||||
docs](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-tune-hyperparameters). To specify a hyperparameter
|
||||
search job you must specify the method `get_parameter_tuning_args` in your Container. This should return a dictionary of
|
||||
the arguments to be passed in to the command. For example:
|
||||
|
||||
```python
|
||||
def get_parameter_tuning_args(self) -> Dict[str, Any]:
|
||||
from azure.ai.ml.entities import Choice
|
||||
from health_azure.himl import (MAX_TOTAL_TRIALS_ARG, PARAM_SAMPLING_ARG, SAMPLING_ALGORITHM_ARG,
|
||||
PRIMARY_METRIC_ARG, GOAL_ARG)
|
||||
|
||||
values = [0.1, 0.5, 0.9]
|
||||
argument_name = "learning_rate"
|
||||
param_sampling = {argument_name: Choice(values)}
|
||||
metric_name = "val/loss"
|
||||
|
||||
hparam_args = {
|
||||
MAX_TOTAL_TRIALS_ARG: len(values),
|
||||
PARAM_SAMPLING_ARG: param_sampling,
|
||||
SAMPLING_ALGORITHM_ARG: "grid",
|
||||
PRIMARY_METRIC_ARG: metric_name,
|
||||
GOAL_ARG: "Minimize"
|
||||
|
||||
}
|
||||
return hparam_args
|
||||
```
|
||||
Additional parameters, sampling strategies, limits etc. are described in the link above. Note that each job that is
|
||||
created will receive an additional command line argument `<argument_name>` and it is your job to update the script to be
|
||||
able to parse and use this argument.
|
|
@ -22,7 +22,7 @@ The `hi-ml` toolbox provides
|
|||
azure_setup.md
|
||||
authentication.md
|
||||
datasets.md
|
||||
hyperdrive.md
|
||||
hyperparameter_search.md
|
||||
lowpriority.md
|
||||
commandline_tools.md
|
||||
downloading.md
|
||||
|
|
|
@ -18,12 +18,14 @@ from argparse import ArgumentParser
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from azure.ai.ml import MLClient, Input, Output, command
|
||||
from azure.ai.ml.constants import AssetTypes, InputOutputModes
|
||||
from azure.ai.ml.entities import Data, Job
|
||||
from azure.ai.ml.entities import Data, Job, Command, Sweep
|
||||
from azure.ai.ml.entities import Environment as EnvironmentV2
|
||||
|
||||
from azure.ai.ml.sweep import Choice
|
||||
from azureml._base_sdk_common import user_agent
|
||||
from azureml.core import ComputeTarget, Environment, Experiment, Run, RunConfiguration, ScriptRunConfig, Workspace
|
||||
from azureml.core.runconfig import DockerConfiguration, MpiConfiguration
|
||||
|
@ -55,6 +57,13 @@ RUN_RECOVERY_FILE = "most_recent_run.txt"
|
|||
SDK_NAME = "innereye"
|
||||
SDK_VERSION = "2.0"
|
||||
|
||||
# hyperparameter search args
|
||||
PARAM_SAMPLING_ARG = "parameter_sampling"
|
||||
MAX_TOTAL_TRIALS_ARG = "max_total_trials"
|
||||
PRIMARY_METRIC_ARG = "primary_metric"
|
||||
SAMPLING_ALGORITHM_ARG = "sampling_algorithm"
|
||||
GOAL_ARG = "goal"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AzureRunInfo:
|
||||
|
@ -266,6 +275,30 @@ def create_grid_hyperdrive_config(values: List[str],
|
|||
)
|
||||
|
||||
|
||||
def create_grid_hyperparam_args_v2(values: List[Any],
|
||||
argument_name: str,
|
||||
metric_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a dictionary of arguments to create an Azure ML v2 SDK Sweep job.
|
||||
|
||||
:param values: The list of values to try for the commandline argument given by `argument_name`.
|
||||
:param argument_name: The name of the commandline argument that each of the child runs gets, to
|
||||
indicate which value they should work on.
|
||||
:param metric_name: The name of the metric that the sweep job will compare runs by. Please note that it is
|
||||
your responsibility to make sure a metric with this name is logged to the Run in your training script
|
||||
:return: A dictionary of arguments and values to pass in to the command job.
|
||||
"""
|
||||
param_sampling = {argument_name: Choice(values)}
|
||||
hyperparam_args = {
|
||||
MAX_TOTAL_TRIALS_ARG: len(values),
|
||||
PARAM_SAMPLING_ARG: param_sampling,
|
||||
SAMPLING_ALGORITHM_ARG: "grid",
|
||||
PRIMARY_METRIC_ARG: metric_name,
|
||||
GOAL_ARG: "Minimize"
|
||||
}
|
||||
return hyperparam_args
|
||||
|
||||
|
||||
def create_crossval_hyperdrive_config(num_splits: int,
|
||||
cross_val_index_arg_name: str = "crossval_index",
|
||||
metric_name: str = "val/loss") -> HyperDriveConfig:
|
||||
|
@ -286,6 +319,24 @@ def create_crossval_hyperdrive_config(num_splits: int,
|
|||
metric_name=metric_name)
|
||||
|
||||
|
||||
def create_crossval_hyperparam_args_v2(num_splits: int,
|
||||
cross_val_index_arg_name: str = "crossval_index",
|
||||
metric_name: str = "val/loss") -> Dict[str, Any]:
|
||||
"""
|
||||
Create a dictionary of arguments to create an Azure ML v2 SDK Sweep job.
|
||||
|
||||
:param num_splits: The number of splits for k-fold cross validation
|
||||
:param cross_val_index_arg_name: The name of the commandline argument that each of the child runs gets, to
|
||||
indicate which split they should work on.
|
||||
:param metric_name: The name of the metric that the HyperDriveConfig will compare runs by. Please note that it is
|
||||
your responsibility to make sure a metric with this name is logged to the Run in your training script
|
||||
:return: A dictionary of arguments and values to pass in to the command job.
|
||||
"""
|
||||
return create_grid_hyperparam_args_v2(values=list(map(str, range(num_splits))),
|
||||
argument_name=cross_val_index_arg_name,
|
||||
metric_name=metric_name)
|
||||
|
||||
|
||||
def create_script_run(snapshot_root_directory: Optional[Path] = None,
|
||||
entry_script: Optional[PathOrString] = None,
|
||||
script_params: Optional[List[str]] = None) -> ScriptRunConfig:
|
||||
|
@ -370,7 +421,8 @@ def submit_run_v2(workspace: Optional[Workspace],
|
|||
wait_for_completion: bool = False,
|
||||
wait_for_completion_show_output: bool = False,
|
||||
workspace_config_path: Optional[PathOrString] = None,
|
||||
ml_client: Optional[MLClient] = None) -> Job:
|
||||
ml_client: Optional[MLClient] = None,
|
||||
hyperparam_args: Optional[Dict[str, Any]] = None) -> Job:
|
||||
"""
|
||||
Starts a v2 AML Job on a given workspace by submitting a command
|
||||
|
||||
|
@ -392,8 +444,10 @@ def submit_run_v2(workspace: Optional[Workspace],
|
|||
the completion of this run (if True).
|
||||
:param wait_for_completion_show_output: If wait_for_completion is True this parameter indicates whether to show the
|
||||
run output on sys.stdout.
|
||||
:param workspace_config_path:
|
||||
:param ml_client:
|
||||
:param workspace_config_path: If not provided with an AzureML Workspace, then load one given the information in this
|
||||
config
|
||||
:param ml_client: An Azure MLClient object for interacting with Azure resources.
|
||||
:param hyperparam_args: A dictionary of hyperparameter search args to pass into a sweep job.
|
||||
:return: An AzureML Run object.
|
||||
"""
|
||||
if ml_client is None:
|
||||
|
@ -429,20 +483,59 @@ def submit_run_v2(workspace: Optional[Workspace],
|
|||
else:
|
||||
output_datasets_v2 = {}
|
||||
|
||||
command_job = command(
|
||||
code=str(snapshot_root_directory),
|
||||
command=cmd,
|
||||
inputs=input_datasets_v2,
|
||||
outputs=output_datasets_v2,
|
||||
environment=environment.name + "@latest",
|
||||
compute=compute_target,
|
||||
experiment_name=experiment_name,
|
||||
environment_variables={
|
||||
"JOB_EXECUTION_MODE": "Basic",
|
||||
"AZUREML_COMPUTE_USE_COMMON_RUNTIME": "true"
|
||||
}
|
||||
)
|
||||
returned_job = ml_client.jobs.create_or_update(command_job)
|
||||
job_to_submit: Union[Command, Sweep]
|
||||
|
||||
if hyperparam_args:
|
||||
param_sampling = hyperparam_args[PARAM_SAMPLING_ARG]
|
||||
|
||||
for sample_param, choices in param_sampling.items():
|
||||
input_datasets_v2[sample_param] = choices.values[0]
|
||||
cmd += f" --{sample_param}=" + "${{inputs." + sample_param + "}}"
|
||||
|
||||
command_job = command(
|
||||
code=str(snapshot_root_directory),
|
||||
command=cmd,
|
||||
inputs=input_datasets_v2,
|
||||
outputs=output_datasets_v2,
|
||||
environment=environment.name + "@latest",
|
||||
compute=compute_target,
|
||||
experiment_name=experiment_name,
|
||||
environment_variables={
|
||||
"JOB_EXECUTION_MODE": "Basic",
|
||||
}
|
||||
)
|
||||
|
||||
del hyperparam_args[PARAM_SAMPLING_ARG]
|
||||
# override command with parameter expressions
|
||||
command_job = command_job(
|
||||
**param_sampling,
|
||||
)
|
||||
|
||||
job_to_submit = command_job.sweep(
|
||||
compute=compute_target, # AML docs suggest setting this here although already passed to command
|
||||
**hyperparam_args
|
||||
)
|
||||
|
||||
# AML docs state to reset certain properties here which aren't picked up from the
|
||||
# underlying command such as experiment name and max_total_trials
|
||||
job_to_submit.experiment_name = experiment_name
|
||||
job_to_submit.set_limits(max_total_trials=hyperparam_args.get(MAX_TOTAL_TRIALS_ARG, None))
|
||||
|
||||
else:
|
||||
job_to_submit = command(
|
||||
code=str(snapshot_root_directory),
|
||||
command=cmd,
|
||||
inputs=input_datasets_v2,
|
||||
outputs=output_datasets_v2,
|
||||
environment=environment.name + "@latest",
|
||||
compute=compute_target,
|
||||
experiment_name=experiment_name,
|
||||
environment_variables={
|
||||
"JOB_EXECUTION_MODE": "Basic",
|
||||
}
|
||||
)
|
||||
|
||||
returned_job = ml_client.jobs.create_or_update(job_to_submit)
|
||||
logging.info(f"URL to job: {returned_job.services['Studio'].endpoint}") # type: ignore
|
||||
return returned_job
|
||||
|
||||
|
@ -608,6 +701,7 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
tags: Optional[Dict[str, str]] = None,
|
||||
after_submission: Optional[Callable[[Run], None]] = None,
|
||||
hyperdrive_config: Optional[HyperDriveConfig] = None,
|
||||
hyperparam_args: Optional[Dict[str, Any]] = None,
|
||||
create_output_folders: bool = True,
|
||||
strictly_aml_v1: bool = False,
|
||||
) -> AzureRunInfo: # pragma: no cover
|
||||
|
@ -785,6 +879,7 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
registered_env = register_environment_v2(environment, ml_client)
|
||||
input_datasets_v2 = create_v2_inputs(ml_client, cleaned_input_datasets)
|
||||
output_datasets_v2 = create_v2_outputs(cleaned_output_datasets)
|
||||
|
||||
run = submit_run_v2(workspace=workspace,
|
||||
input_datasets_v2=input_datasets_v2,
|
||||
output_datasets_v2=output_datasets_v2,
|
||||
|
@ -796,7 +891,9 @@ def submit_to_azure_if_needed( # type: ignore
|
|||
compute_target=compute_cluster_name,
|
||||
tags=tags,
|
||||
wait_for_completion=wait_for_completion,
|
||||
wait_for_completion_show_output=wait_for_completion_show_output)
|
||||
wait_for_completion_show_output=wait_for_completion_show_output,
|
||||
hyperparam_args=hyperparam_args
|
||||
)
|
||||
|
||||
if after_submission is not None and strictly_aml_v1:
|
||||
after_submission(run)
|
||||
|
|
|
@ -27,6 +27,7 @@ from _pytest.capture import CaptureFixture
|
|||
from azure.ai.ml import Input, Output, MLClient
|
||||
from azure.ai.ml.constants import AssetTypes
|
||||
from azure.ai.ml.entities import Data
|
||||
from azure.ai.ml.sweep import Choice
|
||||
from azureml._restclient.constants import RunStatus
|
||||
from azureml.core import ComputeTarget, Environment, RunConfiguration, ScriptRunConfig, Workspace
|
||||
from azureml.data.azure_storage_datastore import AzureBlobDatastore
|
||||
|
@ -1307,6 +1308,46 @@ def test_create_crossval_hyperdrive_config(_: MagicMock, num_crossval_splits: in
|
|||
assert crossval_config._max_total_runs == num_crossval_splits
|
||||
|
||||
|
||||
def test_create_crossval_hyperparam_args_v2() -> None:
|
||||
num_splits = 3
|
||||
crossval_args = himl.create_crossval_hyperparam_args_v2(num_splits)
|
||||
assert isinstance(crossval_args, Dict)
|
||||
assert crossval_args[himl.MAX_TOTAL_TRIALS_ARG] == num_splits
|
||||
assert isinstance(crossval_args[himl.PARAM_SAMPLING_ARG], Dict)
|
||||
assert isinstance(crossval_args[himl.PARAM_SAMPLING_ARG]["crossval_index"], Choice)
|
||||
assert crossval_args[himl.PRIMARY_METRIC_ARG] == "val/loss"
|
||||
assert crossval_args[himl.SAMPLING_ALGORITHM_ARG] == "grid"
|
||||
assert crossval_args[himl.GOAL_ARG] == "Minimize"
|
||||
|
||||
|
||||
def test_create_grid_hyperparam_args_v2() -> None:
|
||||
mock_values_float = [0.1, 0.2, 0.5]
|
||||
mock_arg_name_float = "float_number"
|
||||
mock_metric_name_float = mock_arg_name_float
|
||||
hparams_args_float = himl.create_grid_hyperparam_args_v2(mock_values_float, mock_arg_name_float,
|
||||
mock_metric_name_float)
|
||||
assert isinstance(hparams_args_float, Dict)
|
||||
assert hparams_args_float[himl.MAX_TOTAL_TRIALS_ARG] == len(mock_values_float)
|
||||
assert isinstance(hparams_args_float[himl.PARAM_SAMPLING_ARG], Dict)
|
||||
assert isinstance(hparams_args_float[himl.PARAM_SAMPLING_ARG][mock_arg_name_float], Choice)
|
||||
assert hparams_args_float[himl.PRIMARY_METRIC_ARG] == mock_metric_name_float
|
||||
assert hparams_args_float[himl.SAMPLING_ALGORITHM_ARG] == "grid"
|
||||
assert hparams_args_float[himl.GOAL_ARG] == "Minimize"
|
||||
|
||||
mock_values_str = ["a", "b", "c"]
|
||||
mock_arg_name_str = "letter"
|
||||
mock_metric_name_str = mock_arg_name_str
|
||||
hparam_args_str = himl.create_grid_hyperparam_args_v2(mock_values_str, mock_arg_name_str,
|
||||
mock_metric_name_str)
|
||||
assert isinstance(hparam_args_str, Dict)
|
||||
assert hparam_args_str[himl.MAX_TOTAL_TRIALS_ARG] == len(mock_values_str)
|
||||
assert isinstance(hparam_args_str[himl.PARAM_SAMPLING_ARG], Dict)
|
||||
assert isinstance(hparam_args_str[himl.PARAM_SAMPLING_ARG][mock_arg_name_str], Choice)
|
||||
assert hparam_args_str[himl.PRIMARY_METRIC_ARG] == mock_metric_name_str
|
||||
assert hparam_args_str[himl.SAMPLING_ALGORITHM_ARG] == "grid"
|
||||
assert hparam_args_str[himl.GOAL_ARG] == "Minimize"
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
@pytest.mark.parametrize("cross_validation_metric_name", [None, "accuracy"])
|
||||
@patch("sys.argv")
|
||||
|
|
|
@ -11,13 +11,14 @@ import re
|
|||
from enum import Enum, unique
|
||||
from param import Parameterized
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from azureml.train.hyperdrive import HyperDriveConfig
|
||||
|
||||
from health_azure import create_crossval_hyperdrive_config
|
||||
from health_azure.himl import create_grid_hyperdrive_config
|
||||
from health_azure.himl import (create_grid_hyperdrive_config, create_crossval_hyperparam_args_v2,
|
||||
create_grid_hyperparam_args_v2)
|
||||
from health_azure.amulet import (ENV_AMLT_PROJECT_NAME, ENV_AMLT_INPUT_OUTPUT,
|
||||
ENV_AMLT_SNAPSHOT_DIR, ENV_AMLT_AZ_BATCHAI_DIR,
|
||||
is_amulet_job, get_amulet_aml_working_dir)
|
||||
|
@ -295,6 +296,28 @@ class WorkflowParams(param.Parameterized):
|
|||
metric_name="val/loss"
|
||||
)
|
||||
|
||||
def get_crossval_hyperparam_args_v2(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Wrapper function to create hyperparameter search arguments specifically for running cross validation
|
||||
with AML SDK v2
|
||||
|
||||
:return: A dictionary of hyperparameter search arguments and values.
|
||||
"""
|
||||
return create_crossval_hyperparam_args_v2(num_splits=self.crossval_count,
|
||||
cross_val_index_arg_name=self.CROSSVAL_INDEX_ARG_NAME,
|
||||
metric_name="val/loss")
|
||||
|
||||
def get_grid_hyperparam_args_v2(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Wrapper function to create hyperparameter search arguments specifically for running grid search
|
||||
with AML SDK v2
|
||||
|
||||
:return: A dictionary of hyperparameter search arguments and values.
|
||||
"""
|
||||
return create_grid_hyperparam_args_v2(values=list(map(str, range(self.different_seeds))),
|
||||
argument_name=self.RANDOM_SEED_ARG_NAME,
|
||||
metric_name="val/loss")
|
||||
|
||||
|
||||
class DatasetParams(param.Parameterized):
|
||||
datastore: str = param.String(default="", doc="Datastore to look for data in")
|
||||
|
|
|
@ -94,6 +94,14 @@ class LightningContainer(WorkflowParams,
|
|||
raise NotImplementedError("Parameter search is not implemented. Please override 'get_parameter_tuning_config' "
|
||||
"in your model container.")
|
||||
|
||||
def get_parameter_tuning_args(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of hyperperameter argument names and values as expected by a AML SDK v2 job
|
||||
to perform hyperparameter search
|
||||
"""
|
||||
raise NotImplementedError("Parameter search is not implemented. Please override 'get_parameter_tuning_args' "
|
||||
"in your model container.")
|
||||
|
||||
def update_experiment_config(self, experiment_config: ExperimentConfig) -> None:
|
||||
"""
|
||||
This method allows overriding ExperimentConfig parameters from within a LightningContainer.
|
||||
|
@ -183,6 +191,21 @@ class LightningContainer(WorkflowParams,
|
|||
return self.get_different_seeds_hyperdrive_config()
|
||||
return None
|
||||
|
||||
def get_hyperparam_args(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Returns a dictionary of hyperparameter search arguments that will be passed to an AML v2 command to
|
||||
enable either hyperparameter tuning, cross validation, or running with different seeds.
|
||||
|
||||
:return: A dictionary of hyperparameter search arguments and values.
|
||||
"""
|
||||
if self.hyperdrive:
|
||||
return self.get_parameter_tuning_args()
|
||||
if self.is_crossvalidation_enabled:
|
||||
return self.get_crossval_hyperparam_args_v2()
|
||||
if self.different_seeds > 0:
|
||||
return self.get_grid_hyperparam_args_v2()
|
||||
return None
|
||||
|
||||
def load_model_checkpoint(self, checkpoint_path: Path) -> None:
|
||||
"""
|
||||
Load a checkpoint from the given path. We need to define a separate method since pytorch lightning cannot
|
||||
|
|
|
@ -251,7 +251,14 @@ class Runner:
|
|||
all_local_datasets=all_local_datasets, # type: ignore
|
||||
datastore=datastore,
|
||||
use_mounting=use_mounting)
|
||||
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
|
||||
|
||||
if self.experiment_config.strictly_aml_v1:
|
||||
hyperdrive_config = self.lightning_container.get_hyperdrive_config()
|
||||
hyperparam_args = None
|
||||
else:
|
||||
hyperparam_args = self.lightning_container.get_hyperparam_args()
|
||||
hyperdrive_config = None
|
||||
|
||||
if self.experiment_config.cluster and not is_running_in_azure_ml():
|
||||
ml_client = get_ml_client()
|
||||
|
||||
|
@ -278,6 +285,7 @@ class Runner:
|
|||
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
|
||||
docker_shm_size=self.experiment_config.docker_shm_size,
|
||||
hyperdrive_config=hyperdrive_config,
|
||||
hyperparam_args=hyperparam_args,
|
||||
create_output_folders=False,
|
||||
after_submission=after_submission_hook,
|
||||
tags=self.additional_run_tags(script_params),
|
||||
|
|
|
@ -219,7 +219,7 @@ def _test_hyperdrive_submission(mock_runner: Runner,
|
|||
expected_argument_name: str,
|
||||
expected_argument_values: List[str]) -> None:
|
||||
model_name = "HelloWorld"
|
||||
arguments = ["", f"--model={model_name}", "--cluster=foo", commandline_arg]
|
||||
arguments = ["", f"--model={model_name}", "--cluster=foo", commandline_arg, "--strictly_aml_v1=True"]
|
||||
# Use a special simplified environment file only for the tests here. Copy that to a temp folder, then let the runner
|
||||
# start in that temp folder.
|
||||
with change_working_folder_and_add_environment(mock_runner.project_root):
|
||||
|
|
Загрузка…
Ссылка в новой задаче