* First version of runner

* changelog and pyright

* PR comments

* Runs in AML

* More PR comments

* More PR comments

* Remove GenericConfig from himl

* Add more tests

* pytest and pyright

* Add more tests

* Add more tests

* Add documentation

* bug fix align config attributes

* works on config outside of hi-ml

* flake8 mypy pytest

* Add more tests for config loading

* PR comments and add more tests

* Respond to PR comments

* More PR comments

* pytest, mypy, flake8

* fix test running remotely

* debug why test fails on remote build

* debug why test fails on remote build

* debug why test fails on remote build

* fix test

* Fix tests

* debug test failing on azure agent

* Fix test and update docs

* update test

* Update command line tools

* Update coverage report to only read .py files

* debug coverage problem

* debug coverage failing

* debug coverage failing

* debug coverage failing

* PR comments
This commit is contained in:
mebristo 2022-02-03 09:12:58 +00:00 коммит произвёл GitHub
Родитель a33c1ed07d
Коммит c5a0348d36
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
44 изменённых файлов: 4215 добавлений и 341 удалений

8
.coveragerc Normal file
Просмотреть файл

@ -0,0 +1,8 @@
[report]
omit =
**/pytest
**/__init__.py
*/hello_container_2.py
[html]
skip_empty = true

Просмотреть файл

@ -1,4 +1,4 @@
[flake8]
max-line-length = 120
max-complexity = 25
ignore = E731
ignore = E731

Просмотреть файл

@ -15,6 +15,7 @@ the section headers (Added/Changed/...) and incrementing the package version.
### Added
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) Add GaussianBlur and RotationByMultiplesOf90 augmentations. Added torchvision and opencv to
the environment file since it is necessary for the augmentations.
- ([#178](https://github.com/microsoft/hi-ml/pull/178)) Add runner script for running ML experiments
### Changed

Просмотреть файл

@ -111,6 +111,7 @@ combine: pip_test
mkdir -p coverage
cp hi-ml/.coverage coverage/hi-ml-coverage
cp hi-ml-azure/.coverage coverage/hi-ml-azure-coverage
cp .coveragerc coverage/
cd coverage && \
coverage combine hi-ml-coverage hi-ml-azure-coverage && \
coverage html && \

Просмотреть файл

@ -7,8 +7,8 @@ REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help

Просмотреть файл

@ -3,5 +3,5 @@
.. automodapi:: health_azure
:no-inheritance-diagram:
.. automodapi:: health_ml.utils
.. automodapi:: health_ml
:no-inheritance-diagram:

Просмотреть файл

@ -33,6 +33,7 @@ The `hi-ml` toolbox provides
logging.md
diagnostics.md
runner.md
.. toctree::
:maxdepth: 1

178
docs/source/runner.md Normal file
Просмотреть файл

@ -0,0 +1,178 @@
# Running ML experiments with hi-ml
The hi-ml toolbox is capable of training any PyTorch Lighting (PL) model inside of AzureML, making
use of these features:
- Training on a local GPU machine or inside of AzureML without code changes
- Working with different models in the same codebase, and selecting one by name
- Distributed training in AzureML
- Logging via AzureML's native capabilities
This can be used by invoking the hi-ml runner and providing the name of the container class, like this:
`himl-runner --model=MyContainer`.
There is a fully working example [HelloContainer](../../hi-ml/src/health-ml/configs/hello_container.py), that
implements a simple 1-dimensional regression model from data stored in a CSV file. You can run that
from the command line by `himl-runner --model=HelloContainer`.
# Running ML experiments in Azure ML
To train in AzureML, add a `--azureml` flag. Use the flag `--cluster` to specify the name of the cluster
in your Workspace that you want to submit the job to. So the whole command would look like:
`himl-runner --model=HelloContainer --cluster=my_cluster_name --azureml`. You can also specify `--num_nodes` if
you wish to distribute the model training.
## Setup - creating your model config file
In order to use these capabilities, you need to implement a class deriving from
`health_ml.lightning_container.LightningContainer`. This class encapsulates everything that is needed for training
with PyTorch Lightning:
For example:
```python
class MyContainer(LightningContainer):
def __init__(self):
super().__init__()
self.azure_datasets = ["folder_name_in_azure_blob_storage"]
self.local_datasets = [Path("/some/local/path")]
self.max_epochs = 42
def create_model(self) -> LightningModule:
return MyLightningModel()
def get_data_module(self) -> LightningDataModule:
return MyDataModule(root_path=self.local_dataset)
```
The `create_model` method needs to return a subclass of PyTorch Lightning's [LightningModule](
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html?highlight=lightningmodule
), that has
all the usual PyTorch Lightning methods required for training, like the `training_step` and `forward` methods. E.g:
```python
class MyLightningModel(LightningModule):
def __init__(self):
self.layer = ...
def training_step(self, *args, **kwargs):
...
def forward(self, *args, **kwargs):
...
def configure_optimizers(self):
...
def test_step(self, *args, **kwargs):
...
```
The `get_data_module` method of the container needs to return a DataModule (inheriting from a [PyTorch Lightning DataModule](
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html)) which contains all of the logic for
downloading, preparing and splitting your dataset, as well as methods for wrapping the train, val and test datasets
respectively with [DataLoaders](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). E.g:
```python
class MyDataModule(LightningDataModule):
def __init__(self, root_path: Path):
# All data should be read from the folder given in self.root_path
self.root_path = root_path
def train_dataloader(self, *args, **kwargs) -> DataLoader:
# The data should be read off self.root_path
train_dataset = ...
return DataLoader(train_dataset, batch_size=5, num_workers=5)
def val_dataloader(self, *args, **kwargs) -> DataLoader:
# The data should be read off self.root_path
val_dataset = ...
return DataLoader(val_dataset, batch_size=5, num_workers=5)
def test_dataloader(self, *args, **kwargs) -> DataLoader:
# The data should be read off self.root_path
test_dataset = ...
return DataLoader(test_dataset, batch_size=5, num_workers=5)
```
So, the **full file** would look like:
```python
from pathlib import Path
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, LightningDataModule
from health_ml.lightning_container import LightningContainer
class MyLightningModel(LightningModule):
def __init__(self):
self.layer = ...
def training_step(self, *args, **kwargs):
...
def forward(self, *args, **kwargs):
...
def configure_optimizers(self):
...
def test_step(self, *args, **kwargs):
...
class MyDataModule(LightningDataModule):
def __init__(self, root_path: Path):
# All data should be read from the folder given in self.root_path
self.root_path = root_path
def train_dataloader(self, *args, **kwargs) -> DataLoader:
# The data should be read off self.root_path
train_dataset = ...
return DataLoader(train_dataset, batch_size=5, num_workers=5)
def val_dataloader(self, *args, **kwargs) -> DataLoader:
# The data should be read off self.root_path
val_dataset = ...
return DataLoader(val_dataset, batch_size=5, num_workers=5)
def test_dataloader(self, *args, **kwargs) -> DataLoader:
# The data should be read off self.root_path
test_dataset = ...
return DataLoader(test_dataset, batch_size=5, num_workers=5)
class MyContainer(LightningContainer):
def __init__(self):
super().__init__()
self.azure_datasets = ["folder_name_in_azure_blob_storage"]
self.local_datasets = [Path("/some/local/path")]
self.max_epochs = 42
def create_model(self) -> LightningModule:
return MyLightningModel()
def get_data_module(self) -> LightningDataModule:
return MyDataModule(root_path=self.local_dataset)
```
By default, config files will be looked for in the folder "health_ml.configs". To specify config files
that live elsewhere, use a fully qualified name for the parameter `--model` - e.g. "MyModule.Configs.my_config.py"
### Outputting files during training
The Lightning model returned by `create_model` needs to write its output files to the current working directory.
When running inside of AzureML, the output folders will be directly under the project root. If not running inside
AzureML, a folder with a timestamp will be created for all outputs and logs.
When running in AzureML, the folder structure will be set up such that all files written
to the current working directory are later uploaded to Azure blob storage at the end of the AzureML job. The files
will also be later available via the AzureML UI.
### Trainer arguments
All arguments that control the PyTorch Lightning `Trainer` object are defined in the class `TrainerParams`. A
`LightningContainer` object inherits from this class. The most essential one is the `max_epochs` field, which controls
the `max_epochs` argument of the `Trainer`.
For example:
```python
from pytorch_lightning import LightningModule, LightningDataModule
from health_ml.lightning_container import LightningContainer
class MyContainer(LightningContainer):
def __init__(self):
super().__init__()
self.max_epochs = 42
def create_model(self) -> LightningModule:
return MyLightningModel()
def get_data_module(self) -> LightningDataModule:
return MyDataModule(root_path=self.local_dataset)
```
### Optimizer and LR scheduler arguments
To the optimizer and LR scheduler: the Lightning model returned by `create_model` should define its own
`configure_optimizers` method, with the same signature as `LightningModule.configure_optimizers`,
and returns a tuple containing the Optimizer and LRScheduler objects

Просмотреть файл

@ -83,7 +83,7 @@ pytest_fast: pip_test call_pytest_fast
# run pytest with coverage on package, and format coverage output as a text file, assuming test requirements already installed
call_pytest_and_coverage:
pytest --cov=health_azure --cov-branch --cov-report=html --cov-report=term-missing --cov-report=xml testazure
pytest --cov=health_azure --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc testazure
pycobertura show --format text --output coverage.txt coverage.xml
# install test requirements and run pytest coverage

Просмотреть файл

@ -5,7 +5,7 @@
import logging
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union
from azureml.core import Dataset, Datastore, Workspace
from azureml.data import FileDataset, OutputFileDatasetConfig
@ -236,6 +236,61 @@ def _replace_string_datasets(datasets: List[StrOrDatasetConfig],
for d in datasets]
def create_dataset_configs(all_azure_dataset_ids: List[str],
all_dataset_mountpoints: Sequence[PathOrString],
all_local_datasets: List[Optional[Path]],
datastore: Optional[str] = None,
use_mounting: bool = False) -> List[DatasetConfig]:
"""
Sets up all the dataset consumption objects for the datasets provided. The returned list will have the same length
as there are non-empty azure dataset IDs.
Valid arguments combinations:
N azure datasets, 0 or N mount points, 0 or N local datasets
:param all_azure_dataset_ids: The name of all datasets on blob storage that will be used for this run.
:param all_dataset_mountpoints: When using the datasets in AzureML, these are the per-dataset mount points.
:param all_local_datasets: The paths for all local versions of the datasets.
:param datastore: The name of the AzureML datastore that holds the dataset. This can be empty if the AzureML
workspace has only a single datastore, or if the default datastore should be used.
:param use_mounting: If True, the dataset will be "mounted", that is, individual files will be read
or written on-demand over the network. If False, the dataset will be fully downloaded before the job starts,
respectively fully uploaded at job end for output datasets.
:return: A list of DatasetConfig objects, in the same order as datasets were provided in all_azure_dataset_ids,
omitting datasets with an empty name.
"""
datasets: List[DatasetConfig] = []
num_local = len(all_local_datasets)
num_azure = len(all_azure_dataset_ids)
num_mount = len(all_dataset_mountpoints)
if num_azure > 0 and (num_local == 0 or num_local == num_azure) and (num_mount == 0 or num_mount == num_azure):
# Test for valid settings: If we have N azure datasets, the local datasets and mount points need to either
# have exactly the same length, or 0. In the latter case, empty mount points and no local dataset will be
# assumed below.
count = num_azure
elif num_azure == 0 and num_mount == 0:
# No datasets in Azure at all: This is possible for runs that for example download their own data from the web.
# There can be any number of local datasets, but we are not checking that. In MLRunner.setup, there is a check
# that leaves local datasets intact if there are no Azure datasets.
return []
else:
raise ValueError("Invalid dataset setup. You need to specify N entries in azure_datasets and a matching "
"number of local_datasets and dataset_mountpoints")
for i in range(count):
azure_dataset = all_azure_dataset_ids[i] if i < num_azure else ""
if not azure_dataset:
continue
mount_point = all_dataset_mountpoints[i] if i < num_mount else ""
local_dataset = all_local_datasets[i] if i < num_local else None
config = DatasetConfig(name=azure_dataset,
target_folder=mount_point,
local_folder=local_dataset,
use_mounting=use_mounting,
datastore=datastore or "")
datasets.append(config)
return datasets
def find_workspace_for_local_datasets(aml_workspace: Optional[Workspace],
workspace_config_path: Optional[Path],
dataset_configs: List[DatasetConfig]) -> Optional[Workspace]:

Просмотреть файл

@ -138,6 +138,9 @@ def create_run_configuration(workspace: Workspace,
private_pip_wheel_path=private_pip_wheel_path,
docker_base_image=docker_base_image,
environment_variables=environment_variables)
conda_deps = new_environment.python.conda_dependencies
if conda_deps.get_python_version() is None:
raise ValueError("If specifying a conda environment file, you must specify the python version within it")
registered_env = register_environment(workspace, new_environment)
run_config.environment = registered_env
else:

Просмотреть файл

@ -4,12 +4,14 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import param
import sys
from pathlib import Path
from typing import List
from azureml.core import Run
import health_azure.utils as azure_util
from health_azure.himl import RUN_RECOVERY_FILE
class HimlDownloadConfig(azure_util.AmlRunScriptConfig):
@ -46,7 +48,8 @@ def retrieve_runs(download_config: HimlDownloadConfig) -> List[Run]:
if len(runs) == 0:
raise ValueError(f"Did not find any runs under the given experiment name: {download_config.experiment}")
else:
run_or_recovery_id = azure_util.get_most_recent_run_id(download_config.latest_run_file)
most_recent_run_path = download_config.latest_run_file or Path(RUN_RECOVERY_FILE)
run_or_recovery_id = azure_util.get_most_recent_run_id(most_recent_run_path)
runs = [azure_util.get_aml_run_from_run_id(run_or_recovery_id,
workspace_config_path=download_config.config_file)]
if len(runs) == 0:
@ -57,7 +60,9 @@ def retrieve_runs(download_config: HimlDownloadConfig) -> List[Run]:
def main() -> None: # pragma: no cover
download_config = HimlDownloadConfig.parse_args()
download_config = HimlDownloadConfig()
download_config = azure_util.parse_args_and_update_config(download_config, sys.argv[1:])
output_dir = download_config.output_dir
output_dir.mkdir(exist_ok=True)

Просмотреть файл

@ -108,8 +108,9 @@ class WrappedTensorboard(Tensorboard):
def main() -> None: # pragma: no cover
tb_config = HimlTensorboardConfig()
tb_config = azure_util.parse_args_and_update_config(tb_config, sys.argv[1:])
tb_config = HimlTensorboardConfig.parse_args()
config_path = tb_config.config_file
if not config_path:

Просмотреть файл

@ -10,9 +10,11 @@ import json
import logging
import os
import re
import sys
import tempfile
from argparse import ArgumentParser, OPTIONAL
from argparse import ArgumentParser, OPTIONAL, ArgumentError, _UNRECOGNIZED_ARGS_ATTR, Namespace, SUPPRESS
from collections import defaultdict
from dataclasses import dataclass
from itertools import islice
from pathlib import Path
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
@ -56,6 +58,7 @@ ENV_GLOBAL_RANK = "GLOBAL_RANK"
ENV_LOCAL_RANK = "LOCAL_RANK"
RUN_CONTEXT = Run.get_context()
PARENT_RUN_CONTEXT = getattr(RUN_CONTEXT, "parent", None)
WORKSPACE_CONFIG_JSON = "config.json"
# By default, define several environment variables that work around known issues in the software stack
@ -126,223 +129,314 @@ class GenericConfig(param.Parameterized):
"""
pass
def add_and_validate(self, kwargs: Dict[str, Any], validate: bool = True) -> None:
def set_fields_and_validate(config: param.Parameterized, fields_to_set: Dict[str, Any], validate: bool = True) -> None:
"""
Add further parameters and, if validate is True, validate. We first try set_param, but that
fails when the parameter has a setter.
:param config: The model configuration
:param fields_to_set: A dictionary of key, value pairs where each key represents a parameter to be added
and val represents its value
:param validate: Whether to validate the value of the parameter after adding.
"""
assert isinstance(config, param.Parameterized)
for key, value in fields_to_set.items():
try:
config.set_param(key, value)
except ValueError:
setattr(config, key, value)
if validate:
config.validate()
def create_argparser(config: param.Parameterized) -> ArgumentParser:
"""
Creates an ArgumentParser with all fields of the given config that are overridable.
:param config: The config whose parameters should be used to populate the argument parser
:return: ArgumentParser
"""
assert isinstance(config, param.Parameterized)
parser = ArgumentParser()
_add_overrideable_config_args_to_parser(config, parser)
return parser
def _add_overrideable_config_args_to_parser(config: param.Parameterized, parser: ArgumentParser) -> ArgumentParser:
"""
Adds all overridable fields of the config class to the given argparser.
Fields that are marked as readonly, constant or private are ignored.
:param parser: Parser to add properties to.
"""
def parse_bool(x: str) -> bool:
"""
Add further parameters and, if validate is True, validate. We first try set_param, but that
fails when the parameter has a setter.
Parse a string as a bool. Supported values are case insensitive and one of:
'on', 't', 'true', 'y', 'yes', '1' for True
'off', 'f', 'false', 'n', 'no', '0' for False.
:param kwargs: A dictionary of key, value pairs where each key represents a parameter to be added
and val represents its value
:param validate: Whether to validate the value of the parameter after adding.
:param x: string to test.
:return: Bool value if string valid, otherwise a ValueError is raised.
"""
for key, value in kwargs.items():
try:
self.set_param(key, value)
except ValueError:
setattr(self, key, value)
if validate:
self.validate()
sx = str(x).lower()
if sx in ('on', 't', 'true', 'y', 'yes', '1'):
return True
if sx in ('off', 'f', 'false', 'n', 'no', '0'):
return False
raise ValueError(f"Invalid value {x}, please supply one of True, true, false or False.")
@classmethod
def create_argparser(cls) -> ArgumentParser:
def _get_basic_type(_p: param.Parameter) -> Union[type, Callable]:
"""
Creates an ArgumentParser with all fields of the given argparser that are overridable.
Given a parameter, get its basic Python type, e.g.: param.Boolean -> bool.
Throw exception if it is not supported.
:return: ArgumentParser
:param _p: parameter to get type and nargs for.
:return: Type
"""
parser = ArgumentParser()
cls.add_args(parser)
if isinstance(_p, param.Boolean):
p_type: Callable = parse_bool
elif isinstance(_p, param.Integer):
p_type = lambda x: _p.default if x == "" else int(x)
elif isinstance(_p, param.Number):
p_type = lambda x: _p.default if x == "" else float(x)
elif isinstance(_p, param.String):
p_type = str
elif isinstance(_p, param.List):
p_type = lambda x: [_p.class_(item) for item in x.split(',')]
elif isinstance(_p, param.NumericTuple):
float_or_int = lambda y: int(y) if isinstance(_p, IntTuple) else float(y)
p_type = lambda x: tuple([float_or_int(item) for item in x.split(',')])
elif isinstance(_p, param.ClassSelector):
p_type = _p.class_
elif isinstance(_p, CustomTypeParam):
p_type = _p.from_string
return parser
else:
raise TypeError("Parameter of type: {} is not supported".format(_p))
@classmethod
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
return p_type
def add_boolean_argument(parser: ArgumentParser, k: str, p: param.Parameter) -> None:
"""
Adds all overridable fields of the current class to the given argparser.
Fields that are marked as readonly, constant or private are ignored.
Add a boolean argument.
If the parameter default is False then allow --flag (to set it True) and --flag=Bool as usual.
If the parameter default is True then allow --no-flag (to set it to False) and --flag=Bool as usual.
:param parser: Parser to add properties to.
:param parser: parser to add a boolean argument to.
:param k: argument name.
:param p: boolean parameter.
"""
if not p.default:
# If the parameter default is False then use nargs="?" (OPTIONAL).
# This means that the argument is optional.
# If it is not supplied, i.e. in the --flag mode, use the "const" value, i.e. True.
# Otherwise, i.e. in the --flag=value mode, try to parse the argument as a bool.
parser.add_argument("--" + k, help=p.doc, type=parse_bool, default=False,
nargs=OPTIONAL, const=True)
else:
# If the parameter default is True then create an exclusive group of arguments.
# Either --flag=value as usual
# Or --no-flag to store False in the parameter k.
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--" + k, help=p.doc, type=parse_bool)
group.add_argument('--no-' + k, dest=k, action='store_false')
parser.set_defaults(**{k: p.default})
def parse_bool(x: str) -> bool:
"""
Parse a string as a bool. Supported values are case insensitive and one of:
'on', 't', 'true', 'y', 'yes', '1' for True
'off', 'f', 'false', 'n', 'no', '0' for False.
for k, p in get_overridable_parameters(config).items():
# param.Booleans need to be handled separately, they are more complicated because they have
# an optional argument.
if isinstance(p, param.Boolean):
add_boolean_argument(parser, k, p)
else:
parser.add_argument("--" + k, help=p.doc, type=_get_basic_type(p), default=p.default)
:param x: string to test.
:return: Bool value if string valid, otherwise a ValueError is raised.
"""
sx = str(x).lower()
if sx in ('on', 't', 'true', 'y', 'yes', '1'):
return True
if sx in ('off', 'f', 'false', 'n', 'no', '0'):
return False
raise ValueError(f"Invalid value {x}, please supply one of True, true, false or False.")
return parser
def _get_basic_type(_p: param.Parameter) -> Union[type, Callable]:
"""
Given a parameter, get its basic Python type, e.g.: param.Boolean -> bool.
Throw exception if it is not supported.
:param _p: parameter to get type and nargs for.
:return: Type
"""
if isinstance(_p, param.Boolean):
p_type: Callable = parse_bool
elif isinstance(_p, param.Integer):
p_type = lambda x: _p.default if x == "" else int(x)
elif isinstance(_p, param.Number):
p_type = lambda x: _p.default if x == "" else float(x)
elif isinstance(_p, param.String):
p_type = str
elif isinstance(_p, param.List):
p_type = lambda x: [_p.class_(item) for item in x.split(',')]
elif isinstance(_p, param.NumericTuple):
float_or_int = lambda y: int(y) if isinstance(_p, IntTuple) else float(y)
p_type = lambda x: tuple([float_or_int(item) for item in x.split(',')])
elif isinstance(_p, param.ClassSelector):
p_type = _p.class_
elif isinstance(_p, CustomTypeParam):
p_type = _p.from_string
@dataclass
class ParserResult:
"""
Stores the results of running an argument parser, broken down into a argument-to-value dictionary,
arguments that the parser does not recognize.
"""
args: Dict[str, Any]
unknown: List[str]
overrides: Dict[str, Any]
def _create_default_namespace(parser: ArgumentParser) -> Namespace:
"""
Creates an argparse Namespace with all parser-specific default values set.
:param parser: The parser to work with.
:return: the Namespace object
"""
# This is copy/pasted from parser.parse_known_args
namespace = Namespace()
for action in parser._actions:
if action.dest is not SUPPRESS:
if not hasattr(namespace, action.dest):
if action.default is not SUPPRESS:
setattr(namespace, action.dest, action.default)
for dest in parser._defaults:
if not hasattr(namespace, dest):
setattr(namespace, dest, parser._defaults[dest])
return namespace
def parse_arguments(parser: ArgumentParser,
fail_on_unknown_args: bool = False,
args: List[str] = None) -> ParserResult:
"""
Parses a list of commandline arguments with a given parser. Returns results broken down into a full
arguments dictionary, a dictionary of arguments that were set to non-default values, and unknown
arguments.
:param parser: The parser to use
:param fail_on_unknown_args: If True, raise an exception if the parser encounters an argument that it does
not recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of
the parser result.
:param args: Arguments to parse. If not given, use those in sys.argv
:return: The parsed arguments, and overrides
"""
if args is None:
args = sys.argv[1:]
# The following code is a slightly modified version of what happens in parser.parse_known_args. This had to be
# copied here because otherwise we would not be able to achieve the priority order that we desire.
namespace = _create_default_namespace(parser)
try:
namespace, unknown = parser._parse_known_args(args, namespace)
if hasattr(namespace, _UNRECOGNIZED_ARGS_ATTR):
unknown.extend(getattr(namespace, _UNRECOGNIZED_ARGS_ATTR))
delattr(namespace, _UNRECOGNIZED_ARGS_ATTR)
except ArgumentError:
parser.print_usage(sys.stderr)
err = sys.exc_info()[1]
parser._print_message(str(err), sys.stderr)
raise
# Parse the arguments a second time, without supplying defaults, to see which arguments actually differ
# from defaults.
namespace_without_defaults, _ = parser._parse_known_args(args, Namespace())
parsed_args = vars(namespace).copy()
overrides = vars(namespace_without_defaults).copy()
if len(unknown) > 0 and fail_on_unknown_args:
raise ValueError(f'Unknown arguments: {unknown}')
return ParserResult(
args=parsed_args,
unknown=unknown,
overrides=overrides,
)
def parse_args_and_update_config(config: Any, args: List[str]) -> Any:
"""
Given a model config and a list of command line arguments, creates an argparser, adds arguments from the config
parses the list of provided args and updates the config accordingly. Returns the updated config
:param config: The model configuration
:param args: A list of command line args to parse
:return: The config, updated with the values of the provided args
"""
parser = create_argparser(config)
parser_results = parse_arguments(parser, args=args)
_ = apply_overrides(config, parser_results.args)
return config
def get_overridable_parameters(config: Any) -> Dict[str, param.Parameter]:
"""
Get properties that are not constant, readonly or private (eg: prefixed with an underscore).
:param config: The model configuration
:return: A dictionary of parameter names and their definitions.
"""
assert isinstance(config, param.Parameterized)
return dict((k, v) for k, v in config.params().items()
if reason_not_overridable(v) is None)
def reason_not_overridable(value: param.Parameter) -> Optional[str]:
"""
Given a parameter, check for attributes that denote it is not overrideable (e.g. readonly, constant,
private etc). If such an attribute exists, return a string containing a single-word description of the
reason. Otherwise returns None.
:param value: a parameter value
:return: None if the parameter is overridable; otherwise a one-word string explaining why not.
"""
if value.readonly:
return "readonly"
elif value.constant:
return "constant"
elif is_private_field_name(value.name):
return "private"
elif isinstance(value, param.Callable):
return "callable"
return None
def apply_overrides(config: Any, overrides_to_apply: Optional[Dict[str, Any]], should_validate: bool = False,
keys_to_ignore: Optional[Set[str]] = None) -> Dict[str, Any]:
"""
Applies the provided `values` overrides to the config.
Only properties that are marked as overridable are actually overwritten.
:param config: The model configuration
:param overrides_to_apply: A dictionary mapping from field name to value.
:param should_validate: If true, run the .validate() method after applying overrides.
:param keys_to_ignore: keys to ignore in reporting failed overrides. If None, do not report.
:return: A dictionary with all the fields that were modified.
"""
def _apply(_overrides: Optional[Dict[str, Any]]) -> Dict[str, Any]:
applied: Dict[str, Any] = {}
if _overrides is not None:
overridable_parameters = get_overridable_parameters(config).keys()
for k, v in _overrides.items():
if k in overridable_parameters:
applied[k] = v
setattr(config, k, v)
return applied
actual_overrides = _apply(overrides_to_apply)
if keys_to_ignore is not None:
report_on_overrides(config, overrides_to_apply, keys_to_ignore) # type: ignore
if should_validate:
config.validate()
return actual_overrides
def report_on_overrides(config: Any, overrides_to_apply: Dict[str, Any], keys_to_ignore: Set[str]) -> None:
"""
Logs a warning for every parameter whose value is not as given in "overrides_to_apply", other than those
in keys_to_ignore.
:param config: The model configuration
:param overrides_to_apply: override dictionary, parameter names to values
:param keys_to_ignore: set of dictionary keys not to report on
"""
assert isinstance(config, param.Parameterized)
for key, desired in overrides_to_apply.items():
if key in keys_to_ignore:
continue
actual = getattr(config, key, None)
if actual == desired:
continue
if key not in config.params():
reason = "parameter is undefined"
else:
val = config.params()[key]
reason = reason_not_overridable(val) # type: ignore
if reason is None:
reason = "for UNKNOWN REASONS"
else:
raise TypeError("Parameter of type: {} is not supported".format(_p))
return p_type
def add_boolean_argument(parser: ArgumentParser, k: str, p: param.Parameter) -> None:
"""
Add a boolean argument.
If the parameter default is False then allow --flag (to set it True) and --flag=Bool as usual.
If the parameter default is True then allow --no-flag (to set it to False) and --flag=Bool as usual.
:param parser: parser to add a boolean argument to.
:param k: argument name.
:param p: boolean parameter.
"""
if not p.default:
# If the parameter default is False then use nargs="?" (OPTIONAL).
# This means that the argument is optional.
# If it is not supplied, i.e. in the --flag mode, use the "const" value, i.e. True.
# Otherwise, i.e. in the --flag=value mode, try to parse the argument as a bool.
parser.add_argument("--" + k, help=p.doc, type=parse_bool, default=False,
nargs=OPTIONAL, const=True)
else:
# If the parameter default is True then create an exclusive group of arguments.
# Either --flag=value as usual
# Or --no-flag to store False in the parameter k.
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--" + k, help=p.doc, type=parse_bool)
group.add_argument('--no-' + k, dest=k, action='store_false')
parser.set_defaults(**{k: p.default})
for k, p in cls.get_overridable_parameters().items():
# param.Booleans need to be handled separately, they are more complicated because they have
# an optional argument.
if isinstance(p, param.Boolean):
add_boolean_argument(parser, k, p)
else:
parser.add_argument("--" + k, help=p.doc, type=_get_basic_type(p), default=p.default)
return parser
@classmethod
def parse_args(cls: Type[T], args: Optional[List[str]] = None) -> T:
"""
Creates an argparser based on the params class and parses stdin args (or the args provided)
:param args: The arguments to be parsed
"""
return cls(**vars(cls.create_argparser().parse_args(args))) # type: ignore
@classmethod
def get_overridable_parameters(cls) -> Dict[str, param.Parameter]:
"""
Get properties that are not constant, readonly or private (eg: prefixed with an underscore).
:return: A dictionary of parameter names and their definitions.
"""
return dict((k, v) for k, v in cls.params().items()
if cls.reason_not_overridable(v) is None)
@staticmethod
def reason_not_overridable(value: param.Parameter) -> Optional[str]:
"""
Given a parameter, check for attributes that denote it is not overrideable (e.g. readonly, constant,
private etc). If such an attribute exists, return a string containing a single-word description of the
reason. Otherwise returns None.
:param value: a parameter value
:return: None if the parameter is overridable; otherwise a one-word string explaining why not.
"""
if value.readonly:
return "readonly"
elif value.constant:
return "constant"
elif is_private_field_name(value.name):
return "private"
elif isinstance(value, param.Callable):
return "callable"
return None
def apply_overrides(self, values: Optional[Dict[str, Any]], should_validate: bool = True,
keys_to_ignore: Optional[Set[str]] = None) -> Dict[str, Any]:
"""
Applies the provided `values` overrides to the config.
Only properties that are marked as overridable are actually overwritten.
:param values: A dictionary mapping from field name to value.
:param should_validate: If true, run the .validate() method after applying overrides.
:param keys_to_ignore: keys to ignore in reporting failed overrides. If None, do not report.
:return: A dictionary with all the fields that were modified.
"""
def _apply(_overrides: Optional[Dict[str, Any]]) -> Dict[str, Any]:
applied: Dict[str, Any] = {}
if _overrides is not None:
overridable_parameters = self.get_overridable_parameters().keys()
for k, v in _overrides.items():
if k in overridable_parameters:
applied[k] = v
setattr(self, k, v)
return applied
actual_overrides = _apply(values)
if keys_to_ignore is not None:
self.report_on_overrides(values, keys_to_ignore) # type: ignore
if should_validate:
self.validate()
return actual_overrides
def report_on_overrides(self, values: Dict[str, Any], keys_to_ignore: Optional[Set[str]] = None) -> None:
"""
Logs a warning for every parameter whose value is not as given in "values", other than those
in keys_to_ignore.
:param values: override dictionary, parameter names to values
:param keys_to_ignore: set of dictionary keys not to report on
:return: None
"""
for key, desired in values.items():
# If this isn't an AzureConfig instance, we don't want to warn on keys intended for it.
if keys_to_ignore and (key in keys_to_ignore):
continue
actual = getattr(self, key, None)
if actual == desired:
continue
if key not in self.params():
reason = "parameter is undefined"
else:
val = self.params()[key]
reason = self.reason_not_overridable(val) # type: ignore
if reason is None:
reason = "for UNKNOWN REASONS"
else:
reason = f"parameter is {reason}"
# We could raise an error here instead - to be discussed.
logging.warning(f"Override {key}={desired} failed: {reason} in class {self.__class__.name}")
reason = f"parameter is {reason}"
# We could raise an error here instead - to be discussed.
logging.warning(f"Override {key}={desired} failed: {reason} in class {config.__class__.name}")
def create_from_matching_params(from_object: param.Parameterized, cls_: Type[T]) -> T:
@ -565,6 +659,7 @@ def _find_file(file_name: str, stop_at_pythonpath: bool = True) -> Optional[Path
file_name: str,
stop_at_pythonpath: bool,
pythonpaths: List[Path]) -> Optional[Path]:
logging.debug(f"Searching for file {file_name} in {start_at}")
expected = start_at / file_name
if expected.is_file() and expected.name == file_name:
@ -761,23 +856,72 @@ def _log_conda_dependencies_stats(conda: CondaDependencies, message_prefix: str)
logging.debug(f" {p}")
def merge_conda_files(files: List[Path], result_file: Path) -> None:
def _retrieve_unique_deps(dependencies: List[str], keep_method: str = "first") -> List[str]:
"""
Merges the given Conda environment files using the conda_merge package, and writes the merged file to disk.
Given a list of conda dependencies, which may contain duplicate versions
of the same package name with the same or different versions, returns a
list of them where each package name occurs only once. If a
package name appears more than once, only the first value will be retained.
:param files: The Conda environment files to read.
:param result_file: The location where the merge results should be written.
:param dependencies: The original list of package names to deduplicate
:param keep_method: The strategy for choosing which package version to keep
:return: a list in which each package name occurs only once
"""
for file in files:
_log_conda_dependencies_stats(CondaDependencies(file), f"Conda environment in {file}")
# This code is a slightly modified version of conda_merge. That code can't be re-used easily
# it defaults to writing to stdout
env_definitions = [conda_merge.read_file(str(f)) for f in files]
unique_deps: Dict[str, Tuple[str, str]] = {}
for dep in dependencies:
dep_parts: List[str] = re.split("(=<|==|=|>=|<|>)", dep)
len_parts = len(dep_parts)
dep_name = dep_parts[0]
if len_parts > 1:
dep_join = ''.join(dep_parts[1:-1])
dep_version = dep_parts[-1]
else:
dep_join = ''
dep_version = ''
if dep_name in unique_deps:
if keep_method == "first":
keep_version, _ = unique_deps[dep_name]
elif keep_method == "last":
keep_version = dep_version
unique_deps[dep_name] = (keep_version, dep_join)
else:
raise ValueError(f"Unrecognised value of 'keep_method: {keep_method}'. Accepted values"
f" include: ['first', 'last']")
logging.warning(f"Found duplicate requirements: {dep}. Keeping the {keep_method} "
f"version: {keep_version}")
else:
unique_deps[dep_name] = (dep_version, dep_join)
unique_deps_list = [f"{pkg}{joiner}{vrsn}" for pkg, (vrsn, joiner) in unique_deps.items()]
return unique_deps_list
def merge_conda_files(conda_files: List[Path], result_file: Path, pip_files: List[Path] = None,
pip_clash_keep_method: str = "first") -> None:
"""
Merges the given Conda environment files using the conda_merge package, optionally adds any
dependencies from pip requirements files, and writes the merged file to disk.
:param conda_files: The Conda environment files to read.
:param result_file: The location where the merge results should be written.
:param pip_files: An optional list of one or more pip requirements files including extra dependencies.
:param pip_clash_keep_method: If two or more pip packages are specified with the same name, this determines
which one should be kept. Current options: ['first', 'last']
"""
env_definitions = [conda_merge.read_file(str(f)) for f in conda_files]
unified_definition = {}
NAME = "name"
CHANNELS = "channels"
DEPENDENCIES = "dependencies"
extra_pip_deps = []
for pip_file in pip_files or []:
with open(pip_file, "r") as f_path:
additional_pip_deps = [d for d in f_path.read().split("\n") if d]
extra_pip_deps.extend(additional_pip_deps)
name = conda_merge.merge_names(env.get(NAME) for env in env_definitions)
if name:
unified_definition[NAME] = name
@ -791,12 +935,34 @@ def merge_conda_files(files: List[Path], result_file: Path) -> None:
unified_definition[CHANNELS] = channels
try:
deps = conda_merge.merge_dependencies(env.get(DEPENDENCIES) for env in env_definitions)
deps_to_merge = [env.get(DEPENDENCIES) for env in env_definitions]
if len(extra_pip_deps) > 0:
deps_to_merge.extend([[{"pip": extra_pip_deps}]])
deps = conda_merge.merge_dependencies(deps_to_merge)
# Remove duplicated pip packages from merged dependencies sections. Note that for a package that is
# duplicated, the first value encountered will be retained.
pip_deps_entries = [d for d in deps if isinstance(d, dict) and "pip" in d] # type: ignore
if len(pip_deps_entries) == 0:
raise ValueError("Didn't find a dictionary with the key 'pip' in the list of dependencies")
pip_deps_entry: Dict[str, List[str]] = pip_deps_entries[0]
pip_deps = pip_deps_entry["pip"]
# temporarily remove pip dependencies from deps to be added back after deduplicaton
deps.remove(pip_deps_entry)
# remove all non-pip duplicates from the list of dependencies
unique_deps = _retrieve_unique_deps(deps, keep_method=pip_clash_keep_method)
unique_pip_deps = _retrieve_unique_deps(pip_deps, keep_method=pip_clash_keep_method)
# finally add back the deduplicated list of dependencies
unique_deps.append({"pip": unique_pip_deps}) # type: ignore
except conda_merge.MergeError:
logging.error("Failed to merge dependencies.")
raise
if deps:
unified_definition[DEPENDENCIES] = deps
if unique_deps:
unified_definition[DEPENDENCIES] = unique_deps
else:
raise ValueError("No dependencies found in any of the conda files.")
@ -1230,7 +1396,7 @@ def upload_to_datastore(datastore_name: str, local_data_folder: Path, remote_pat
logging.info(f"Uploaded data to {str(remote_path)}")
class AmlRunScriptConfig(GenericConfig):
class AmlRunScriptConfig(param.Parameterized):
"""
Base config for a script that handles Azure ML Runs, which can be retrieved with either a run id, latest_run_file,
or by giving the experiment name (optionally alongside tags and number of runs to retrieve). A config file path can
@ -1319,6 +1485,16 @@ def is_running_in_azure_ml(aml_run: Run = RUN_CONTEXT) -> bool:
return hasattr(aml_run, 'experiment')
def is_running_on_azure_agent() -> bool:
"""
Determine whether the current code is running on an Azure agent by examing the environment variable
for AGENT_OS, that all Azure hosted agents define.
:return: True if the code appears to be running on an Azure build agent, and False otherwise.
"""
return bool(os.environ.get("AGENT_OS", None))
def torch_barrier() -> None:
"""
This is a barrier to use in distributed jobs. Use it to make all processes that participate in a distributed

Просмотреть файл

@ -10,10 +10,10 @@ import logging
import os
import sys
import time
from argparse import ArgumentParser, Namespace, ArgumentError
from enum import Enum
from pathlib import Path
from random import randint
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from unittest import mock
from unittest.mock import MagicMock, patch
from uuid import uuid4
@ -36,6 +36,7 @@ from testazure.test_himl import RunTarget, render_and_run_test_script
from testazure.utils_testazure import (DEFAULT_IGNORE_FOLDERS, DEFAULT_WORKSPACE, MockRun, change_working_directory,
repository_root)
RUN_ID = uuid4().hex
RUN_NUMBER = 42
EXPERIMENT_NAME = "fancy-experiment"
@ -252,6 +253,16 @@ def test_split_recovery_id(id: str, expected1: str, expected2: str) -> None:
assert util.split_recovery_id(id) == (expected1, expected2)
def test_retrieve_unique_deps() -> None:
deps_with_duplicates = ["package==1.0", "package==1.1", "git+https:www.github.com/something.git"]
dedup_deps = util._retrieve_unique_deps(deps_with_duplicates) # type: ignore
assert dedup_deps == ["package==1.0", "git+https:www.github.com/something.git"]
dedup_deps_keep_last = util._retrieve_unique_deps(deps_with_duplicates, keep_method="last")
assert dedup_deps_keep_last == ["package==1.1", "git+https:www.github.com/something.git"]
def test_merge_conda(
random_folder: Path,
caplog: CaptureFixture,
@ -298,12 +309,10 @@ dependencies:
- pytorch
dependencies:
- conda1=1.0
- conda1=1.1
- conda2=2.0
- conda_both=3.0
- pip:
- azureml-sdk==1.6.0
- azureml-sdk==1.7.0
- bar==2.0
- foo==1.0
""".splitlines()
@ -313,8 +322,30 @@ dependencies:
assert list(conda_dep.conda_channels) == ["defaults", "pytorch"]
# Package version conflicts are not resolved, both versions are retained.
assert list(conda_dep.conda_packages) == ["conda1=1.0", "conda1=1.1", "conda2=2.0", "conda_both=3.0"]
assert list(conda_dep.pip_packages) == ["azureml-sdk==1.6.0", "azureml-sdk==1.7.0", "bar==2.0", "foo==1.0"]
assert list(conda_dep.conda_packages) == ["conda1=1.0", "conda2=2.0", "conda_both=3.0"]
assert list(conda_dep.pip_packages) == ["azureml-sdk==1.6.0", "bar==2.0", "foo==1.0"]
# Assert that extra pip requirements are added correctly
pip_contents = """package1==0.0.1
package2==0.0.1
"""
pip_file = random_folder / "req.txt"
pip_file.write_text(pip_contents)
util.merge_conda_files(files, merged_file, pip_files=[pip_file])
merged_file_text = merged_file.read_text()
assert merged_file_text.splitlines() == """channels:
- defaults
- pytorch
dependencies:
- conda1=1.0
- conda2=2.0
- conda_both=3.0
- pip:
- azureml-sdk==1.6.0
- bar==2.0
- foo==1.0
- package1==0.0.1
- package2==0.0.1""".splitlines()
# Are names merged correctly?
assert "name:" not in merged_file_text
@ -349,9 +380,8 @@ dependencies:
# If there are no dependencies then something is wrong with the conda files or our parsing of them
with mock.patch("health_azure.utils.conda_merge.merge_dependencies") as mock_merge_dependencies:
mock_merge_dependencies.return_value = []
with pytest.raises(ValueError) as e:
with pytest.raises(ValueError):
util.merge_conda_files(files, merged_file)
assert "No dependencies found in any of the conda files" in str(e.value)
@pytest.mark.parametrize(["s", "expected"],
@ -933,12 +963,13 @@ def test_get_run_source(dummy_recovery_id: str,
arguments = ["", "--run", dummy_recovery_id]
with patch.object(sys, "argv", arguments):
run_source = util.AmlRunScriptConfig.parse_args()
script_config = util.AmlRunScriptConfig()
script_config = util.parse_args_and_update_config(script_config, arguments)
if isinstance(run_source.run, List):
assert isinstance(run_source.run[0], str)
if isinstance(script_config.run, List):
assert isinstance(script_config.run[0], str)
else:
assert isinstance(run_source.run, str)
assert isinstance(script_config.run, str)
@pytest.mark.parametrize("overwrite", [True, False])
@ -1036,7 +1067,8 @@ def test_upload_to_datastore(tmp_path: Path, overwrite: bool, show_progress: boo
])
def test_script_config_run_src(arguments: List[str], run_id: Union[str, List[str]]) -> None:
with patch.object(sys, "argv", arguments):
script_config = util.AmlRunScriptConfig.parse_args()
script_config = util.AmlRunScriptConfig()
script_config = util.parse_args_and_update_config(script_config, arguments)
if isinstance(run_id, list):
for script_config_run, expected_run_id in zip(script_config.run, run_id):
@ -1172,7 +1204,65 @@ class IllegalCustomTypeNoValidate(util.CustomTypeParam):
return x
class ParamClass(util.GenericConfig):
class DummyConfig(param.Parameterized):
string_param = param.String()
int_param = param.Integer()
def validate(self) -> None:
assert isinstance(self.string_param, str)
assert isinstance(self.int_param, int)
@pytest.fixture(scope="module")
def dummy_model_config() -> DummyConfig:
string_param = "dummy"
int_param = 1
return DummyConfig(param1=string_param, param2=int_param)
def test_add_and_validate(dummy_model_config: DummyConfig) -> None:
new_string_param = "new_dummy"
new_int_param = 2
new_args = {"string_param": new_string_param, "int_param": new_int_param}
util.set_fields_and_validate(dummy_model_config, new_args)
assert dummy_model_config.string_param == new_string_param
assert dummy_model_config.int_param == new_int_param
def test_create_argparse(dummy_model_config: DummyConfig) -> None:
with patch("health_azure.utils._add_overrideable_config_args_to_parser") as mock_add_args:
parser = util.create_argparser(dummy_model_config)
mock_add_args.assert_called_once()
assert isinstance(parser, ArgumentParser)
def test_add_args(dummy_model_config: DummyConfig) -> None:
parser = ArgumentParser()
# assert that calling parse_args on a default ArgumentParser returns an empty Namespace
args = parser.parse_args([])
assert args == Namespace()
# now call _add_overrideable_config_args_to_parser and assert that calling parse_args on the result
# of that is a non-empty Namepsace
with patch("health_azure.utils.get_overridable_parameters") as mock_get_overridable_parameters:
mock_get_overridable_parameters.return_value = {"string_param": param.String(default="Hello")}
parser = util._add_overrideable_config_args_to_parser(dummy_model_config, parser)
assert isinstance(parser, ArgumentParser)
args = parser.parse_args([])
assert args != Namespace()
assert args.string_param == "Hello"
def test_parse_args(dummy_model_config: DummyConfig) -> None:
new_string_arg = "dummy_string"
new_args = ["--string_param", new_string_arg]
parser = ArgumentParser()
parser.add_argument("--string_param", type=str, default=None)
parser_result = util.parse_arguments(parser, args=new_args)
assert parser_result.args.get("string_param") == new_string_arg
class ParamClass(param.Parameterized):
name: str = param.String(None, doc="Name")
seed: int = param.Integer(42, doc="Seed")
flag: bool = param.Boolean(False, doc="Flag")
@ -1190,6 +1280,9 @@ class ParamClass(util.GenericConfig):
constant: str = param.String("Nope", constant=True)
other_args = util.ListOrDictParam(None, doc="List or dictionary of other args")
def validate(self) -> None:
pass
class ClassFrom(param.Parameterized):
foo = param.String("foo")
@ -1210,12 +1303,20 @@ class NotParameterized:
foo = 1
@pytest.fixture(scope="module")
def parameterized_config_and_parser() -> Tuple[ParamClass, ArgumentParser]:
parameterized_config = ParamClass()
parser = util.create_argparser(parameterized_config)
return parameterized_config, parser
@pytest.mark.fast
def test_overridable_parameter() -> None:
def test_get_overridable_parameter(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser]) -> None:
"""
Test to check overridable parameters are correctly identified.
"""
param_dict = ParamClass.get_overridable_parameters()
parameterized_config = parameterized_config_and_parser[0]
param_dict = util.get_overridable_parameters(parameterized_config)
assert "name" in param_dict
assert "flag" in param_dict
assert "not_flag" in param_dict
@ -1235,12 +1336,13 @@ def test_overridable_parameter() -> None:
@pytest.mark.fast
def test_parser_defaults() -> None:
def test_parser_defaults(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser]) -> None:
"""
Check that default values are created as expected, and that the non-overridable parameters
are omitted.
"""
defaults = vars(ParamClass.create_argparser().parse_args([]))
parameterized_config = parameterized_config_and_parser[0]
defaults = vars(util.create_argparser(parameterized_config).parse_args([]))
assert defaults["seed"] == 42
assert defaults["tuple1"] == (1, 2.3)
assert defaults["int_tuple"] == (1, 1, 1)
@ -1254,17 +1356,20 @@ def test_parser_defaults() -> None:
# upon errors.
def check_parsing_succeeds(arg: List[str], expected_key: str, expected_value: Any) -> None:
parsed = ParamClass.parse_args(arg)
assert getattr(parsed, expected_key) == expected_value
def check_parsing_succeeds(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
arg: List[str],
expected_key: str,
expected_value: Any) -> None:
parameterized_config, parser = parameterized_config_and_parser
parser_result = util.parse_arguments(parser, args=arg)
assert parser_result.args.get(expected_key) == expected_value
def check_parsing_fails(arg: List[str], expected_key: Optional[str] = None, expected_value: Optional[Any] = None
) -> None:
with pytest.raises(SystemExit) as e:
ParamClass.parse_args(arg)
assert e.type == SystemExit
assert e.value.code == 2
def check_parsing_fails(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
arg: List[str]) -> None:
parameterized_config, parser = parameterized_config_and_parser
with pytest.raises(Exception):
util.parse_arguments(parser, args=arg, fail_on_unknown_args=True)
@pytest.mark.fast
@ -1296,14 +1401,18 @@ def check_parsing_fails(arg: List[str], expected_key: Optional[str] = None, expe
(["--other_args={'learning':3"], None, None, False),
(["--other_args=['foo','bar'"], None, None, False)
])
def test_create_parser(args: List[str], expected_key: str, expected_value: Any, expected_pass: bool) -> None:
def test_create_parser(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
args: List[str],
expected_key: str,
expected_value: Any,
expected_pass: bool) -> None:
"""
Check that parse_args works as expected, with both non default and default values.
"""
if expected_pass:
check_parsing_succeeds(args, expected_key, expected_value)
check_parsing_succeeds(parameterized_config_and_parser, args, expected_key, expected_value)
else:
check_parsing_fails(args)
check_parsing_fails(parameterized_config_and_parser, args)
@pytest.mark.fast
@ -1311,73 +1420,106 @@ def test_create_parser(args: List[str], expected_key: str, expected_value: Any,
('on', True), ('t', True), ('true', True), ('y', True), ('yes', True), ('1', True),
('off', False), ('f', False), ('false', False), ('n', False), ('no', False), ('0', False)
])
def test_parsing_bools(flag: str, expected_value: bool) -> None:
def test_parsing_bools(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
flag: str,
expected_value: bool) -> None:
"""
Check all the ways of passing in True and False, with and without the first letter capitialized
"""
check_parsing_succeeds([f"--flag={flag}"], "flag", expected_value)
check_parsing_succeeds([f"--flag={flag.capitalize()}"], "flag", expected_value)
check_parsing_succeeds([f"--not_flag={flag}"], "not_flag", expected_value)
check_parsing_succeeds([f"--not_flag={flag.capitalize()}"], "not_flag", expected_value)
check_parsing_succeeds(parameterized_config_and_parser,
[f"--flag={flag}"],
"flag",
expected_value)
check_parsing_succeeds(parameterized_config_and_parser,
[f"--flag={flag.capitalize()}"],
"flag",
expected_value)
check_parsing_succeeds(parameterized_config_and_parser,
[f"--not_flag={flag}"],
"not_flag",
expected_value)
check_parsing_succeeds(parameterized_config_and_parser,
[f"--not_flag={flag.capitalize()}"],
"not_flag",
expected_value)
@pytest.mark.fast
@patch("health_azure.utils.GenericConfig.report_on_overrides")
@patch("health_azure.utils.GenericConfig.validate")
def test_apply_overrides(mock_validate: MagicMock, mock_report_on_overrides: MagicMock) -> None:
def test_apply_overrides(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser]) -> None:
"""
Test that overrides are applied correctly, ond only to overridable parameters,
Test that overrides are applied correctly, ond only to overridable parameters
"""
m = ParamClass()
overrides = {"name": "newName", "int_tuple": (0, 1, 2)}
actual_overrides = m.apply_overrides(overrides)
assert actual_overrides == overrides
assert all([x == i and isinstance(x, int) for i, x in enumerate(m.int_tuple)])
assert m.name == "newName"
# Attempt to change seed and constant, but the latter should be ignored.
change_seed = {"seed": 123}
old_constant = m.constant
changes2 = m.apply_overrides({**change_seed, "constant": "Nothing"}) # type: ignore
assert changes2 == change_seed
assert m.seed == 123
assert m.constant == old_constant
parameterized_config = parameterized_config_and_parser[0]
with patch("health_azure.utils.report_on_overrides") as mock_report_on_overrides:
overrides = {"name": "newName", "int_tuple": (0, 1, 2)}
actual_overrides = util.apply_overrides(parameterized_config, overrides)
assert actual_overrides == overrides
assert all([x == i and isinstance(x, int) for i, x in enumerate(parameterized_config.int_tuple)])
assert parameterized_config.name == "newName"
# Check the call count of mock_validate and check it doesn't increase if should_validate is set to False
# and that setting this flag doesn't affect on the outputs
mock_validate_call_count = mock_validate.call_count
actual_overrides = m.apply_overrides(values=overrides, should_validate=False)
assert actual_overrides == overrides
assert mock_validate.call_count == mock_validate_call_count
# Attempt to change seed and constant, but the latter should be ignored.
change_seed = {"seed": 123}
old_constant = parameterized_config.constant
extra_overrides = {**change_seed, "constant": "Nothing"} # type: ignore
changes2 = util.apply_overrides(parameterized_config, overrides_to_apply=extra_overrides) # type: ignore
assert changes2 == change_seed
assert parameterized_config.seed == 123
assert parameterized_config.constant == old_constant
# Check that report_on_overrides has not yet been called, but is called if keys_to_ignore is not None
# and that setting this flag doesn't affect on the outputs
assert mock_report_on_overrides.call_count == 0
actual_overrides = m.apply_overrides(values=overrides, keys_to_ignore={"name"})
assert actual_overrides == overrides
assert mock_report_on_overrides.call_count == 1
# Check the call count of mock_validate and check it doesn't increase if should_validate is set to False
# and that setting this flag doesn't affect on the outputs
# mock_validate_call_count = mock_validate.call_count
actual_overrides = util.apply_overrides(parameterized_config,
overrides_to_apply=overrides,
should_validate=False)
assert actual_overrides == overrides
# assert mock_validate.call_count == mock_validate_call_count
# Check that report_on_overrides has not yet been called, but is called if keys_to_ignore is not None
# and that setting this flag doesn't affect on the outputs
assert mock_report_on_overrides.call_count == 0
actual_overrides = util.apply_overrides(parameterized_config,
overrides_to_apply=overrides,
keys_to_ignore={"name"})
assert actual_overrides == overrides
assert mock_report_on_overrides.call_count == 1
def test_report_on_overrides() -> None:
m = ParamClass()
overrides = {"name": "newName", "int_tuple": (0, 1, 2)}
m.report_on_overrides(overrides)
def test_report_on_overrides(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
caplog: LogCaptureFixture) -> None:
caplog.set_level(logging.WARNING)
parameterized_config = parameterized_config_and_parser[0]
old_logs = caplog.messages
assert len(old_logs) == 0
# the following overrides are expected to cause logged warnings because
# a) parameter 'constant' is constant
# b) parameter 'readonly' is readonly
# b) parameter 'idontexist' is undefined (not the name of a parameter of ParamClass)
overrides = {"constant": "dif_value", "readonly": "new_value", "idontexist": (0, 1, 2)}
keys_to_ignore: Set = set()
util.report_on_overrides(parameterized_config, overrides, keys_to_ignore)
# Expect one warning message per failed override
new_logs = caplog.messages
expected_warnings = len(overrides.keys())
assert len(new_logs) == expected_warnings, f"Expected {expected_warnings} warnings but found: {caplog.records}"
@pytest.mark.fast
@pytest.mark.parametrize("value_idx_0", [1.0, 1])
@pytest.mark.parametrize("value_idx_1", [2.0, 2])
@pytest.mark.parametrize("value_idx_2", [3.0, 3])
def test_int_tuple_validation(value_idx_0: Any, value_idx_1: Any, value_idx_2: Any) -> None:
def test_int_tuple_validation(value_idx_0: Any, value_idx_1: Any, value_idx_2: Any,
parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser]) -> None:
"""
Test integer tuple parameter is validated correctly.
"""
m = ParamClass()
parameterized_config = parameterized_config_and_parser[0]
val = (value_idx_0, value_idx_1, value_idx_2)
if not all([isinstance(x, int) for x in val]):
with pytest.raises(ValueError):
m.int_tuple = (value_idx_0, value_idx_1, value_idx_2)
parameterized_config.int_tuple = (value_idx_0, value_idx_1, value_idx_2)
else:
m.int_tuple = (value_idx_0, value_idx_1, value_idx_2)
parameterized_config.int_tuple = (value_idx_0, value_idx_1, value_idx_2)
@pytest.mark.fast
@ -1404,42 +1546,24 @@ def test_create_from_matching_params() -> None:
def test_parse_illegal_params() -> None:
with pytest.raises(ValueError) as e:
with pytest.raises(TypeError) as e:
ParamClass(readonly="abc")
assert "cannot be overridden" in str(e.value)
def test_parse_throw_if_unknown() -> None:
with pytest.raises(ValueError) as e:
ParamClass(throw_if_unknown_param=True, idontexist="hello")
assert "parameters do not exist" in str(e.value)
@patch("health_azure.utils.GenericConfig.validate")
def test_config_validate(mock_validate: MagicMock) -> None:
_ = ParamClass(should_validate=False)
assert mock_validate.call_count == 0
_ = ParamClass(should_validate=True)
assert mock_validate.call_count == 1
_ = ParamClass()
assert mock_validate.call_count == 2
assert "cannot be modified" in str(e.value)
def test_config_add_and_validate() -> None:
config = ParamClass.parse_args([])
assert config.name == "ParamClass"
config.add_and_validate({"name": "foo"})
config = ParamClass()
assert config.name.startswith("ParamClass")
util.set_fields_and_validate(config, {"name": "foo"})
assert config.name == "foo"
assert hasattr(config, "new_property") is False
config.add_and_validate({"new_property": "bar"})
util.set_fields_and_validate(config, {"new_property": "bar"})
assert hasattr(config, "new_property") is True
assert config.new_property == "bar"
class IllegalParamClassNoString(util.GenericConfig):
class IllegalParamClassNoString(param.Parameterized):
custom_type_no_from_string = IllegalCustomTypeNoFromString(
None, doc="This should fail since from_string method is missing"
)
@ -1449,8 +1573,10 @@ def test_cant_parse_param_type() -> None:
"""
Assert that a TypeError is raised when trying to add a custom type with no from_string method as an argument
"""
config = IllegalParamClassNoString()
with pytest.raises(TypeError) as e:
IllegalParamClassNoString.parse_args([])
util.create_argparser(config)
assert "is not supported" in str(e.value)
@ -1469,33 +1595,39 @@ class EvenNumberParam(util.CustomTypeParam):
return int(x)
class MyScriptConfig(util.AmlRunScriptConfig):
class MyScriptConfig(param.Parameterized):
simple_string: str = param.String(default="")
even_number: int = EvenNumberParam(2, doc="your choice of even number", allow_None=False)
def test_my_script_config() -> None:
even_number = randint(0, 100) * 2
odd_number = even_number + 1
none_number = "None"
def test_parse_args_and_apply_overrides() -> None:
config = MyScriptConfig()
assert config.even_number == 2
assert config.simple_string == ""
config = MyScriptConfig.parse_args(["--even_number", f"{even_number}"])
assert config.even_number == even_number
new_even_number = config.even_number * 2
new_string = config.simple_string + "something_new"
config_w_results = util.parse_args_and_update_config(config, ["--even_number", str(new_even_number),
"--simple_string", new_string])
assert config_w_results.even_number == new_even_number
assert config_w_results.simple_string == new_string
# parsing args with unaccepted values should cause an exception to be raised
odd_number = new_even_number + 1
with pytest.raises(ValueError) as e:
MyScriptConfig.parse_args(["--even_number", f"{odd_number}"])
util.parse_args_and_update_config(config, args=["--even_number", f"{odd_number}"])
assert "not an even number" in str(e.value)
# If parser can't parse type, will raise a SystemExit
with pytest.raises(SystemExit):
MyScriptConfig.parse_args(["--even_number", f"{none_number}"])
none_number = "None"
with pytest.raises(ArgumentError):
util.parse_args_and_update_config(config, args=["--even_number", f"{none_number}"])
# Mock from_string to check test _validate
mock_from_string_none = lambda a, b: None
mock_from_string_none = lambda a, b: None # type: ignore
with patch.object(EvenNumberParam, "from_string", new=mock_from_string_none):
# Check that _validate fails with None value
with pytest.raises(ValueError) as e:
MyScriptConfig.parse_args(["--even_number", f"{none_number}"])
util.parse_args_and_update_config(config, ["--even_number", f"{none_number}"])
assert "must not be None" in str(e.value)

Просмотреть файл

@ -8,7 +8,7 @@ Test the data input and output functionality
from pathlib import Path
from unittest import mock
from health_azure.utils import PathOrString
from typing import List, Union
from typing import List, Union, Optional
import pytest
from azureml._restclient.exceptions import ServiceException
@ -19,7 +19,8 @@ from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
from azureml.exceptions._azureml_exception import UserErrorException
from health_azure.datasets import (DatasetConfig, _input_dataset_key, _output_dataset_key,
_replace_string_datasets, get_datastore, get_or_create_dataset)
_replace_string_datasets, get_datastore, get_or_create_dataset,
create_dataset_configs)
from testazure.utils_testazure import DEFAULT_DATASTORE, DEFAULT_WORKSPACE
@ -205,3 +206,58 @@ def test_dataset_keys() -> None:
assert in1
assert out1
assert in1 != out1
def test_create_dataset_configs() -> None:
azure_datasets: List[str] = []
dataset_mountpoints: List[str] = []
local_datasets: List[Optional[Path]] = []
datastore = None
use_mounting = False
datasets = create_dataset_configs(azure_datasets,
dataset_mountpoints,
local_datasets,
datastore,
use_mounting)
assert datasets == []
# if local_datasets is not empty but azure_datasets still is, expect an empty list
local_datasets = [Path("dummy")]
datasets = create_dataset_configs(azure_datasets,
dataset_mountpoints,
local_datasets,
datastore,
use_mounting)
assert datasets == []
with pytest.raises(Exception) as e:
azure_datasets = ["dummy"]
local_datasets = [Path("another_dummy"), Path("another_extra_dummy")]
create_dataset_configs(azure_datasets,
dataset_mountpoints,
local_datasets,
datastore,
use_mounting)
assert "Invalid dataset setup" in str(e)
az_dataset_name = "dummy"
azure_datasets = [az_dataset_name]
local_datasets = [Path("another_dummy")]
datasets = create_dataset_configs(azure_datasets,
dataset_mountpoints,
local_datasets,
datastore,
use_mounting)
assert len(datasets) == 1
assert isinstance(datasets[0], DatasetConfig)
assert datasets[0].name == az_dataset_name
# If azure dataset name is empty, should still create
azure_datasets = [" "]
with pytest.raises(Exception) as e:
create_dataset_configs(azure_datasets,
dataset_mountpoints,
local_datasets,
datastore,
use_mounting)
assert "Invalid dataset setup" in str(e)

Просмотреть файл

@ -278,6 +278,28 @@ def test_create_run_configuration_correct_env(mock_create_environment: MagicMock
mock_register.assert_called_once()
assert mock_environment_get.call_count == 2
# Assert that a Conda env spec with no python version raises an exception
conda_env_spec = OrderedDict({"name": "dummy_env",
"channels": OrderedList("default"),
"dependencies": OrderedList(["- pip=20.1.1"])})
conda_env_path = tmp_path / "dummy_conda_env_no_python.yml"
with open(conda_env_path, "w+") as f_path:
yaml.dump(conda_env_spec, f_path)
assert conda_env_path.is_file()
with patch.object(mock_environment, "register") as mock_register:
mock_register.return_value = mock_environment
with patch("azureml.core.Environment.get") as mock_environment_get: # type: ignore
mock_environment_get.side_effect = Exception()
with pytest.raises(Exception) as e:
himl.create_run_configuration(mock_workspace,
"dummy_compute_cluster",
conda_environment_file=conda_env_path)
assert "you must specify the python version" in str(e)
# check that when create_run_configuration is called, whatever is returned from register_environment
# is set as the new "environment" attribute of the run config
with patch("health_azure.himl.register_environment") as mock_register_environment:

Просмотреть файл

@ -82,7 +82,7 @@ pytest_fast: pip_test call_pytest_fast
# run pytest with coverage on package, and format coverage output as a text file, assuming test requirements already installed
call_pytest_and_coverage:
pytest --cov=health_ml --cov-branch --cov-report=html --cov-report=term-missing --cov-report=xml testhiml
pytest --cov=health_ml --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc testhiml
pycobertura show --format text --output coverage.txt coverage.xml
# install test requirements and run pytest coverage

Просмотреть файл

@ -3,6 +3,7 @@ jinja2==3.0.2
matplotlib==3.4.3
opencv-python-headless==4.5.1.48
pandas==1.3.4
pytorch-lightning>=1.4.9
torchvision==0.9.0
pytorch-lightning==1.5.5
rpdb==0.1.6
torchvision==0.11.1
torch>=1.8

Просмотреть файл

@ -86,6 +86,7 @@ setup(
install_requires=install_requires,
entry_points={
'console_scripts': [
'himl-runner = health_ml.runner:main'
]
}
)

Просмотреть файл

@ -2,3 +2,13 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from health_ml.model_trainer import model_train
from health_ml.run_ml import MLRunner
from health_ml.runner import Runner
__all__ = [
"model_train",
"MLRunner",
"Runner"
]

Просмотреть файл

@ -0,0 +1,235 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from pytorch_lightning import LightningDataModule, LightningModule
from torchmetrics import MeanAbsoluteError
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import StepLR, _LRScheduler
from torch.utils.data import DataLoader, Dataset
from health_ml.lightning_container import LightningContainer
class HelloDataset(Dataset):
"""
A simple 1dim regression task, read from a data file stored in the test data folder.
"""
# Creating the data file:
# import numpy as np
# import torch
#
# N = 100
# x = torch.rand((N, 1)) * 10
# y = 0.2 * x + 0.1 * torch.randn(x.size())
# xy = torch.cat((x, y), dim=1)
# np.savetxt("health_ml/configs/hellocontainer.csv", xy.numpy(), delimiter=",")
def __init__(self, raw_data: List[List[float]]) -> None:
"""
Creates the 1-dim regression dataset.
:param raw_data: The raw data. This must be numeric data which can be converted into a tensor.
See the static method from_path_and_indexes for an example call.
"""
super().__init__() # type: ignore
self.data = torch.tensor(raw_data, dtype=torch.float)
def __len__(self) -> int:
return self.data.shape[0]
def __getitem__(self, item: int) -> Dict[str, torch.Tensor]:
return {'x': self.data[item][0:1], 'y': self.data[item][1:2]}
@staticmethod
def from_path_and_indexes(
root_folder: Path,
start_index: int,
end_index: int) -> 'HelloDataset':
"""
Static method to instantiate a HelloDataset from the root folder with the start and end indexes.
:param root_folder: The folder in which the data file lives ("hellocontainer.csv")
:param start_index: The first row to read.
:param end_index: The last row to read (exclusive)
:return: A new instance based on the root folder and the start and end indexes.
"""
raw_data = np.loadtxt(root_folder / "hellocontainer.csv", delimiter=",")[start_index:end_index]
return HelloDataset(raw_data)
class HelloDataModule(LightningDataModule):
"""
A data module that gives the training, validation and test data for a simple 1-dim regression task.
"""
def __init__(
self,
root_folder: Path) -> None:
super().__init__()
self.train = HelloDataset.from_path_and_indexes(root_folder, start_index=0, end_index=50)
self.val = HelloDataset.from_path_and_indexes(root_folder, start_index=50, end_index=70)
self.test = HelloDataset.from_path_and_indexes(root_folder, start_index=70, end_index=100)
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
pass
def setup(self, stage: Optional[str] = None) -> None:
pass
def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
return DataLoader(self.train, batch_size=5)
def val_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
return DataLoader(self.val, batch_size=5)
def test_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
return DataLoader(self.test, batch_size=5)
class HelloRegression(LightningModule):
"""
A simple 1-dim regression model.
"""
def __init__(self) -> None:
super().__init__()
self.model = torch.nn.Linear(in_features=1, out_features=1, bias=True) # type: ignore
self.test_mse: List[torch.Tensor] = []
self.test_mae = MeanAbsoluteError()
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
"""
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It runs a forward pass of a tensor through the model.
:param x: The input tensor(s)
:return: The model output.
"""
return self.model(x)
def training_step(self, batch: Dict[str, torch.Tensor], *args: Any, **kwargs: Any) -> torch.Tensor: # type: ignore
"""
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It consumes a minibatch of training data (coming out of the data loader), does forward propagation, and
computes the loss.
:param batch: The batch of training data
:return: The loss value with a computation graph attached.
"""
loss = self.shared_step(batch)
self.log("loss", loss, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch: Dict[str, torch.Tensor], *args: Any, # type: ignore
**kwargs: Any) -> torch.Tensor:
"""
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It consumes a minibatch of validation data (coming out of the data loader), does forward propagation, and
computes the loss.
:param batch: The batch of validation data
:return: The loss value on the validation data.
"""
loss = self.shared_step(batch)
self.log("val_loss", loss, on_epoch=True, on_step=False)
return loss
def shared_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
This is a convenience method to reduce code duplication, because training, validation, and test step share
large amounts of code.
:param batch: The batch of data to process, with input data and targets.
:return: The MSE loss that the model achieved on this batch.
"""
input = batch["x"]
target = batch["y"]
prediction = self.forward(input)
return torch.nn.functional.mse_loss(prediction, target) # type: ignore
def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]:
"""
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It returns the PyTorch optimizer(s) and learning rate scheduler(s) that should be used for training.
"""
optimizer = Adam(self.parameters(), lr=1e-1)
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)
return [optimizer], [scheduler]
def on_test_epoch_start(self) -> None:
"""
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
In this method, you can prepare data structures that need to be in place before evaluating the model on the
test set (that is done in the test_step).
"""
self.test_mse = []
self.test_mae.reset()
def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: # type: ignore
"""
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It evaluates the model in "inference mode" on data coming from the test set. It could, for example,
also write each model prediction to disk.
:param batch: The batch of test data.
:param batch_idx: The index (0, 1, ...) of the batch when the data loader is enumerated.
:return: The loss on the test data.
"""
input = batch["x"]
target = batch["y"]
prediction = self.forward(input)
# This illustrates two ways of computing metrics: Using standard torch
loss = torch.nn.functional.mse_loss(prediction, target) # type: ignore
self.test_mse.append(loss)
# Metrics computed using PyTorch Lightning objects. Note that these will, by default, attempt
# to synchronize across GPUs.
self.test_mae.update(preds=prediction, target=target)
return loss
def on_test_epoch_end(self) -> None:
"""
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
In this method, you can finish off anything to do with evaluating the model on the test set,
for example writing aggregate metrics to disk.
"""
average_mse = torch.mean(torch.stack(self.test_mse))
Path("test_mse.txt").write_text(str(average_mse.item()))
Path("test_mae.txt").write_text(str(self.test_mae.compute().item()))
class HelloContainer(LightningContainer):
"""
An example container for using the hi-ml runner. This container has methods
to generate the actual Lightning model, and read out the datamodule that will be used for training.
The number of training epochs is controlled at container level.
You can train this model by running `python health_ml/runner.py --model=HelloContainer` on the local box,
or via `python health_ml/runner.py --model=HelloContainer --azureml=True` in AzureML
"""
def __init__(self) -> None:
super().__init__()
self.local_dataset_dir = Path(__file__).parent
self.max_epochs = 20
# This method must be overridden by any subclass of LightningContainer. It returns the model that you wish to
# train, as a LightningModule
def create_model(self) -> LightningModule:
return HelloRegression()
# This method must be overridden by any subclass of LightningContainer. It returns a data module, which
# in turn contains 3 data loaders for training, validation, and test set.
def get_data_module(self) -> LightningDataModule:
assert self.local_dataset_dir is not None
return HelloDataModule(
root_folder=self.local_dataset_dir) # type: ignore

Просмотреть файл

@ -0,0 +1,413 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from __future__ import annotations
import logging
from enum import Enum, unique
from pathlib import Path
from typing import List, Optional
import param
from param import Parameterized
from health_azure.utils import RUN_CONTEXT, PathOrString, is_running_in_azure_ml
from health_ml.utils import fixed_paths
from health_ml.utils.common_utils import (CHECKPOINT_FOLDER,
create_unique_timestamp_id,
DEFAULT_AML_UPLOAD_DIR,
DEFAULT_LOGS_DIR_NAME, is_windows, parse_model_id_and_version)
from health_ml.utils.type_annotations import TupleFloat2
@unique
class LRWarmUpType(Enum):
"""
Supported LR warm up types for model training
"""
NoWarmUp = "NoWarmUp"
Linear = "Linear"
@unique
class LRSchedulerType(Enum):
"""
Supported lr scheduler types for model training
"""
Exponential = "Exponential"
Step = "Step"
Polynomial = "Polynomial"
Cosine = "Cosine"
MultiStep = "MultiStep"
@unique
class MultiprocessingStartMethod(Enum):
"""
Different methods for starting data loader processes.
"""
fork = "fork"
forkserver = "forkserver"
spawn = "spawn"
@unique
class OptimizerType(Enum):
"""
Supported optimizers for model training
"""
Adam = "Adam"
AMSGrad = "AMSGrad"
SGD = "SGD"
RMSprop = "RMSprop"
class ExperimentFolderHandler(Parameterized):
"""High level config to abstract the file system related settings for experiments"""
outputs_folder: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
doc="The folder where all training and test outputs should go.")
logs_folder: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
doc="The folder for all log files and Tensorboard event files")
project_root: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
doc="The root folder for the codebase that triggers the training run.")
run_folder: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
doc="The folder that contains outputs and the logs subfolder.")
@staticmethod
def create(project_root: Path,
is_offline_run: bool,
model_name: str,
output_to: Optional[str] = None) -> ExperimentFolderHandler:
"""
Creates a new object that holds output folder configurations. When running inside of AzureML, the output
folders will be directly under the project root. If not running inside AzureML, a folder with a timestamp
will be created for all outputs and logs.
:param project_root: The root folder that contains the code that submitted the present training run.
When running inside the hi-ml repository, it is the git repo root. When consuming hi-ml as a package,
this should be the root of the source code that calls the package.
:param is_offline_run: If true, this is a run outside AzureML. If False, it is inside AzureML.
:param model_name: The name of the model that is trained. This is used to generate a run-specific output
folder.
:param output_to: If provided, the output folders will be created as a subfolder of this argument. If not
given, the output folders will be created inside of the project root.
"""
if not project_root.is_absolute():
raise ValueError(f"The project root is required to be an absolute path, but got {project_root}")
if is_offline_run or output_to:
if output_to:
logging.info(f"All results will be written to the specified output folder {output_to}")
root = Path(output_to).absolute()
else:
logging.info("All results will be written to a subfolder of the project root folder.")
root = project_root.absolute() / DEFAULT_AML_UPLOAD_DIR
timestamp = create_unique_timestamp_id()
run_folder = root / f"{timestamp}_{model_name}"
outputs_folder = run_folder
logs_folder = run_folder / DEFAULT_LOGS_DIR_NAME
else:
logging.info("Running inside AzureML.")
logging.info("All results will be written to a subfolder of the project root folder.")
run_folder = project_root
outputs_folder = project_root / DEFAULT_AML_UPLOAD_DIR
logs_folder = project_root / DEFAULT_LOGS_DIR_NAME
logging.info(f"Run outputs folder: {outputs_folder}")
logging.info(f"Logs folder: {logs_folder}")
return ExperimentFolderHandler(
outputs_folder=outputs_folder,
logs_folder=logs_folder,
project_root=project_root,
run_folder=run_folder
)
class WorkflowParams(param.Parameterized):
"""
This class contains all parameters that affect how the whole training and testing workflow is executed.
"""
random_seed: int = param.Integer(42, doc="The seed to use for all random number generators.")
weights_url: List[str] = param.List(default=[], class_=str,
doc="If provided, a set of urls from which checkpoints will be downloaded"
"and used for inference.")
local_weights_path: List[Path] = param.List(default=[], class_=Path,
doc="A list of checkpoints paths to use for inference, "
"when the job is running outside Azure.")
model_id: str = param.String(default="",
doc="A model id string in the form 'model name:version' "
"to use a registered model for inference.")
multiprocessing_start_method: MultiprocessingStartMethod = \
param.ClassSelector(class_=MultiprocessingStartMethod,
default=(MultiprocessingStartMethod.spawn if is_windows()
else MultiprocessingStartMethod.fork),
doc="Method to be used to start child processes in pytorch. Should be one of forkserver, "
"fork or spawn. If not specified, fork is used on Linux and spawn on Windows. "
"Set to forkserver as a possible remedy for stuck jobs.")
regression_test_folder: Optional[Path] = \
param.ClassSelector(class_=Path, default=None, allow_None=True,
doc="A path to a folder that contains a set of files. At the end of training and "
"model evaluation, all files given in that folder must be present in the job's output "
"folder, and their contents must match exactly. When running in AzureML, you need to "
"ensure that this folder is part of the snapshot that gets uploaded. The path should "
"be relative to the repository root directory.")
def validate(self) -> None:
if sum([bool(param) for param in [self.weights_url, self.local_weights_path, self.model_id]]) > 1:
raise ValueError("Cannot specify more than one of local_weights_path, weights_url or model_id.")
if self.model_id:
parse_model_id_and_version(self.model_id)
@property
def is_running_in_aml(self) -> bool:
"""
Whether the current run is executing inside Azure ML
:return: True if the run is executing inside Azure ML, or False if outside AzureML.
"""
return is_running_in_azure_ml(RUN_CONTEXT)
def get_effective_random_seed(self) -> int:
"""
Returns the random seed set as part of this configuration.
:return:
"""
seed = self.random_seed
return seed
class DatasetParams(param.Parameterized):
azure_datasets: List[str] = param.List(default=[], class_=str,
doc="If provided, the ID of one or more datasets to use when running in"
" AzureML.This dataset must exist as a folder of the same name in the"
" 'datasets' container in the datasets storage account. This dataset"
" will be mounted and made available at the 'local_dataset' path"
" when running in AzureML.")
local_datasets: List[Path] = param.List(default=[], class_=Path,
doc="A list of one or more paths to the dataset to use, when training"
" outside of Azure ML.")
dataset_mountpoints: List[Path] = param.List(default=[], class_=Path,
doc="The path at which the AzureML dataset should be made available "
"via mounting or downloading. This only affects jobs running in "
"AzureML. If empty, use a random mount/download point.")
def validate(self) -> None:
if (not self.azure_datasets) and (not self.local_datasets):
raise ValueError("Either local_datasets or azure_datasets must be set.")
if self.dataset_mountpoints and len(self.azure_datasets) != len(self.dataset_mountpoints):
raise ValueError(f"Expected the number of azure datasets to equal the number of mountpoints, "
f"got datasets [{','.join(self.azure_datasets)}] "
f"and mountpoints [{','.join([str(m) for m in self.dataset_mountpoints])}]")
class OutputParams(param.Parameterized):
output_to: Path = param.ClassSelector(class_=Path, default=Path(),
doc="If provided, the run outputs will be written to the given folder. If "
"not provided, outputs will go into a subfolder of the project root "
"folder.")
file_system_config: ExperimentFolderHandler = param.ClassSelector(default=ExperimentFolderHandler(),
class_=ExperimentFolderHandler,
instantiate=False,
doc="File system related configs")
_model_name: str = param.String("", doc="The human readable name of the model (for example, Liver). This is "
"usually set from the class name.")
@property
def model_name(self) -> str:
"""
Gets the human readable name of the model (e.g., Liver). This is usually set from the class name.
:return: A model name as a string.
"""
return self._model_name
def set_output_to(self, output_to: PathOrString) -> None:
"""
Adjusts the file system settings in the present object such that all outputs are written to the given folder.
:param output_to: The absolute path to a folder that should contain the outputs.
"""
self.output_to = Path(output_to)
self.create_filesystem()
def create_filesystem(self, project_root: Path = fixed_paths.repository_root_directory()) -> None:
"""
Creates new file system settings (outputs folder, logs folder) based on the information stored in the
present object. If any of the folders do not yet exist, they are created.
:param project_root: The root folder for the codebase that triggers the training run.
"""
self.file_system_config = ExperimentFolderHandler.create(
project_root=project_root,
model_name=self.model_name,
is_offline_run=not is_running_in_azure_ml(RUN_CONTEXT),
output_to=str(self.output_to)
)
@property
def outputs_folder(self) -> Path:
"""Gets the full path in which the model outputs should be stored."""
return self.file_system_config.outputs_folder
@property
def logs_folder(self) -> Path:
"""Gets the full path in which the model logs should be stored."""
return self.file_system_config.logs_folder
@property
def checkpoint_folder(self) -> Path:
"""Gets the full path in which the model checkpoints should be stored during training."""
return self.outputs_folder / CHECKPOINT_FOLDER
class OptimizerParams(param.Parameterized):
l_rate: float = param.Number(1e-4, doc="The initial learning rate", bounds=(0, None))
_min_l_rate: float = param.Number(0.0, doc="The minimum learning rate for the Polynomial and Cosine schedulers.",
bounds=(0.0, None))
l_rate_scheduler: LRSchedulerType = param.ClassSelector(default=LRSchedulerType.Polynomial,
class_=LRSchedulerType,
instantiate=False,
doc="Learning rate decay method (Cosine, Polynomial, "
"Step, MultiStep or Exponential)")
l_rate_exponential_gamma: float = param.Number(0.9, doc="Controls the rate of decay for the Exponential "
"LR scheduler.")
l_rate_step_gamma: float = param.Number(0.1, doc="Controls the rate of decay for the "
"Step LR scheduler.")
l_rate_step_step_size: int = param.Integer(50, bounds=(0, None),
doc="The step size for Step LR scheduler")
l_rate_multi_step_gamma: float = param.Number(0.1, doc="Controls the rate of decay for the "
"MultiStep LR scheduler.")
l_rate_multi_step_milestones: Optional[List[int]] = param.List(None, bounds=(1, None),
allow_None=True, class_=int,
doc="The milestones for MultiStep decay.")
l_rate_polynomial_gamma: float = param.Number(1e-4, doc="Controls the rate of decay for the "
"Polynomial LR scheduler.")
l_rate_warmup: LRWarmUpType = param.ClassSelector(default=LRWarmUpType.NoWarmUp, class_=LRWarmUpType,
instantiate=False,
doc="The type of learning rate warm up to use. "
"Can be NoWarmUp (default) or Linear.")
l_rate_warmup_epochs: int = param.Integer(0, bounds=(0, None),
doc="Number of warmup epochs (linear warmup) before the "
"scheduler starts decaying the learning rate. "
"For example, if you are using MultiStepLR with "
"milestones [50, 100, 200] and warmup epochs = 100, warmup "
"will last for 100 epochs and the first decay of LR "
"will happen on epoch 150")
optimizer_type: OptimizerType = param.ClassSelector(default=OptimizerType.Adam, class_=OptimizerType,
instantiate=False, doc="The optimizer_type to use")
opt_eps: float = param.Number(1e-4, doc="The epsilon parameter of RMSprop or Adam")
rms_alpha: float = param.Number(0.9, doc="The alpha parameter of RMSprop")
adam_betas: TupleFloat2 = param.NumericTuple((0.9, 0.999), length=2,
doc="The betas parameter of Adam, default is (0.9, 0.999)")
momentum: float = param.Number(0.6, doc="The momentum parameter of the optimizers")
weight_decay: float = param.Number(1e-4, doc="The weight decay used to control L2 regularization")
def validate(self) -> None:
if len(self.adam_betas) < 2:
raise ValueError(
"The adam_betas parameter should be the coefficients used for computing running averages of "
"gradient and its square")
if self.l_rate_scheduler == LRSchedulerType.MultiStep:
if not self.l_rate_multi_step_milestones:
raise ValueError("Must specify l_rate_multi_step_milestones to use LR scheduler MultiStep")
if sorted(set(self.l_rate_multi_step_milestones)) != self.l_rate_multi_step_milestones:
raise ValueError("l_rate_multi_step_milestones must be a strictly increasing list")
if self.l_rate_multi_step_milestones[0] <= 0:
raise ValueError("l_rate_multi_step_milestones cannot be negative or 0.")
@property
def min_l_rate(self) -> float:
return self._min_l_rate
@min_l_rate.setter
def min_l_rate(self, value: float) -> None:
if value > self.l_rate:
raise ValueError("l_rate must be >= min_l_rate, found: {}, {}".format(self.l_rate, value))
self._min_l_rate = value
class TrainerParams(param.Parameterized):
max_epochs: int = param.Integer(100, bounds=(1, None), doc="Number of epochs to train.")
autosave_every_n_val_epochs: int = param.Integer(1, bounds=(0, None),
doc="Save epoch checkpoints every N validation epochs. "
"If pl_check_val_every_n_epoch > 1, this means that "
"checkpoints are saved every N * pl_check_val_every_n_epoch "
"training epochs.")
detect_anomaly: bool = param.Boolean(False, doc="If true, test gradients for anomalies (NaN or Inf) during "
"training.")
use_mixed_precision: bool = param.Boolean(False, doc="If true, mixed precision training is activated during "
"training.")
max_num_gpus: int = param.Integer(default=-1, doc="The maximum number of GPUS to use. If set to a value < 0, use"
"all available GPUs. In distributed training, this is the "
"maximum number of GPUs per node.")
pl_progress_bar_refresh_rate: Optional[int] = \
param.Integer(default=None,
doc="PyTorch Lightning trainer flag 'progress_bar_refresh_rate': How often to refresh progress "
"bar (in steps). Value 0 disables progress bar. Value None chooses automatically.")
pl_num_sanity_val_steps: int = \
param.Integer(default=0,
doc="PyTorch Lightning trainer flag 'num_sanity_val_steps': Number of validation "
"steps to run before training, to identify possible problems")
pl_deterministic: bool = \
param.Boolean(default=False,
doc="Controls the PyTorch Lightning trainer flags 'deterministic' and 'benchmark'. If "
"'pl_deterministic' is True, results are perfectly reproducible. If False, they are not, but "
"you may see training speed increases.")
pl_find_unused_parameters: bool = \
param.Boolean(default=False,
doc="Controls the PyTorch Lightning flag 'find_unused_parameters' for the DDP plugin. "
"Setting it to True comes with a performance hit.")
pl_limit_train_batches: Optional[int] = \
param.Integer(default=None,
doc="PyTorch Lightning trainer flag 'limit_train_batches': Limit the training dataset to the "
"given number of batches.")
pl_limit_val_batches: Optional[int] = \
param.Integer(default=None,
doc="PyTorch Lightning trainer flag 'limit_val_batches': Limit the validation dataset to the "
"given number of batches.")
pl_profiler: Optional[str] = \
param.String(default=None,
doc="The value to use for the 'profiler' argument for the Lightning trainer. "
"Set to either 'simple', 'advanced', or 'pytorch'")
monitor_gpu: bool = param.Boolean(default=False,
doc="If True, add the GPUStatsMonitor callback to the Lightning trainer object. "
"This will write GPU utilization metrics every 50 batches by default.")
monitor_loading: bool = param.Boolean(default=False,
doc="If True, add the BatchTimeCallback callback to the Lightning trainer "
"object. This will monitor how long individual batches take to load.")
additional_env_files: List[str] = param.List(class_=Path, default=[],
doc="Additional conda environment (.yml) files to merge into the"
" overall environment definition")
@property
def use_gpu(self) -> bool:
"""
Returns True if a GPU is available, and the self.max_num_gpus flag allows it to be used. Returns False
otherwise (i.e., if there is no GPU available, or self.max_num_gpus==0)
"""
if self.max_num_gpus == 0:
return False
from health_ml.utils.common_utils import is_gpu_available
return is_gpu_available()
def num_gpus_per_node(self) -> int:
"""
Computes the number of gpus to use for each node: either the number of gpus available on the device
or restrict it to max_num_gpu, whichever is smaller. Returns 0 if running on a CPU device.
"""
import torch
available_gpus = torch.cuda.device_count() # type: ignore
num_gpus = available_gpus if self.use_gpu else 0
message_suffix = "" if self.use_gpu else ", but not using them because use_gpu == False"
logging.info(f"Number of available GPUs: {available_gpus}{message_suffix}")
if 0 <= self.max_num_gpus < num_gpus:
num_gpus = self.max_num_gpus
logging.info(f"Restricting the number of GPUs to {num_gpus}")
elif self.max_num_gpus > num_gpus:
logging.warning(f"You requested max_num_gpus {self.max_num_gpus} but there are only {num_gpus} available.")
return num_gpus

Просмотреть файл

@ -0,0 +1,13 @@
import param
from typing import Optional
class ExperimentConfig(param.Parameterized):
cluster: Optional[str] = param.String(default=None, allow_None=True,
doc="The name of the GPU or CPU cluster inside the AzureML workspace"
"that should execute the job.")
num_nodes: int = param.Integer(default=1, doc="The number of virtual machines that will be allocated for this"
"job in AzureML.")
model: str = param.String(doc="The fully qualified name of the model to train/test -e.g."
"mymodule.configs.MyConfig.")
azureml: bool = param.Boolean(False, doc="If True, submit the executing script to run on AzureML.")

Просмотреть файл

@ -0,0 +1,171 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, Dict, Optional
from pathlib import Path
import param
from azureml.core import ScriptRunConfig
from azureml.train.hyperdrive import HyperDriveConfig
from pytorch_lightning import LightningDataModule, LightningModule
from health_ml.deep_learning_config import DatasetParams, OptimizerParams, OutputParams, TrainerParams, \
WorkflowParams
from health_ml.experiment_config import ExperimentConfig
class LightningContainer(WorkflowParams,
DatasetParams,
OutputParams,
TrainerParams,
OptimizerParams):
"""
A LightningContainer contains all information to train a user-specified PyTorch Lightning model. The model that
should be trained is returned by the `get_model` method. The training data must be returned in the form of
a LightningDataModule, by the `get_data_module` method.
"""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._model: Optional[LightningModule] = None
self._model_name = type(self).__name__
self.num_nodes = 1
def validate(self) -> None:
WorkflowParams.validate(self)
OptimizerParams.validate(self)
def setup(self) -> None:
"""
This method is called as one of the first operations of the training/testing workflow, before any other
operations on the present object. At the point when called, the datasets are already available in
the locations given by self.local_datasets. Use this method to prepare datasets or data loaders, for example.
"""
pass
def create_model(self) -> LightningModule: # type: ignore
"""
This method must create the actual Lightning model that will be trained. It can read out parameters from the
container and pass them into the model, for example.
"""
pass
def get_data_module(self) -> LightningDataModule:
"""
Gets the data that is used for the training, validation, and test steps.
This should read datasets from the self.local_datasets folder or download from a web location.
The format of the data is not specified any further.
:return: A LightningDataModule
"""
return None # type: ignore
def get_trainer_arguments(self) -> Dict[str, Any]:
"""
Gets additional parameters that will be passed on to the PyTorch Lightning trainer.
"""
return dict()
def get_parameter_search_hyperdrive_config(self, _: ScriptRunConfig) -> HyperDriveConfig: # type: ignore
"""
Parameter search is not implemented. It should be implemented in a sub class if needed.
"""
raise NotImplementedError("Parameter search is not implemented. It should be implemented in"
"a sub class if needed.")
def update_experiment_config(self, experiment_config: ExperimentConfig) -> None:
"""
This method allows overriding ExperimentConfig parameters from within a LightningContainer.
It is called right after the ExperimentConfig and container are initialised.
Be careful when using class parameters to override these values. If the parameter names clash,
CLI values will be consumed by the ExperimentConfig, but container parameters will keep their defaults.
This can be avoided by always using unique parameter names.
Also note that saving a reference to `experiment_config` and updating its attributes at any other
point may lead to unexpected behaviour.
:param experiment_config: The initialised ExperimentConfig whose parameters to override in-place.
"""
pass
def before_training_on_global_rank_zero(self) -> None:
"""
A hook that will be called before starting model training, before creating the Lightning Trainer object.
In distributed training, this is only run on global rank zero (i.e, on the process that runs on node 0, GPU 0).
The order in which hooks are called is: before_training_on_global_rank_zero, before_training_on_local_rank_zero,
before_training_on_all_ranks.
"""
pass
def before_training_on_local_rank_zero(self) -> None:
"""
A hook that will be called before starting model training.
In distributed training, this hook will be called once per node (i.e., whenever the LOCAL_RANK environment
variable is zero).
The order in which hooks are called is: before_training_on_global_rank_zero, before_training_on_local_rank_zero,
before_training_on_all_ranks.
"""
pass
def before_training_on_all_ranks(self) -> None:
"""
A hook that will be called before starting model training.
In distributed training, this hook will be called on all ranks (i.e., once per GPU).
The order in which hooks are called is: before_training_on_global_rank_zero, before_training_on_local_rank_zero,
before_training_on_all_ranks.
"""
pass
# The code from here on does not need to be modified.
@property
def model(self) -> LightningModule:
"""
Returns the PyTorch Lightning module that the present container object manages.
:return: A PyTorch Lightning module
"""
if self._model is None:
raise ValueError("No Lightning module has been set yet.")
return self._model
def create_lightning_module_and_store(self) -> None:
"""
Creates the Lightning model
"""
self._model = self.create_model()
def get_hyperdrive_config(self, run_config: ScriptRunConfig) -> HyperDriveConfig:
"""
Returns the HyperDrive config for either parameter search
:param run_config: AzureML estimator
:return: HyperDriveConfigs
"""
return self.get_parameter_search_hyperdrive_config(run_config)
def load_model_checkpoint(self, checkpoint_path: Path) -> None:
"""
Load a checkpoint from the given path. We need to define a separate method since pytorch lightning cannot
access the _model attribute to modify it.
"""
if self._model is None:
raise ValueError("No Lightning module has been set yet.")
self._model = type(self._model).load_from_checkpoint(checkpoint_path=str(checkpoint_path))
def __str__(self) -> str:
"""Returns a string describing the present object, as a list of key: value strings."""
arguments_str = "\nContainer:\n"
# Avoid callable params, the bindings that are printed out can be humongous.
# Avoid dataframes
skip_params = {name for name, value in self.param.params().items()
if isinstance(value, (param.Callable, param.DataFrame))}
for key, value in self.param.get_param_values():
if key not in skip_params:
arguments_str += f"\t{key:40}: {value}\n"
# Print out all other separate vars that are not under the guidance of the params library,
# skipping the two that are introduced by params
skip_vars = {"param", "initialized"}
for key, value in vars(self).items():
if key not in skip_vars and key[0] != "_":
arguments_str += f"\t{key:40}: {value}\n"
return arguments_str

Просмотреть файл

@ -0,0 +1,207 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import os
import sys
from pathlib import Path
from typing import Any, List, Tuple, TypeVar
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import GPUStatsMonitor, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
from health_azure.utils import (ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK, RUN_CONTEXT, is_global_rank_zero,
is_local_rank_zero, is_running_in_azure_ml)
from health_ml.lightning_container import LightningContainer
from health_ml.utils import AzureMLLogger, AzureMLProgressBar
from health_ml.utils.common_utils import EXPERIMENT_SUMMARY_FILE
from health_ml.utils.lightning_loggers import StoringLogger
TEMP_PREFIX = "temp/"
T = TypeVar('T')
def write_experiment_summary_file(config: Any, outputs_folder: Path) -> None:
"""
Writes the given config to disk in plain text in the default output folder.
"""
output = str(config)
outputs_folder.mkdir(exist_ok=True, parents=True)
dst = outputs_folder / EXPERIMENT_SUMMARY_FILE
dst.write_text(output)
logging.info(output)
def create_lightning_trainer(container: LightningContainer,
num_nodes: int = 1) -> Tuple[Trainer, StoringLogger]:
"""
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
return value.
:param container: The container with model and data.
:param num_nodes: The number of nodes to use in distributed training.
:param kwargs: Any additional keyowrd arguments will be passed to the constructor of Trainer.
:return: A tuple [Trainer object, diagnostic logger]
"""
num_gpus = container.num_gpus_per_node()
effective_num_gpus = num_gpus * num_nodes
strategy = None
if effective_num_gpus == 0:
accelerator = "cpu"
devices = 1
message = "CPU"
else:
accelerator = "gpu"
devices = num_gpus
message = f"{devices} GPU"
if effective_num_gpus > 1:
# Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of
# GPU memory).
# Initialize the DDP plugin. The default for pl_find_unused_parameters is False. If True, the plugin
# prints out lengthy warnings about the performance impact of find_unused_parameters.
strategy = DDPPlugin(find_unused_parameters=container.pl_find_unused_parameters)
message += "s per node with DDP"
logging.info(f"Using {message}")
tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
loggers = [tensorboard_logger, AzureMLLogger(False)]
storing_logger = StoringLogger()
loggers.append(storing_logger)
# Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32
# The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
# https://pytorch.org/docs/stable/notes/randomness.html
# Note that switching to deterministic models can have large performance downside.
if container.pl_deterministic:
deterministic = True
benchmark = False
else:
deterministic = False
benchmark = True
# Get more callbacks
callbacks: List[Any] = []
if container.monitor_loading:
# TODO antonsc: Remove after fixing the callback.
raise NotImplementedError("Monitoring batch loading times has been temporarily disabled.")
# callbacks.append(BatchTimeCallback())
if num_gpus > 0 and container.monitor_gpu:
logging.info("Adding monitoring for GPU utilization")
callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True)) # type: ignore
# Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers
additional_args = container.get_trainer_arguments()
# Callbacks can be specified via the "callbacks" argument (the legacy behaviour) or the new get_callbacks method
if "callbacks" in additional_args:
more_callbacks = additional_args.pop("callbacks")
if isinstance(more_callbacks, list):
callbacks.extend(more_callbacks) # type: ignore
else:
callbacks.append(more_callbacks) # type: ignore
is_azureml_run = is_running_in_azure_ml(RUN_CONTEXT)
progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate
if progress_bar_refresh_rate is None:
progress_bar_refresh_rate = 50
logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
f"To change, modify the pl_progress_bar_refresh_rate field of the container.")
if is_azureml_run:
callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate,
write_to_logging_info=True,
print_timestamp=False))
else:
callbacks.append(TQDMProgressBar(refresh_rate=progress_bar_refresh_rate))
# Read out additional model-specific args here.
# We probably want to keep essential ones like numgpu and logging.
trainer = Trainer(default_root_dir=str(container.outputs_folder),
deterministic=deterministic,
benchmark=benchmark,
accelerator=accelerator,
strategy=strategy,
max_epochs=container.max_epochs,
# Both these arguments can be integers or floats. If integers, it is the number of batches.
# If float, it's the fraction of batches. We default to 1.0 (processing all batches).
limit_train_batches=container.pl_limit_train_batches or 1.0,
limit_val_batches=container.pl_limit_val_batches or 1.0,
num_sanity_val_steps=container.pl_num_sanity_val_steps,
callbacks=callbacks,
logger=loggers,
num_nodes=num_nodes,
devices=devices,
precision=precision,
sync_batchnorm=True,
detect_anomaly=container.detect_anomaly,
profiler=container.pl_profiler,
**additional_args)
return trainer, storing_logger
def model_train(container: LightningContainer
) -> Tuple[Trainer, StoringLogger]:
"""
The main training loop. It creates the Pytorch model based on the configuration options passed in,
creates a Pytorch Lightning trainer, and trains the model.
If a checkpoint was specified, then it loads the checkpoint before resuming training.
:param container: A container object that holds the training data in PyTorch Lightning format
and the model to train.
:return: A tuple of [Trainer, StoringLogger]. Trainer is the Lightning Trainer object that was used for fitting
the model. The StoringLogger object is returned when training a built-in model, this is None when
fitting other models.
"""
lightning_model = container.model
# resource_monitor: Optional[ResourceMonitor] = None
# Execute some bookkeeping tasks only once if running distributed:
if is_global_rank_zero():
logging.info(f"Model checkpoints are saved at {container.checkpoint_folder}")
write_experiment_summary_file(container,
outputs_folder=container.outputs_folder)
data_module = container.get_data_module()
if is_global_rank_zero():
container.before_training_on_global_rank_zero()
if is_local_rank_zero():
container.before_training_on_local_rank_zero()
container.before_training_on_all_ranks()
# Create the trainer object. Backup the environment variables before doing that, in case we need to run a second
# training in the unit tests.
old_environ = dict(os.environ)
# Set random seeds just before training
seed_everything(container.get_effective_random_seed())
trainer, storing_logger = create_lightning_trainer(container,
num_nodes=container.num_nodes)
rank_info = ", ".join(f"{env}: {os.getenv(env)}"
for env in [ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK])
logging.info(f"Environment variables: {rank_info}. trainer.global_rank: {trainer.global_rank}")
# get recovery checkpoint if it exists
logging.info("Starting training")
trainer.fit(lightning_model, datamodule=data_module)
trainer.logger.finalize('success')
# DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
# We can now use the global_rank of the Lightning model, rather than environment variables, because DDP has set
# all necessary properties.
if lightning_model.global_rank != 0:
logging.info(f"Terminating training thread with rank {lightning_model.global_rank}.")
sys.exit()
logging.info("Removing redundant checkpoint files.")
# get_best_checkpoint_path(container.checkpoint_folder)
# Lightning modifies a ton of environment variables. If we first run training and then the test suite,
# those environment variables will mislead the training runs in the test suite, and make them crash.
# Hence, restore the original environment after training.
os.environ.clear()
os.environ.update(old_environ)
logging.info("Finished training")
return trainer, storing_logger

Просмотреть файл

@ -0,0 +1,164 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional
import torch.multiprocessing
from pytorch_lightning import seed_everything
from health_azure import AzureRunInfo
from health_azure.utils import (ENV_OMPI_COMM_WORLD_RANK, RUN_CONTEXT, create_run_recovery_id,
PARENT_RUN_CONTEXT, is_running_in_azure_ml)
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.model_trainer import create_lightning_trainer, model_train
from health_ml.utils import fixed_paths
from health_ml.utils.common_utils import (
change_working_directory, logging_section, RUN_RECOVERY_ID_KEY,
EFFECTIVE_RANDOM_SEED_KEY_NAME, RUN_RECOVERY_FROM_ID_KEY_NAME)
from health_ml.utils.lightning_loggers import StoringLogger
from health_ml.utils.type_annotations import PathOrString
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
"""
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
to a Path instance. If it does not exist, raise a FileNotFoundError.
:param local_dataset: The dataset folder to check.
:return: The local_dataset argument, converted to a Path.
"""
expected_dir = Path(local_dataset)
if not expected_dir.is_dir():
raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.")
logging.info(f"Model training will use the local dataset provided in {expected_dir}")
return expected_dir
class MLRunner:
def __init__(self,
experiment_config: ExperimentConfig,
container: LightningContainer,
project_root: Optional[Path] = None) -> None:
"""
Driver class to run a ML experiment. Note that the project root argument MUST be supplied when using hi-ml
as a package!
:param container: The LightningContainer object to use for training.
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
it is crucial when using hi-ml as a package or submodule!
"""
self.container = container
self.experiment_config = experiment_config
self.container.num_nodes = self.experiment_config.num_nodes
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
self.storing_logger: Optional[StoringLogger] = None
self._has_setup_run = False
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
"""
Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual
Lightning modules.
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
This can be missing when running in unit tests, where the local dataset paths are already populated.
"""
if self._has_setup_run:
return
if azure_run_info:
# Set up the paths to the datasets. azure_run_info already has all necessary information, using either
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
# This must happen before container setup because that could already read datasets.
if len(azure_run_info.input_datasets) > 0:
input_datasets = azure_run_info.input_datasets
assert len(input_datasets) > 0
local_datasets = [
check_dataset_folder_exists(input_dataset) for input_dataset in input_datasets # type: ignore
]
self.container.local_datasets = local_datasets # type: ignore
# Ensure that we use fixed seeds before initializing the PyTorch models
seed_everything(self.container.get_effective_random_seed())
# Creating the folder structure must happen before the LightningModule is created, because the output
# parameters of the container will be copied into the module.
self.container.create_filesystem(self.project_root)
self.container.setup()
self.container.create_lightning_module_and_store()
self._has_setup_run = True
def set_run_tags_from_parent(self) -> None:
"""
Set metadata for the run
"""
assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run."
run_tags_parent = PARENT_RUN_CONTEXT.get_tags()
tags_to_copy = [
"tag",
"model_name",
"execution_mode",
"recovered_from",
"friendly_name",
"build_number",
"build_user",
RUN_RECOVERY_FROM_ID_KEY_NAME
]
new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy}
new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT)
new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed())
RUN_CONTEXT.set_tags(new_tags)
def run(self) -> None:
"""
Driver function to run a ML experiment
"""
self.setup()
is_offline_run = not is_running_in_azure_ml(RUN_CONTEXT)
# Get the AzureML context in which the script is running
if not is_offline_run and PARENT_RUN_CONTEXT is not None:
logging.info("Setting tags from parent run.")
self.set_run_tags_from_parent()
# do training
with logging_section("Model training"):
_, storing_logger = model_train(container=self.container)
self.storing_logger = storing_logger
def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> List[Dict[str, float]]:
"""
Run inference on the test set for all models that are specified via a LightningContainer.
:param checkpoint_paths: The path to the checkpoint that should be used for inference.
"""
if len(checkpoint_paths) != 1:
raise ValueError(f"This method expects exactly 1 checkpoint for inference, but got {len(checkpoint_paths)}")
# lightning_model = self.container.model
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
logging.info("Running inference via the LightningModule.test_step method")
# Lightning does not cope with having two calls to .fit or .test in the same script. As a workaround for
# now, restrict number of GPUs to 1, meaning that it will not start DDP.
self.container.max_num_gpus = 1
# Without this, the trainer will think it should still operate in multi-node mode, and wrongly start
# searching for Horovod
if ENV_OMPI_COMM_WORLD_RANK in os.environ:
del os.environ[ENV_OMPI_COMM_WORLD_RANK]
# From the training setup, torch still thinks that it should run in a distributed manner,
# and would block on some GPU operations. Hence, clean up distributed training.
if torch.distributed.is_initialized(): # type: ignore
torch.distributed.destroy_process_group() # type: ignore
trainer, _ = create_lightning_trainer(self.container, num_nodes=1)
self.container.load_model_checkpoint(checkpoint_path=checkpoint_paths[0])
# Change the current working directory to ensure that test files go to thr right folder
data_module = self.container.get_data_module()
with change_working_directory(self.container.outputs_folder):
results = trainer.test(self.container.model, datamodule=data_module)
return results

Просмотреть файл

@ -0,0 +1,299 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import argparse
import logging
import os
import param
import sys
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import matplotlib
# Add hi-ml packages to sys.path so that AML can find them
himl_root = Path(__file__).parent.parent.parent.parent
print(f"Starting the himl runner at {himl_root}")
print(f"health_ml pkg: {himl_root}")
health_ml_pkg = himl_root / "hi-ml" / "src"
health_azure_pkg = himl_root / "hi-ml-azure" / "src"
sys.path.insert(0, str(health_azure_pkg))
sys.path.insert(0, str(health_ml_pkg))
print(f"sys path: {sys.path}")
from health_azure import AzureRunInfo, submit_to_azure_if_needed # noqa: E402
from health_azure.datasets import create_dataset_configs # noqa: E402
from health_azure.utils import (get_workspace, is_local_rank_zero, merge_conda_files, # noqa: E402
set_environment_variables_for_multi_node, create_argparser, parse_arguments,
ParserResult, apply_overrides)
from health_ml.experiment_config import ExperimentConfig # noqa: E402
from health_ml.lightning_container import LightningContainer # noqa: E402
from health_ml.run_ml import MLRunner # noqa: E402
from health_ml.utils import fixed_paths # noqa: E402
from health_ml.utils.common_utils import (get_all_environment_files, # noqa: E402
get_all_pip_requirements_files,
is_linux, logging_to_stdout)
from health_ml.utils.config_loader import ModelConfigLoader # noqa: E402
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
def initialize_rpdb() -> None:
"""
On Linux only, import and initialize rpdb, to enable remote debugging if necessary.
"""
# rpdb signal trapping does not work on Windows, as there is no SIGTRAP:
if not is_linux():
return
import rpdb
rpdb_port = 4444
rpdb.handle_trap(port=rpdb_port)
# For some reason, os.getpid() does not return the ID of what appears to be the currently running process.
logging.info("rpdb is handling traps. To debug: identify the main runner.py process, then as root: "
f"kill -TRAP <process_id>; nc 127.0.0.1 {rpdb_port}")
def package_setup_and_hacks() -> None:
"""
Set up the Python packages where needed. In particular, reduce the logging level for some of the used
libraries, which are particularly talkative in DEBUG mode. Usually when running in DEBUG mode, we want
diagnostics about the model building itself, but not for the underlying libraries.
It also adds workarounds for known issues in some packages.
"""
# Numba code generation is extremely talkative in DEBUG mode, disable that.
logging.getLogger('numba').setLevel(logging.WARNING)
# Matplotlib is also very talkative in DEBUG mode, filling half of the log file in a PR build.
logging.getLogger('matplotlib').setLevel(logging.INFO)
# Urllib3 prints out connection information for each call to write metrics, etc
logging.getLogger('urllib3').setLevel(logging.INFO)
logging.getLogger('msrest').setLevel(logging.INFO)
# AzureML prints too many details about logging metrics
logging.getLogger('azureml').setLevel(logging.INFO)
# Jupyter notebook report generation
logging.getLogger('papermill').setLevel(logging.INFO)
logging.getLogger('nbconvert').setLevel(logging.INFO)
# This is working around a spurious error message thrown by MKL, see
# https://github.com/pytorch/pytorch/issues/37377
os.environ['MKL_THREADING_LAYER'] = 'GNU'
# Workaround for issues with matplotlib on some X servers, see
# https://stackoverflow.com/questions/45993879/matplot-lib-fatal-io-error-25-inappropriate-ioctl-for-device-on-x
# -server-loc
matplotlib.use('Agg')
def create_runner_parser() -> argparse.ArgumentParser:
"""
Creates a commandline parser, that understands all necessary arguments for training a model
:return: An instance of ArgumentParser with args from ExperimentConfig added
"""
config = ExperimentConfig()
parser = create_argparser(config)
return parser
def additional_run_tags(commandline_args: str) -> Dict[str, str]:
"""
Gets the set of tags from the commandline arguments that will be added to the AzureML run as metadata
:param commandline_args: A string that holds all commandline arguments that were used for the present run.
"""
return {
"commandline_args": commandline_args,
}
class Runner:
"""
This class contains the high-level logic to start a training run: choose a model configuration by name,
submit to AzureML if needed, or otherwise start the actual training and test loop.
:param project_root: The root folder that contains all of the source code that should be executed.
"""
def __init__(self, project_root: Path):
self.project_root = project_root
self.experiment_config: ExperimentConfig = ExperimentConfig()
self.lightning_container: LightningContainer = None # type: ignore
# This field stores the MLRunner object that has been created in the most recent call to the run() method.
self.ml_runner: Optional[MLRunner] = None
def parse_and_load_model(self) -> ParserResult:
"""
Parses the command line arguments, and creates configuration objects for the model itself, and for the
Azure-related parameters. Sets self.experiment_config to its proper values. Returns the
parser output from parsing the model commandline arguments.
:return: ParserResult object containing args, overrides and settings
"""
parser = create_runner_parser()
parser_result = parse_arguments(parser, args=sys.argv[1:])
experiment_config = ExperimentConfig(**parser_result.args)
self.experiment_config = experiment_config
if not experiment_config.model:
raise ValueError("Parameter 'model' needs to be set to specify which model to run.")
print(f"Creating model loader with the following args: {parser_result.args}")
model_config_loader: ModelConfigLoader = ModelConfigLoader(**parser_result.args)
# Create the model as per the "model" commandline option. This is a LightningContainer.
container = model_config_loader.create_model_config_from_name(model_name=experiment_config.model)
# parse overrides and apply
assert isinstance(container, param.Parameterized)
parser_ = create_argparser(container)
# For each parser, feed in the unknown settings from the previous parser. All commandline args should
# be consumed by name, hence fail if there is something that is still unknown.
parser_result_ = parse_arguments(parser_, args=parser_result.unknown)
# Apply the overrides and validate. Overrides can come from either YAML settings or the commandline.
_ = apply_overrides(container, parser_result_.overrides) # type: ignore
container.validate()
self.lightning_container = container
return parser_result_
def run(self) -> Tuple[LightningContainer, AzureRunInfo]:
"""
The main entry point for training and testing models from the commandline. This chooses a model to train
via a commandline argument, runs training or testing, and writes all required info to disk and logs.
:return: a tuple of the LightningContainer object and an AzureRunInfo containing all information about
the present run (whether running in AzureML or not)
"""
# Usually, when we set logging to DEBUG, we want diagnostics about the model
# build itself, but not the tons of debug information that AzureML submissions create.
logging_to_stdout(logging.INFO if is_local_rank_zero() else "ERROR")
initialize_rpdb()
self.parse_and_load_model()
azure_run_info = self.submit_to_azureml_if_needed()
self.run_in_situ(azure_run_info)
return self.lightning_container, azure_run_info
def submit_to_azureml_if_needed(self) -> AzureRunInfo:
"""
Submit a job to AzureML, returning the resulting Run object, or exiting if we were asked to wait for
completion and the Run did not succeed.
:return: an AzureRunInfo object containing all of the details of the present run. If AzureML is not
specified, the attribute 'run' will None, but the object still contains helpful information
about datasets etc
"""
root_folder = self.project_root
entry_script = Path(sys.argv[0]).resolve()
script_params = sys.argv[1:]
additional_conda_env_files = self.lightning_container.additional_env_files
additional_env_files: Optional[List[Path]]
if additional_conda_env_files is not None:
additional_env_files = [Path(f) for f in additional_conda_env_files]
else:
additional_env_files = None
conda_dependencies_files = get_all_environment_files(self.project_root,
additional_files=additional_env_files)
pip_requirements_files = get_all_pip_requirements_files()
# Merge the project-specific dependencies with the packages and write unified definition
# to temp file. In case of version conflicts, the package version in the outer project is given priority.
temp_conda: Optional[Path] = None
if len(conda_dependencies_files) > 1 or len(pip_requirements_files) > 0:
temp_conda = root_folder / f"temp_environment-{uuid.uuid4().hex[:8]}.yml"
merge_conda_files(conda_dependencies_files, temp_conda, pip_files=pip_requirements_files)
# TODO: Update environment variables
environment_variables: Dict[str, Any] = {}
# get default datastore from provided workspace
workspace = get_workspace()
default_datastore = workspace.get_default_datastore().name
local_datasets = self.lightning_container.local_datasets
all_local_datasets = [Path(p) for p in local_datasets] if len(local_datasets) > 0 else []
input_datasets = \
create_dataset_configs(all_azure_dataset_ids=self.lightning_container.azure_datasets,
all_dataset_mountpoints=self.lightning_container.dataset_mountpoints,
all_local_datasets=all_local_datasets, # type: ignore
datastore=default_datastore)
try:
if self.experiment_config.azureml:
if not self.experiment_config.cluster:
raise ValueError("You need to specify a cluster name via '--cluster NAME' to submit"
"the script to run in AzureML")
azure_run_info = submit_to_azure_if_needed(
entry_script=entry_script,
snapshot_root_directory=root_folder,
script_params=script_params,
conda_environment_file=temp_conda or conda_dependencies_files[0],
aml_workspace=workspace,
compute_cluster_name=self.experiment_config.cluster,
environment_variables=environment_variables,
default_datastore=default_datastore,
experiment_name=self.lightning_container.name, # create_experiment_name(),
input_datasets=input_datasets, # type: ignore
num_nodes=self.experiment_config.num_nodes,
wait_for_completion=False,
ignored_folders=[],
submit_to_azureml=self.experiment_config.azureml,
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
tags=additional_run_tags(
commandline_args=" ".join(script_params))
)
else:
azure_run_info = submit_to_azure_if_needed(
input_datasets=input_datasets, # type: ignore
submit_to_azureml=False)
finally:
if temp_conda:
temp_conda.unlink()
# submit_to_azure_if_needed calls sys.exit after submitting to AzureML. We only reach this when running
# the script locally or in AzureML.
return azure_run_info
def run_in_situ(self, azure_run_info: AzureRunInfo) -> None:
"""
Actually run the AzureML job; this method will typically run on an Azure VM.
:param azure_run_info: Contains all information about the present run in AzureML, in particular where the
datasets are mounted.
"""
# Only set the logging level now. Usually, when we set logging to DEBUG, we want diagnostics about the model
# build itself, but not the tons of debug information that AzureML submissions create.
# Suppress the logging from all processes but the one for GPU 0 on each node, to make log files more readable
logging_to_stdout("INFO" if is_local_rank_zero() else "ERROR")
package_setup_and_hacks()
# Set environment variables for multi-node training if needed. This function will terminate early
# if it detects that it is not in a multi-node environment.
set_environment_variables_for_multi_node()
self.ml_runner = MLRunner(
experiment_config=self.experiment_config,
container=self.lightning_container,
project_root=self.project_root)
self.ml_runner.setup(azure_run_info)
self.ml_runner.run()
def run(project_root: Path) -> Tuple[LightningContainer, AzureRunInfo]:
"""
The main entry point for training and testing models from the commandline. This chooses a model to train
via a commandline argument, runs training or testing, and writes all required info to disk and logs.
:param project_root: The root folder that contains all of the source code that should be executed.
:return: If submitting to AzureML, returns the model configuration that was used for training,
including commandline overrides applied (if any). For details on the arguments, see the constructor of Runner.
"""
runner = Runner(project_root)
return runner.run()
def main() -> None:
run(project_root=fixed_paths.repository_root_directory())
if __name__ == '__main__':
main()

Просмотреть файл

@ -1,11 +1,264 @@
from typing import Optional
import logging
import os
import sys
import time
from contextlib import contextmanager
from datetime import datetime
from enum import Enum, unique
from pathlib import Path
from typing import Any, Generator, Iterable, List, Optional, Union
import torch
from health_azure.utils import PathOrString
from health_ml.utils import fixed_paths
MAX_PATH_LENGTH = 260
# convert string to None if an empty string or whitespace is provided
empty_string_to_none = lambda x: None if (x is None or len(x.strip()) == 0) else x
string_to_path = lambda x: None if (x is None or len(x.strip()) == 0) else Path(x)
EXPERIMENT_SUMMARY_FILE = "experiment_summary.txt"
CHECKPOINT_FOLDER = "checkpoints"
RUN_RECOVERY_ID_KEY = 'run_recovery_id'
EFFECTIVE_RANDOM_SEED_KEY_NAME = "effective_random_seed"
RUN_RECOVERY_FROM_ID_KEY_NAME = "recovered_from"
DEFAULT_AML_UPLOAD_DIR = "outputs"
DEFAULT_LOGS_DIR_NAME = "logs"
@unique
class ModelExecutionMode(Enum):
"""
Model execution mode
"""
TRAIN = "Train"
TEST = "Test"
VAL = "Val"
def check_is_any_of(message: str, actual: Optional[str], valid: Iterable[Optional[str]]) -> None:
"""
Raises an exception if 'actual' is not any of the given valid values.
:param message: The prefix for the error message.
:param actual: The actual value.
:param valid: The set of valid strings that 'actual' is allowed to take on.
:return:
"""
if actual not in valid:
all_valid = ", ".join(["<None>" if v is None else v for v in valid])
raise ValueError("{} must be one of [{}], but got: {}".format(message, all_valid, actual))
logging_stdout_handler: Optional[logging.StreamHandler] = None
logging_to_file_handler: Optional[logging.StreamHandler] = None
def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:
"""
Instructs the Python logging libraries to start writing logs to stdout up to the given logging level.
Logging will use a timestamp as the prefix, using UTC.
:param log_level: The logging level. All logging message with a level at or above this level will be written to
stdout. log_level can be numeric, or one of the pre-defined logging strings (INFO, DEBUG, ...).
"""
log_level = standardize_log_level(log_level)
logger = logging.getLogger()
# This function can be called multiple times, in particular in AzureML when we first run a training job and
# then a couple of tests, which also often enable logging. This would then add multiple handlers, and repeated
# logging lines.
global logging_stdout_handler
if not logging_stdout_handler:
print("Setting up logging to stdout.")
# At startup, logging has one handler set, that writes to stderr, with a log level of 0 (logging.NOTSET)
if len(logger.handlers) == 1:
logger.removeHandler(logger.handlers[0])
logging_stdout_handler = logging.StreamHandler(stream=sys.stdout)
_add_formatter(logging_stdout_handler)
logger.addHandler(logging_stdout_handler)
print(f"Setting logging level to {log_level}")
logging_stdout_handler.setLevel(log_level)
logger.setLevel(log_level)
def standardize_log_level(log_level: Union[int, str]) -> int:
"""
:param log_level: integer or string (any casing) version of a log level, e.g. 20 or "INFO".
:return: integer version of the level; throws if the string does not name a level.
"""
if isinstance(log_level, str):
log_level = log_level.upper()
check_is_any_of("log_level", log_level, logging._nameToLevel.keys())
return logging._nameToLevel[log_level]
return log_level
def _add_formatter(handler: logging.StreamHandler) -> None:
"""
Adds a logging formatter that includes the timestamp and the logging level.
"""
formatter = logging.Formatter(fmt="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ")
# noinspection PyTypeHints
formatter.converter = time.gmtime # type: ignore
handler.setFormatter(formatter)
@contextmanager
def logging_section(gerund: str) -> Generator:
"""
Context manager to print "**** STARTING: ..." and "**** FINISHED: ..." lines around sections of the log,
to help people locate particular sections. Usage:
with logging_section("doing this and that"):
do_this_and_that()
:param gerund: string expressing what happens in this section of the log.
"""
from time import time
logging.info("")
msg = f"**** STARTING: {gerund} "
logging.info(msg + (100 - len(msg)) * "*")
logging.info("")
start_time = time()
yield
elapsed = time() - start_time
logging.info("")
if elapsed >= 3600:
time_expr = f"{elapsed / 3600:0.2f} hours"
elif elapsed >= 60:
time_expr = f"{elapsed / 60:0.2f} minutes"
else:
time_expr = f"{elapsed:0.2f} seconds"
msg = f"**** FINISHED: {gerund} after {time_expr} "
logging.info(msg + (100 - len(msg)) * "*")
logging.info("")
def is_windows() -> bool:
"""
Returns True if the host operating system is Windows.
"""
return os.name == 'nt'
def is_linux() -> bool:
"""
Returns True if the host operating system is a flavour of Linux.
"""
return os.name == 'posix'
def check_properties_are_not_none(obj: Any, ignore: Optional[List[str]] = None) -> None:
"""
Checks to make sure the provided object has no properties that have a None value assigned.
"""
if ignore is not None:
none_props = [k for k, v in vars(obj).items() if v is None and k not in ignore]
if len(none_props) > 0:
raise ValueError("Properties had None value: {}".format(none_props))
@contextmanager
def change_working_directory(path_or_str: PathOrString) -> Generator:
"""
Context manager for changing the current working directory to the value provided. Outside the context
manager, the original working directory will be restored.
:param path_or_str: The new directory to change to
:yield: a _GeneratorContextManager object (this object itself is of no use, rather we are interested in
the side effect of the working directory temporarily changing
"""
new_path = Path(path_or_str).expanduser()
old_path = Path.cwd()
os.chdir(new_path)
yield
os.chdir(str(old_path))
def _create_generator(seed: Optional[int] = None) -> torch.Generator:
"""
Create Torch generator and sets seed with value if provided, or else with a random seed.
:param seed: Optional seed to set the Generator object with
:return: Torch Generator object
"""
generator = torch.Generator()
if seed is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator.manual_seed(seed)
return generator
def get_all_environment_files(project_root: Path, additional_files: Optional[List[Path]] = None) -> List[Path]:
"""
Returns a list of all Conda environment files that should be used. This is just an
environment.yml file that lives at the project root folder, plus any additional files provided.
:param project_root: The root folder of the code that starts the present training run.
:param additional_files: Optional list of additional environment files to merge
:return: A list with 1 entry that is the root level repo's conda environment files.
"""
env_files = []
project_yaml = project_root / fixed_paths.ENVIRONMENT_YAML_FILE_NAME
if project_yaml.exists():
env_files.append(project_yaml)
if additional_files:
for additional_file in additional_files:
if additional_file.exists():
env_files.append(additional_file)
return env_files
def get_all_pip_requirements_files() -> List[Path]:
"""
If the root level hi-ml directory is available (e.g. it has been installed as a submodule or
downloaded directly into a parent repo) then we must add it's pip requirements to any environment
definition. This function returns a list of the necessary pip requirements files. If the hi-ml
root directory does not exist (e.g. hi-ml has been installed as a pip package, this is not necessary
and so this function returns None)
:return: An list list of pip requirements files in the hi-ml and hi-ml-azure packages if relevant,
or else an empty list
"""
files = []
himl_root_dir = fixed_paths.himl_root_dir()
if himl_root_dir is not None:
himl_yaml = himl_root_dir / "hi-ml" / "run_requirements.txt"
himl_az_yaml = himl_root_dir / "hi-ml-azure" / "run_requirements.txt"
files.append(himl_yaml)
files.append(himl_az_yaml)
return files
return []
def create_unique_timestamp_id() -> str:
"""
Creates a unique string using the current time in UTC, up to seconds precision, with characters that
are suitable for use in filenames. For example, on 31 Dec 2019 at 11:59:59pm UTC, the result would be
2019-12-31T235959Z.
"""
unique_id = datetime.utcnow().strftime("%Y-%m-%dT%H%M%SZ")
return unique_id
def is_gpu_available() -> bool:
"""
:return: True if a GPU with at least 1 device is available.
"""
return torch.cuda.is_available() and torch.cuda.device_count() > 0 # type: ignore
def parse_model_id_and_version(model_id_and_version: str) -> None:
"""
When using registered models, the model id must have both the id and version present, in the format
model_name:version. This function checks the input model id and raises a ValueError if it is not of the
expected format
"""
if len(model_id_and_version.split(":")) != 2:
raise ValueError(
f"model id should be in the form 'model_name:version', got {model_id_and_version}")

Просмотреть файл

@ -0,0 +1,177 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from __future__ import annotations
import importlib
import inspect
import logging
import sys
from importlib.util import find_spec
from pathlib import Path
from typing import Any, Dict, List, Optional
import param
from importlib._bootstrap import ModuleSpec
from health_azure.utils import PathOrString
from health_ml.lightning_container import LightningContainer
from health_ml.utils import fixed_paths
class ModelConfigLoader(param.Parameterized):
"""
Helper class to manage model config loading.
"""
def __init__(self, **params: Any):
super().__init__(**params)
default_module = self.get_default_search_module()
self.module_search_specs: List[ModuleSpec] = [importlib.util.find_spec(default_module)] # type: ignore
self._find_module_search_specs()
def _find_module_search_specs(self) -> None:
"""
Given the fully qualified model name, append the root folder to the system path (so that the config
file can be discovered) and try to find a spec for the specifed module. If found, appends the spec
to self.module_search_specs
"""
model_namespace_parts = self.model.split(".")
if len(model_namespace_parts) == 1:
# config must be in the default path. This is already in module_search_specs so we dont need to do anything
return
else:
# Get the root folder of the fully qualified model name and ensure it is in the path to enable
# discovery of the config file
root_namespace = str(Path(model_namespace_parts[0]).absolute())
if root_namespace not in sys.path:
print(f"Adding {str(root_namespace)} to path")
sys.path.insert(0, str(root_namespace))
# Strip the root folder (now in the path) and the class name from the model namespace, leaving the
# module name - e.g. "mymodule.configs"
model_namespace = ".".join([str(p) for p in model_namespace_parts[1:-1]]) # type: ignore
custom_spec = importlib.util.find_spec(model_namespace) # type: ignore
if custom_spec is None:
raise ValueError(f"Search namespace {model_namespace} was not found.")
self.module_search_specs.append(custom_spec)
@staticmethod
def get_default_search_module() -> str:
from health_ml import configs # type: ignore
return configs.__name__
def create_model_config_from_name(self, model_name: str) -> LightningContainer:
"""
Returns a model configuration for a model of the given name.
To avoid having to import torch here, there are no references to LightningContainer.
Searching for a class member called <model_name> in the search modules provided recursively.
:param model_name: Fully qualified name of the model for which to get the configs for - i.e.
mymodule.configs.MyConfig
"""
if not model_name:
raise ValueError("Unable to load a model configuration because the model name is missing.")
# get the class name from the fully qualified name
model_name = model_name.split(".")[-1]
configs: Dict[str, LightningContainer] = {}
def _get_model_config(module_spec: ModuleSpec) -> Optional[LightningContainer]:
"""
Given a module specification check to see if it has a class property with
the <model_name> provided, and instantiate that config class with the
provided <config_overrides>. Otherwise, return None.
:param module_spec:
:return: Instantiated model config if it was found.
"""
# noinspection PyBroadException
try:
logging.debug(f"Importing {module_spec.name}")
target_module = importlib.import_module(module_spec.name)
# The "if" clause checks that obj is a class, of the desired name, that is
# defined in this module rather than being imported into it (and hence potentially
# being found twice).
_class = next(obj for name, obj in inspect.getmembers(target_module)
if inspect.isclass(obj)
and name == model_name # noqa: W503
and inspect.getmodule(obj) == target_module) # noqa: W503
logging.info(f"Found class {_class} in file {module_spec.origin}")
# ignore the exception which will occur if the provided module cannot be loaded
# or the loaded module does not have the required class as a member
except Exception as e:
exception_text = str(e)
if exception_text != "":
logging.warning(f"(from attempt to import module {module_spec.name}): {exception_text}")
return None
model_config = _class()
return model_config
def _search_recursively_and_store(module_search_spec: ModuleSpec) -> None:
"""
Given a root namespace eg: A.B.C searches recursively in all child namespaces
for class property with the <model_name> provided. If found, this is
instantiated with the provided overrides, and added to the configs dictionary.
:param module_search_spec:
"""
root_namespace = module_search_spec.name
namespaces_to_search: List[str] = []
if module_search_spec.submodule_search_locations:
logging.debug(f"Searching through {len(module_search_spec.submodule_search_locations)} folders that "
f"match namespace {module_search_spec.name}: "
f"{module_search_spec.submodule_search_locations}")
for root in module_search_spec.submodule_search_locations:
# List all python files in all the dirs under root, except for private dirs (prefixed with .)
all_py_files = [x for x in Path(root).rglob("*.py") if ".." not in str(x)]
for f in all_py_files:
if f.is_file() and "__pycache__" not in str(f) and f.name != "setup.py":
sub_namespace = path_to_namespace(f, root=root)
namespaces_to_search.append(root_namespace + "." + sub_namespace)
elif module_search_spec.origin:
# The module search spec already points to a python file: Search only that.
namespaces_to_search.append(module_search_spec.name)
else:
raise ValueError(f"Unable to process module spec: {module_search_spec}")
for n in namespaces_to_search: # type: ignore
_module_spec = None
# noinspection PyBroadException
try:
_module_spec = find_spec(n) # type: ignore
except Exception:
pass
if _module_spec:
config = _get_model_config(_module_spec)
if config:
configs[n] = config # type: ignore
for search_spec in self.module_search_specs:
_search_recursively_and_store(search_spec)
if len(configs) == 0:
raise ValueError(
f"Model name {model_name} was not found in search namespaces: "
f"{[s.name for s in self.module_search_specs]}.")
elif len(configs) > 1:
raise ValueError(
f"Multiple instances of model name {model_name} were found in namespaces: {configs.keys()}.")
else:
return list(configs.values())[0]
def path_to_namespace(path: Path, root: PathOrString = fixed_paths.repository_root_directory()) -> str:
"""
Given a path (in form R/A/B/C) and an optional root directory R, create a namespace A.B.C.
If root is provided, then path must be a relative child to it.
:param path: Path to convert to namespace
:param root: Path prefix to remove from namespace (default is project root)
:return:
"""
return ".".join([Path(x).stem for x in path.relative_to(root).parts])

Просмотреть файл

@ -211,7 +211,7 @@ class BatchTimeCallback(Callback):
self.write_and_log_epoch_time(is_training=True)
self.write_and_log_epoch_time(is_training=False)
def on_train_batch_start(self,
def on_train_batch_start(self, # type: ignore
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
@ -229,7 +229,7 @@ class BatchTimeCallback(Callback):
) -> None:
self.batch_start(batch_idx=batch_idx, is_training=False)
def on_train_batch_end(self,
def on_train_batch_end(self, # type: ignore
trainer: Trainer,
pl_module: LightningModule,
outputs: Any,

Просмотреть файл

@ -0,0 +1,58 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Optional
from health_azure.utils import PathOrString
ENVIRONMENT_YAML_FILE_NAME = "environment.yml"
def get_environment_yaml_file() -> Path:
"""
Returns the path where the environment.yml file is located, in the repository root directory.
The function throws an exception if the file is not found
:return: The full path to the environment files.
"""
# The environment file is copied into the package folder in setup.py.
root_dir = repository_root_directory()
env = root_dir / ENVIRONMENT_YAML_FILE_NAME
if not env.exists():
raise ValueError(f"File {ENVIRONMENT_YAML_FILE_NAME} was not found not found in in the repository root"
f"{root_dir}.")
return env
def repository_root_directory(path: Optional[PathOrString] = None) -> Path:
"""
Gets the full path to the root directory that holds the present repository.
:param path: if provided, a relative path to append to the absolute path to the repository root.
:return: The full path to the repository's root directory, with symlinks resolved if any.
"""
root = Path.cwd()
if path:
full_path = root / path
assert full_path.exists(), f"Path {full_path} doesn't exist"
return root / path
else:
return root
def himl_root_dir() -> Optional[Path]:
"""
Attempts to return the path to the top-level hi-ml repo that contains the hi-ml and hi-ml-azure packages.
This top level repo will only be present if hi-ml has been installed as a git submodule, or the repo has
been directly downlaoded. Otherwise (e.g.if hi-ml has been installed as a pip package) returns None
return: Path to the himl root dir if it exists, else None
"""
health_ml_root = Path(__file__).parent.parent
if health_ml_root.parent.stem == "site-packages":
return None
himl_root = health_ml_root.parent.parent.parent
assert (himl_root / "hi-ml").is_dir() and (himl_root / "hi-ml-azure").is_dir()
return himl_root

Просмотреть файл

@ -0,0 +1,110 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from typing import Any, Dict, Iterable, List, Optional
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
from health_ml.utils.type_annotations import DictStrFloat, DictStrFloatOrFloatList
class StoringLogger(LightningLoggerBase):
"""
A Pytorch Lightning logger that simply stores the metrics that are written to it.
Used for diagnostic purposes in unit tests.
"""
def __init__(self) -> None:
super().__init__()
self.results_per_epoch: Dict[int, DictStrFloatOrFloatList] = {}
self.hyperparams: Any = None
# Fields to store diagnostics for unit testing
self.train_diagnostics: List[Any] = []
self.val_diagnostics: List[Any] = []
self.results_without_epoch: List[DictStrFloat] = []
@rank_zero_only
def log_metrics(self, metrics: DictStrFloat, step: Optional[int] = None) -> None:
logging.debug(f"StoringLogger step={step}: {metrics}")
epoch_name = "epoch"
if epoch_name not in metrics:
# Metrics without an "epoch" key are logged during testing, for example
self.results_without_epoch.append(metrics)
return
epoch = int(metrics[epoch_name])
del metrics[epoch_name]
for key, value in metrics.items():
if isinstance(value, int):
metrics[key] = float(value)
if epoch in self.results_per_epoch:
current_results = self.results_per_epoch[epoch]
for key, value in metrics.items():
if key in current_results:
logging.debug(f"StoringLogger: appending results for metric {key}")
current_metrics = current_results[key]
if isinstance(current_metrics, list):
current_metrics.append(value)
else:
current_results[key] = [current_metrics, value]
else:
current_results[key] = value
else:
self.results_per_epoch[epoch] = metrics # type: ignore
@rank_zero_only
def log_hyperparams(self, params: Any) -> None:
self.hyperparams = params
def experiment(self) -> Any:
return None
def name(self) -> Any:
return ""
def version(self) -> int:
return 0
@property
def epochs(self) -> Iterable[int]:
"""
Gets the epochs for which the present object holds any results.
"""
return self.results_per_epoch.keys()
def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat:
"""
Reads the set of metrics for a given epoch, filters them to retain only those that have the given prefix,
and returns the filtered ones. This is used to break a set
of results down into those for training data (prefix "Train/") or validation data (prefix "Val/").
:param epoch: The epoch for which results should be read.
:param prefix_filter: If empty string, return all metrics. If not empty, return only those metrics that
have a name starting with `prefix`, and strip off the prefix.
:return: A metrics dictionary.
"""
epoch_results = self.results_per_epoch.get(epoch, None)
if epoch_results is None:
raise KeyError(f"No results are stored for epoch {epoch}")
filtered = {}
for key, value in epoch_results.items():
assert isinstance(key, str), f"All dictionary keys should be strings, but got: {type(key)}"
# Add the metric if either there is no prefix filter (prefix does not matter), or if the prefix
# filter is supplied and really matches the metric name
if (not prefix_filter) or key.startswith(prefix_filter):
stripped_key = key[len(prefix_filter):]
filtered[stripped_key] = value # type: ignore
return filtered # type: ignore
def to_metrics_dicts(self, prefix_filter: str = "") -> Dict[int, DictStrFloat]:
"""
Converts the results stored in the present object into a two-level dictionary, mapping from epoch number to
metric name to metric value. Only metrics where the name starts with the given prefix are retained, and the
prefix is stripped off in the result.
:param prefix_filter: If empty string, return all metrics. If not empty, return only those metrics that
have a name starting with `prefix`, and strip off the prefix.
:return: A dictionary mapping from epoch number to metric name to metric value.
"""
return {epoch: self.extract_by_prefix(epoch, prefix_filter) for epoch in self.epochs}

Просмотреть файл

@ -0,0 +1,287 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from __future__ import annotations
import random
from dataclasses import dataclass
from itertools import combinations
from math import ceil
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple
import numpy as np
import pandas as pd
from health_ml.utils import common_utils
from health_ml.utils.common_utils import ModelExecutionMode
@dataclass
class DatasetSplits:
train: pd.DataFrame
val: pd.DataFrame
test: pd.DataFrame
subject_column: Optional[str] = None
group_column: Optional[str] = None
allow_empty: bool = False
def __post_init__(self) -> None:
common_utils.check_properties_are_not_none(self)
def pairwise_intersection(*collections: Iterable) -> Set:
"""
Returns any element that appears in more than one collection
:return: a Set of elements that appear in more than one collection
"""
intersection = set()
for col1, col2 in combinations(map(set, collections), 2):
intersection |= col1 & col2
return intersection
# perform dataset split validity assertions
unique_train, unique_test, unique_val = self.unique_subjects()
intersection = pairwise_intersection(unique_train, unique_test, unique_val)
if len(intersection) != 0:
raise ValueError("Train, Test, and Val splits must have no intersection, found: {}".format(intersection))
if self.group_column is not None:
groups_train = self.train[self.group_column].unique()
groups_test = self.test[self.group_column].unique()
groups_val = self.val[self.group_column].unique()
group_intersection = pairwise_intersection(groups_train, groups_test, groups_val)
if len(group_intersection) != 0:
raise ValueError("Train, Test, and Val splits must have no intersecting groups, found: {}"
.format(group_intersection))
if (not self.allow_empty) and any([len(x) == 0 for x in [unique_train, unique_val]]):
raise ValueError("train_ids({}), val_ids({}) must have at least one value"
.format(len(unique_train), len(unique_val)))
def __str__(self) -> str:
unique_train, unique_test, unique_val = self.unique_subjects()
return f'Train: {len(unique_train)}, Test: {len(unique_test)}, and Val: {len(unique_val)}. ' \
f'Total subjects: {len(unique_train) + len(unique_test) + len(unique_val)}'
def unique_subjects(self) -> Tuple[Any, Any, Any]:
"""
Return a tuple of pandas Series of unique subjects across train, test and validation data splits,
based on self.subject_column
:return: a tuple of pandas Series
"""
return (self.train[self.subject_column].unique(),
self.test[self.subject_column].unique(),
self.val[self.subject_column].unique())
def number_of_subjects(self) -> int:
"""
Returns the sum of unique subjects in the dataset (identified by self.subject_column), summed
over train, test and validation data splits
:return: An integer representing the number of unique subjects
"""
unique_train, unique_test, unique_val = self.unique_subjects()
return len(unique_train) + len(unique_test) + len(unique_val)
def __getitem__(self, mode: ModelExecutionMode) -> pd.DataFrame:
"""
Retrieve either the train, validation or test data in the form of a Pandas dataframe, depending
on the current execution mode
:param mode: The current ModelExecutionMode
:return: A dataframe of the relevant data split
"""
if mode == ModelExecutionMode.TRAIN:
return self.train
elif mode == ModelExecutionMode.TEST:
return self.test
elif mode == ModelExecutionMode.VAL:
return self.val
else:
raise ValueError(f"Model execution mode not recognized: {mode}")
@staticmethod
def get_subject_ranges_for_splits(population: Sequence[str],
proportion_train: float,
proportion_test: float,
proportion_val: float) \
-> Dict[ModelExecutionMode, Set[str]]:
"""
Get mutually exclusive subject ranges for each dataset split (w.r.t to the proportion provided)
ensuring all sets have at least one item in them when possible.
:param population: all subjects
:param proportion_train: proportion for the train set.
:param proportion_test: proportion for the test set.
:param proportion_val: proportion for the validation set.
:return: Train, Test, and Val splits
"""
sum_proportions = proportion_train + proportion_val + proportion_test
if not np.isclose(sum_proportions, 1):
raise ValueError("proportion_train({}) + proportion_val({}) + proportion_test({}) must be ~ 1, found: {}"
.format(proportion_train, proportion_val, proportion_test, sum_proportions))
if not 0 <= proportion_test < 1:
raise ValueError("proportion_test({}) must be in range [0, 1)"
.format(proportion_test))
if not all([0 < x < 1 for x in [proportion_train, proportion_val]]):
raise ValueError("proportion_train({}) and proportion_val({}) must be in range (0, 1)"
.format(proportion_train, proportion_val))
subjects_train, subjects_test, subjects_val = (set(population[0:1]),
set(population[1:2]),
set(population[2:3]))
remaining = list(population[3:])
if proportion_test == 0:
remaining = list(subjects_test) + remaining
subjects_test = set()
subjects_train |= set(remaining[: ceil(len(remaining) * proportion_train)])
if len(subjects_test) > 0:
subjects_test |= set(remaining[len(subjects_train):
len(subjects_train) + ceil(len(remaining) * proportion_test)])
subjects_val |= set(remaining) - (subjects_train | subjects_test)
result = {
ModelExecutionMode.TRAIN: subjects_train,
ModelExecutionMode.TEST: subjects_test,
ModelExecutionMode.VAL: subjects_val
}
return result
@staticmethod
def _from_split_keys(df: pd.DataFrame,
train_keys: Sequence[str],
test_keys: Sequence[str],
val_keys: Sequence[str],
*, # make column names keyword-only arguments to avoid mistakes when providing both
key_column: str,
subject_column: str,
group_column: Optional[str]) -> DatasetSplits:
"""
Takes a slice of values from each data split train/test/val for the provided keys.
:param df: the input DataFrame
:param train_keys: keys for training.
:param test_keys: keys for testing.
:param val_keys: keys for validation.
:param key_column: name of the column the provided keys belong to
:param subject_column: subject id column name
:param group_column: grouping column name; if given, samples from each group will always be
in the same subset (train, val, or test) and cross-validation fold.
:return: Data splits with respected dataset split ids.
"""
train_df = DatasetSplits.get_df_from_ids(df, train_keys, key_column)
test_df = DatasetSplits.get_df_from_ids(df, test_keys, key_column)
val_df = DatasetSplits.get_df_from_ids(df, val_keys, key_column)
return DatasetSplits(train=train_df, test=test_df, val=val_df,
subject_column=subject_column, group_column=group_column)
@staticmethod
def from_proportions(df: pd.DataFrame,
proportion_train: float,
proportion_test: float,
proportion_val: float,
*, # make column names keyword-only arguments to avoid mistakes when providing both
subject_column: str = "",
group_column: Optional[str] = None,
shuffle: bool = True,
random_seed: int = 0) -> DatasetSplits:
"""
Creates a split of a dataset into train, test, and validation set, according to fixed proportions using
the "subject" column in the dataframe, or the group column, if given.
:param df: The dataframe containing all subjects.
:param proportion_train: proportion for the train set.
:param proportion_test: proportion for the test set.
:param subject_column: Subject id column name
:param group_column: grouping column name; if given, samples from each group will always be
in the same subset (train, val, or test) and cross-validation fold.
:param proportion_val: proportion for the validation set.
:param shuffle: If True the subjects in the dataframe will be shuffle before performing splits.
:param random_seed: Random seed to be used for shuffle 0 is default.
:return:
"""
key_column: str = subject_column if group_column is None else group_column
split_keys = df[key_column].unique()
if shuffle:
# fix the random seed so we can guarantee reproducibility when working with shuffle
random.Random(random_seed).shuffle(split_keys)
ranges = DatasetSplits.get_subject_ranges_for_splits(
split_keys,
proportion_train=proportion_train,
proportion_val=proportion_val,
proportion_test=proportion_test
)
return DatasetSplits._from_split_keys(df,
list(ranges[ModelExecutionMode.TRAIN]),
list(ranges[ModelExecutionMode.TEST]),
list(ranges[ModelExecutionMode.VAL]),
key_column=key_column,
subject_column=subject_column,
group_column=group_column)
@staticmethod
def from_subject_ids(df: pd.DataFrame,
train_ids: Sequence[str],
test_ids: Sequence[str],
val_ids: Sequence[str],
*, # make column names keyword-only arguments to avoid mistakes when providing both
subject_column: str = "",
group_column: Optional[str] = None) -> DatasetSplits:
"""
Assuming a DataFrame with columns subject
Takes a slice of values from each data split train/test/val for the provided ids.
:param df: the input DataFrame
:param train_ids: ids for training.
:param test_ids: ids for testing.
:param val_ids: ids for validation.
:param subject_column: subject id column name
:param group_column: grouping column name; if given, samples from each group will always be
in the same subset (train, val, or test) and cross-validation fold.
:return: Data splits with respected dataset split ids.
"""
return DatasetSplits._from_split_keys(df, train_ids, test_ids, val_ids, key_column=subject_column,
subject_column=subject_column, group_column=group_column)
@staticmethod
def from_groups(df: pd.DataFrame,
train_groups: Sequence[str],
test_groups: Sequence[str],
val_groups: Sequence[str],
*, # make column names keyword-only arguments to avoid mistakes when providing both
group_column: str,
subject_column: str = "") -> DatasetSplits:
"""
Assuming a DataFrame with columns subject
Takes a slice of values from each data split train/test/val for the provided groups.
:param df: the input DataFrame
:param train_groups: groups for training.
:param test_groups: groups for testing.
:param val_groups: groups for validation.
:param subject_column: subject id column name
:param group_column: grouping column name; if given, samples from each group will always be
in the same subset (train, val, or test) and cross-validation fold.
:return: Data splits with respected dataset split ids.
"""
return DatasetSplits._from_split_keys(df, train_groups, test_groups, val_groups, key_column=group_column,
subject_column=subject_column, group_column=group_column)
@staticmethod
def get_df_from_ids(df: pd.DataFrame, ids: Sequence[str],
subject_column: str = "") -> pd.DataFrame:
"""
Retrieve a subset dataframe where the subject column is restricted to a sequence of provided ids
:param df: The dataframe to restrict
:param ids: The ids to lookup
:param subject_column: The column to lookup ids in. Defaults to ""
:return: A subset of the dataframe
"""
return df[df[subject_column].isin(ids)]

Просмотреть файл

@ -0,0 +1,12 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Dict, List, Tuple, TypeVar, Union
T = TypeVar('T')
PathOrString = Union[Path, str]
TupleFloat2 = Tuple[float, float]
DictStrFloat = Dict[str, float]
DictStrFloatOrFloatList = Dict[str, Union[float, List[float]]]

Просмотреть файл

@ -0,0 +1,44 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from pathlib import Path
from unittest.mock import patch
import os
import pytest
from health_ml.utils import common_utils
@pytest.mark.parametrize("os_name, expected_val", [
("nt", True),
("None", False),
("posix", False),
("", False)])
def test_is_windows(os_name: str, expected_val: bool) -> None:
with patch.object(os, "name", new=os_name):
assert common_utils.is_windows() == expected_val
@pytest.mark.parametrize("os_name, expected_val", [
("nt", False),
("None", False),
("posix", True),
("", False)])
def test_is_linux(os_name: str, expected_val: bool) -> None:
with patch.object(os, "name", new=os_name):
assert common_utils.is_linux() == expected_val
def test_change_working_directory(tmp_path: Path) -> None:
"""
Test that change_working_directory temporarily changes the current working directory, but that the context manager
works to restore the original working directory
"""
orig_cwd_str = str(Path.cwd())
tmp_path_str = str(tmp_path)
assert orig_cwd_str != tmp_path_str
with common_utils.change_working_directory(tmp_path):
assert str(Path.cwd()) == tmp_path_str
# outside of the context, the original working directory should be restored
assert str(Path.cwd()) == orig_cwd_str != tmp_path_str

Просмотреть файл

@ -0,0 +1,139 @@
import shutil
from pathlib import Path
from typing import Any
import pytest
from health_azure.utils import is_running_on_azure_agent
from health_ml.lightning_container import LightningContainer
from health_ml.utils.config_loader import ModelConfigLoader, path_to_namespace
from testhiml.utils.fixed_paths_for_tests import full_ml_test_data_path, tests_root_directory
@pytest.fixture(scope="module")
def config_loader() -> ModelConfigLoader:
return ModelConfigLoader(**{"model": "HelloContainer"})
@pytest.fixture(scope="module")
def hello_config() -> Any:
from health_ml.configs import hello_container # type: ignore
assert Path(hello_container.__file__).exists(), "Can't find hello_container config"
return hello_container
def test_find_module_search_specs(config_loader: ModelConfigLoader) -> None:
# By default, property module_search_specs includes the default config path - health_ml.configs
len_search_specs_before = len(config_loader.module_search_specs)
assert any([m.name == "health_ml.configs" for m in config_loader.module_search_specs])
config_loader._find_module_search_specs()
# nothing should have been added to module_search_specs
assert len(config_loader.module_search_specs) == len_search_specs_before
def test_find_module_search_specs_outside_default_dir() -> None:
if is_running_on_azure_agent():
return
model_name = "NewConfig"
dummy_config_dir = Path.cwd() / "test_configs"
dummy_config_dir.mkdir()
dummy_config_path = dummy_config_dir / "new_config.py"
dummy_config = f"""class {model_name}:
def __init__(self):
pass
"""
dummy_config_path.touch()
dummy_config_path.write_text(dummy_config)
dummy_config_namespace = f"test_configs.new_config.{model_name}"
config_loader2 = ModelConfigLoader(**{"model": f"{dummy_config_namespace}"})
# The root "testhiml" should now be in the system path and the module "outputs" should be in module_search_specs
# this wont be in the previous results, since the default path was used. The default search_spec (health_ml.configs)
# should also be in the results for hte new
assert any([m.name == "new_config" for m in config_loader2.module_search_specs])
assert any([m.name == "health_ml.configs" for m in config_loader2.module_search_specs])
# If the file doesnt exist but the parent module does, the module will still be appended to module_search_specs
# at this stage
config_loader3 = ModelConfigLoader(**{"model": "test_configs.new_config.idontexist"})
assert any([m.name == "new_config" for m in config_loader3.module_search_specs])
# If the parent module doesn't exist, an Exception should be raised
with pytest.raises(Exception) as e:
ModelConfigLoader(**{"model": "testhiml.idontexist.idontexist"})
assert "was not found" in str(e)
shutil.rmtree(dummy_config_dir)
def test_get_default_search_module(config_loader: ModelConfigLoader) -> None:
search_module = config_loader.get_default_search_module()
assert search_module == "health_ml.configs"
def test_create_model_config_from_name(config_loader: ModelConfigLoader, hello_config: Any
) -> None:
# if no model name is given, an exception should be raised
with pytest.raises(Exception) as e:
config_loader.create_model_config_from_name("")
assert "the model name is missing" in str(e)
# if no config is found matching the model name, an exception should be raised
with pytest.raises(Exception) as e:
config_loader.create_model_config_from_name("idontexist")
assert "was not found in search namespaces" in str(e)
# if > 1 config is found matching the model name, an exception should be raised
config_name = "HelloContainer"
hello_config_path = Path(hello_config.__file__)
duplicate_config_file = hello_config_path.parent / "hello_container_2.py"
duplicate_config_file.touch()
shutil.copyfile(str(hello_config_path), str(duplicate_config_file))
with pytest.raises(Exception) as e:
config_loader.create_model_config_from_name(config_name)
assert "Multiple instances of model name " in str(e)
duplicate_config_file.unlink()
# if exactly one config is found, expect a LightningContainer to be returned
container = config_loader.create_model_config_from_name(config_name)
assert isinstance(container, LightningContainer)
assert container.model_name == config_name
def test_config_in_dif_location(tmp_path: Path, hello_config: Any) -> None:
himl_root = Path(hello_config.__file__).parent.parent
model_name = "HelloContainer"
new_config_path = himl_root / "hello_container_to_delete.py"
new_config_path.touch()
hello_config_path = Path(hello_config.__file__)
shutil.copyfile(str(hello_config_path), str(new_config_path))
config_loader = ModelConfigLoader(model=model_name)
# Trying to find this config should now cause an exception as it should find it in both "health_ml" and
# in "health_ml.configs"
with pytest.raises(Exception) as e:
config_loader.create_model_config_from_name(model_name)
assert "Multiple instances of model name HelloContainer were found in namespaces: " \
"dict_keys(['health_ml.configs.hello_container', 'health_ml.hello_container_to_delete']) " in str(e)
new_config_path.unlink()
@pytest.mark.parametrize("is_external", [True, False])
def test_path_to_namespace(is_external: bool) -> None:
"""
A test to check conversion between namespace to path for InnerEye and external namespaces
"""
tests_root_dir = tests_root_directory()
if is_external:
folder_name = "logs"
full_folder = tests_root_dir / folder_name
assert path_to_namespace(
path=full_folder,
root=tests_root_dir
) == folder_name
else:
assert path_to_namespace(
path=full_ml_test_data_path(),
root=tests_root_dir
) == "ML.test_data"

Просмотреть файл

@ -0,0 +1,229 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from random import randint
from unittest.mock import patch, MagicMock
import pytest
from _pytest.logging import LogCaptureFixture
from param import Number
from pathlib import Path
from health_ml.deep_learning_config import DatasetParams, WorkflowParams, OutputParams, OptimizerParams, \
ExperimentFolderHandler, TrainerParams
def test_validate_workflow_params() -> None:
# DeepLearningConfig cannot be initialized with more than one of these parameters set
with pytest.raises(ValueError) as ex:
WorkflowParams(local_datasets=Path("foo"),
local_weights_path=[Path("foo")],
weights_url=["bar"]).validate()
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
with pytest.raises(ValueError) as ex:
WorkflowParams(local_dataset=Path("foo"),
local_weights_path=[Path("foo")],
model_id="foo:1").validate()
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
with pytest.raises(ValueError) as ex:
WorkflowParams(local_dataset=Path("foo"),
weights_url=["foo"],
model_id="foo:1").validate()
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
with pytest.raises(ValueError) as ex:
WorkflowParams(local_dataset=Path("foo"),
local_weights_path=[Path("foo")],
weights_url=["foo"],
model_id="foo:1").validate()
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
with pytest.raises(ValueError) as ex:
WorkflowParams(local_dataset=Path("foo"),
model_id="foo").validate()
assert "model id should be in the form 'model_name:version'" in ex.value.args[0]
# The following should be okay
WorkflowParams(local_dataset=Path("foo"), local_weights_path=[Path("foo")]).validate()
WorkflowParams(local_dataset=Path("foo"), weights_url=["foo"]).validate()
WorkflowParams(local_dataset=Path("foo"), model_id="foo:1").validate()
def test_workflow_params_get_effective_random_seed() -> None:
params = WorkflowParams(local_dataset=Path("foo"), weights_url=["foo"])
seed = params.get_effective_random_seed()
assert seed == params.random_seed
def test_validate_dataset_params() -> None:
# DatasetParams cannot be initialized with neither of azure_datasets or local_datasets set
with pytest.raises(ValueError) as ex:
DatasetParams(local_datasets=[], azure_datasets=[]).validate()
assert ex.value.args[0] == "Either local_datasets or azure_datasets must be set."
# If azure_datasets or local_datasets is not a list an exception should be raised
with pytest.raises(Exception) as e:
DatasetParams(local_datasets="", azure_datasets=[]).validate()
assert "must be a list" in str(e)
with pytest.raises(Exception) as e:
DatasetParams(local_datasets=[], azure_datasets=None).validate()
assert "must be a list" in str(e)
# local datasets and dataset_mountpoints must be Paths
with pytest.raises(Exception) as e:
DatasetParams(local_datasets=["foo"])
assert "is not an instance of" in str(e)
with pytest.raises(Exception) as e:
DatasetParams(dataset_mountpoints=["foo"])
assert "is not an instance of" in str(e)
# The following should be okay
DatasetParams(local_datasets=[Path("foo")]).validate()
DatasetParams(azure_datasets=["bar"]).validate()
config = DatasetParams(local_datasets=[Path("foo")],
azure_datasets=[""])
config.validate()
assert config.azure_datasets == [""]
config = DatasetParams(azure_datasets=["foo"])
config.validate()
assert len(config.azure_datasets) == 1
config = DatasetParams(local_datasets=[Path("foo")],
azure_datasets=[""])
config.validate()
assert len(config.azure_datasets) == 1
config = DatasetParams(azure_datasets=["foo", "bar"])
config.validate()
assert len(config.azure_datasets) == 2
config = DatasetParams(azure_datasets=["foo"],
dataset_mountpoints=[Path()])
config.validate()
assert config.dataset_mountpoints == [Path()]
config = DatasetParams(azure_datasets=["foo"],
dataset_mountpoints=[Path("foo")])
config.validate()
assert len(config.dataset_mountpoints) == 1
# the number of mountpoints must not be larger than the number of datasets
with pytest.raises(ValueError) as e:
DatasetParams(azure_datasets=["foo"],
dataset_mountpoints=[Path("foo"), Path("bar")]).validate()
assert "Expected the number of azure datasets to equal the number of mountpoints" in str(e)
def test_output_params_set_output_to() -> None:
# output_to must be Path type
with pytest.raises(Exception) as e:
OutputParams(output_to="foo")
assert "must be an instance of Path" in str(e)
old_path = Path()
config = OutputParams(output_to=old_path)
assert config.outputs_folder == old_path
new_path = Path("dummy")
config.set_output_to(new_path)
# create_filesystem gets called inside
assert config.output_to == new_path
def test_output_params_create_filesystem(tmp_path: Path) -> None:
# file_system_config must be of type ExperimentFolderHandler
with pytest.raises(Exception) as e:
OutputParams(file_system_config="foo")
assert "value must be an instance of ExperimentFolderHandler" in str(e)
config = OutputParams()
default_file_system_config = config.file_system_config
assert isinstance(default_file_system_config, ExperimentFolderHandler)
assert default_file_system_config.project_root == Path(".")
# Now call create_filesystem with a different path project_root
config.create_filesystem(tmp_path)
new_file_system_config = config.file_system_config
assert new_file_system_config.project_root == tmp_path
def test_validate_optimizer_params() -> None:
# Instantiating OptimizerParams with no non-default values should be ok
config = OptimizerParams()
config.validate()
# assert that passing a string to a param expecting a numeric value causes an Exception to be raised
numeric_params = [k for k, v in config.params().items() if isinstance(v, Number)]
for numeric_param_name in numeric_params:
with pytest.raises(Exception) as e:
config = OptimizerParams()
setattr(config, numeric_param_name, "foo")
config.validate()
# For non-numeric parametes, check that Exceptions are raised when params with invalid types are provided
with pytest.raises(Exception) as e:
OptimizerParams(l_rate_scheduler="foo").validate()
assert "must be an instance of LRSchedulerType" in str(e)
with pytest.raises(Exception) as e:
OptimizerParams(l_rate_multi_step_milestones="foo")
assert "must be a list" in str(e)
with pytest.raises(Exception) as e:
OptimizerParams(l_rate_warmup="foo").validate()
assert "must be an instance of LRWarmUpType" in str(e)
with pytest.raises(Exception) as e:
OptimizerParams(optimizer_type="foo").validate()
assert "must be an instance of OptimizerType" in str(e)
with pytest.raises(Exception) as e:
OptimizerParams(adam_betas="foo").validate()
assert "only takes a tuple value" in str(e)
def test_optimizer_params_min_l_rate() -> None:
config = OptimizerParams()
min_l_rate = config.min_l_rate
assert min_l_rate == config._min_l_rate
def test_trainer_params_use_gpu() -> None:
config = TrainerParams()
for patch_gpu in [False, True]:
with patch("health_ml.utils.common_utils.is_gpu_available") as mock_gpu_available:
mock_gpu_available.return_value = patch_gpu
assert config.use_gpu is patch_gpu
@patch("health_ml.utils.common_utils.is_gpu_available")
def test_trainer_params_num_gpus_per_node(mock_gpu_available: MagicMock, caplog: LogCaptureFixture) -> None:
mock_gpu_available.return_value = True
# if the requested number of gpus is available and less than the total available number of gpus, a warning
# should be logged to let the user know that they aren't using the full capacity
requested_gpus = 3
config = TrainerParams(max_num_gpus=requested_gpus)
random_num_available_gpus = randint(requested_gpus, requested_gpus + 5)
with patch("torch.cuda.device_count") as mock_gpu_count:
mock_gpu_count.return_value = random_num_available_gpus
assert config.num_gpus_per_node() == requested_gpus
message = caplog.messages[-1]
assert f"Restricting the number of GPUs to {requested_gpus}" in message
# if the max number of gpus is set as less than the number available, expect a warning
requested_gpus = 3
random_num_available_gpus = randint(1, requested_gpus - 1)
with patch("torch.cuda.device_count") as mock_gpu_count:
mock_gpu_count.return_value = random_num_available_gpus
config = TrainerParams(max_num_gpus=requested_gpus)
assert config.num_gpus_per_node() == random_num_available_gpus
message = caplog.messages[-1]
assert f"You requested max_num_gpus {requested_gpus} but there are only {random_num_available_gpus}" \
f" available." in message

Просмотреть файл

@ -0,0 +1,110 @@
from pathlib import Path
from typing import Any, Dict
from unittest.mock import MagicMock, patch, Mock
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import GradientAccumulationScheduler, ModelCheckpoint, ModelSummary, TQDMProgressBar
from health_ml.configs.hello_container import HelloContainer # type: ignore
from health_ml.lightning_container import LightningContainer
from health_ml.model_trainer import (create_lightning_trainer, write_experiment_summary_file, model_train)
from health_ml.utils.common_utils import EXPERIMENT_SUMMARY_FILE
from health_ml.utils.config_loader import ModelConfigLoader
from health_ml.utils.lightning_loggers import StoringLogger
def test_write_experiment_summary_file(tmp_path: Path) -> None:
config = {
"Container": {
"_min_l_rate": 0.0,
"_model_name": "HelloContainer",
"adam_betas": "(0.9, 0.999)",
"azure_datasets": "[]"}
}
expected_args_path = tmp_path / EXPERIMENT_SUMMARY_FILE
write_experiment_summary_file(config, tmp_path)
actual_args = expected_args_path.read_text()
assert actual_args == str(config)
def test_create_lightning_trainer() -> None:
container = LightningContainer()
trainer, storing_logger = create_lightning_trainer(container)
assert trainer.num_gpus == container.num_gpus_per_node()
# by default, trainer's num_nodes is 1
assert trainer.num_nodes == 1
assert trainer.default_root_dir == str(container.outputs_folder)
assert trainer.limit_train_batches == 1.0
assert trainer._detect_anomaly == container.detect_anomaly
assert isinstance(trainer.callbacks[0], TQDMProgressBar)
assert isinstance(trainer.callbacks[1], ModelSummary)
assert isinstance(trainer.callbacks[2], GradientAccumulationScheduler)
assert isinstance(trainer.callbacks[3], ModelCheckpoint)
assert isinstance(storing_logger, StoringLogger)
assert storing_logger.hyperparams is None
assert len(storing_logger.results_per_epoch) == 0
assert len(storing_logger.train_diagnostics) == 0
assert len(storing_logger.val_diagnostics) == 0
assert len(storing_logger.results_without_epoch) == 0
class MyCallback(Callback):
def on_init_start(self, trainer: Trainer) -> None:
print("Starting to init trainer")
def test_create_lightning_trainer_with_callbacks() -> None:
"""
Test that create_lightning_trainer picks up on additional Container callbacks
"""
def _get_trainer_arguments() -> Dict[str, Any]:
callbacks = [MyCallback()]
return {"callbacks": callbacks}
model_name = "HelloContainer"
model_config_loader = ModelConfigLoader(model=model_name)
container = model_config_loader.create_model_config_from_name(model_name)
container.monitor_gpu = False
container.monitor_loading = False
# mock get_trainer_arguments method, since default HelloContainer class doesn't specify any additional callbacks
container.get_trainer_arguments = _get_trainer_arguments # type: ignore
kwargs = container.get_trainer_arguments()
assert "callbacks" in kwargs
# create_lightning_trainer(container, )
trainer, storing_logger = create_lightning_trainer(container)
# expect trainer to have 3 default callbacks: TQProgressBar, ModelSummary, GradintAccumlationScheduler
# and ModelCheckpoint, plus any additional callbacks specified in get_trainer_arguments method
kwarg_callbacks = kwargs.get("callbacks") or []
expected_num_callbacks = len(kwarg_callbacks) + 4
assert len(trainer.callbacks) == expected_num_callbacks, f"Found callbacks: {trainer.callbacks}"
assert any([isinstance(c, MyCallback) for c in trainer.callbacks])
assert isinstance(storing_logger, StoringLogger)
def test_model_train() -> None:
container = HelloContainer()
container.create_lightning_module_and_store()
with patch.object(container, "get_data_module"):
with patch("health_ml.model_trainer.create_lightning_trainer") as mock_create_trainer:
mock_trainer = MagicMock()
mock_storing_logger = MagicMock()
mock_create_trainer.return_value = mock_trainer, mock_storing_logger
mock_trainer.fit = Mock()
mock_close_logger = Mock()
mock_trainer.logger = MagicMock(close=mock_close_logger)
trainer, storing_logger = model_train(container)
mock_trainer.fit.assert_called_once()
mock_trainer.logger.finalize.assert_called_once()
assert trainer == mock_trainer
assert storing_logger == mock_storing_logger

Просмотреть файл

@ -0,0 +1,94 @@
from pathlib import Path
import pytest
from typing import Tuple
from unittest.mock import patch, MagicMock, Mock
from pytorch_lightning import Callback
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.run_ml import MLRunner
@pytest.fixture
def ml_runner() -> MLRunner:
experiment_config = ExperimentConfig()
container = LightningContainer(num_epochs=1)
return MLRunner(experiment_config=experiment_config, container=container)
def test_ml_runner_setup(ml_runner: MLRunner) -> None:
"""
Check that all the necessary methods get called during setup
"""
assert not ml_runner._has_setup_run
with patch.object(ml_runner, "container", spec=LightningContainer) as mock_container:
with patch("health_ml.run_ml.seed_everything") as mock_seed:
# mock_container.get_effectie_random_seed = Mock()
ml_runner.setup()
mock_container.get_effective_random_seed.assert_called_once()
mock_container.setup.assert_called_once()
mock_container.create_lightning_module_and_store.assert_called_once()
assert ml_runner._has_setup_run
mock_seed.assert_called_once()
def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None:
with pytest.raises(AssertionError) as ae:
ml_runner.set_run_tags_from_parent()
assert "should only be called in a Hyperdrive run" in str(ae)
with patch("health_ml.run_ml.PARENT_RUN_CONTEXT") as mock_parent_run_context:
with patch("health_ml.run_ml.RUN_CONTEXT") as mock_run_context:
mock_parent_run_context.get_tags.return_value = {"tag": "dummy_tag"}
ml_runner.set_run_tags_from_parent()
mock_run_context.set_tags.assert_called()
def test_run(ml_runner: MLRunner) -> None:
def _mock_model_train(container: LightningContainer) -> Tuple[str, str]:
return "trainer", dummy_storing_logger
dummy_storing_logger = "storing_logger"
with patch.object(ml_runner, "setup") as mock_setup:
with patch("health_ml.run_ml.model_train", new=_mock_model_train):
ml_runner.run()
mock_setup.assert_called_once()
# expect _mock_model_train to be called and the result of ml_runner.storing_logger
# updated accordingly
assert ml_runner.storing_logger == dummy_storing_logger
@patch("health_ml.run_ml.create_lightning_trainer")
def test_run_inference_for_lightning_models(mock_create_trainer: MagicMock, ml_runner: MLRunner,
tmp_path: Path) -> None:
"""
Check that all expected methods are called during inference3
"""
mock_trainer = MagicMock()
mock_test_result = [{"result": 1.0}]
mock_trainer.test.return_value = mock_test_result
mock_create_trainer.return_value = mock_trainer, ""
with patch.object(ml_runner, "container") as mock_container:
mock_container.num_gpus_per_node.return_value = 0
mock_container.get_trainer_arguments.return_value = {"callbacks": Callback()}
mock_container.load_model_checkpoint.return_value = Mock()
mock_container.get_data_module.return_value = Mock()
mock_container.pl_progress_bar_refresh_rate = None
mock_container.detect_anomaly = False
mock_container.pl_limit_train_batches = 1.0
mock_container.pl_limit_val_batches = 1.0
mock_container.outputs_folder = tmp_path
checkpoint_paths = [Path("dummy")]
result = ml_runner.run_inference_for_lightning_models(checkpoint_paths)
assert result == mock_test_result
mock_create_trainer.assert_called_once()
mock_container.load_model_checkpoint.assert_called_once()
mock_container.get_data_module.assert_called_once()
mock_trainer.test.assert_called_once()

Просмотреть файл

@ -0,0 +1,115 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import sys
from pathlib import Path
from typing import List, Optional
from unittest.mock import patch, MagicMock
import pytest
from health_azure import AzureRunInfo, DatasetConfig
from health_ml.lightning_container import LightningContainer
from health_ml.runner import Runner
@pytest.fixture
def mock_runner(tmp_path: Path) -> Runner:
return Runner(project_root=tmp_path)
@pytest.mark.parametrize("model_name, cluster, num_nodes, should_raise_value_error", [
("HelloContainer", "dummyCluster", 1, False),
("", "", None, True),
("HelloContainer", "", None, False),
("a", None, 0, True),
(None, "b", 10, True),
("HelloContainer", "b", 10, False)
])
def test_parse_and_load_model(mock_runner: Runner, model_name: Optional[str], cluster: Optional[str],
num_nodes: Optional[int], should_raise_value_error: bool) -> None:
"""
Test that command line args are parsed, a LightningContainer is instantiated with the expected attributes
and a ParserResult object is returned, with the expected attributes. If model_name cannot be found in the
namespace (i.e. the config does not exist) a ValueError should be raised
"""
dummy_args = [""]
if model_name is not None:
dummy_args.append(f"--model={model_name}")
if cluster is not None:
dummy_args.append(f"--cluster={cluster}")
if num_nodes is not None:
dummy_args.append(f"--num_nodes={num_nodes}")
with patch.object(sys, "argv", new=dummy_args):
if should_raise_value_error:
with pytest.raises(ValueError) as ve:
mock_runner.parse_and_load_model()
assert "Parameter 'model' needs to be set" in str(ve)
else:
parser_result = mock_runner.parse_and_load_model()
# if model, cluster or num_nodes are provdided in command line args, the corresponding attributes of
# the LightningContainer will be set accordingly and they will be dropped from ParserResult during
# parse_overrides_and_apply
assert parser_result.args.get("model") is None
assert parser_result.args.get("cluster") is None
assert parser_result.args.get("num_nodes") is None
assert isinstance(mock_runner.lightning_container, LightningContainer)
assert mock_runner.lightning_container.initialized
assert mock_runner.lightning_container.model_name == model_name
def test_run(mock_runner: Runner) -> None:
model_name = "HelloContainer"
arguments = ["", f"--model={model_name}"]
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
with patch("health_ml.runner.get_workspace"):
with patch.object(sys, "argv", arguments):
model_config, azure_run_info = mock_runner.run()
mock_run_in_situ.assert_called_once()
assert model_config is not None # for pyright
assert model_config.model_name == model_name
assert azure_run_info.run is None
assert len(azure_run_info.input_datasets) == len(azure_run_info.output_datasets) == 0
@patch("health_ml.runner.get_all_environment_files")
@patch("health_ml.runner.get_all_pip_requirements_files")
@patch("health_ml.runner.get_workspace")
def test_submit_to_azureml_if_needed(mock_get_workspace: MagicMock,
mock_get_pip_req_files: MagicMock,
mock_get_env_files: MagicMock,
mock_runner: Runner
) -> None:
def _mock_dont_submit_to_aml(input_datasets: List[DatasetConfig], submit_to_azureml: bool # type: ignore
) -> AzureRunInfo:
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
return AzureRunInfo(input_datasets=datasets_input,
output_datasets=[],
mount_contexts=[],
run=None,
is_running_in_azure_ml=False,
output_folder=None, # type: ignore
logs_folder=None) # type: ignore
mock_get_env_files.return_value = []
mock_get_pip_req_files.return_value = []
mock_default_datastore = MagicMock()
mock_default_datastore.name.return_value = "dummy_datastore"
mock_get_workspace.get_default_datastore.return_value = mock_default_datastore
with patch("health_ml.runner.create_dataset_configs") as mock_create_datasets:
mock_create_datasets.return_value = []
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
mock_submit_to_aml.side_effect = _mock_dont_submit_to_aml
mock_runner.lightning_container = LightningContainer()
run_info = mock_runner.submit_to_azureml_if_needed()
assert isinstance(run_info, AzureRunInfo)
assert run_info.input_datasets == []
assert run_info.is_running_in_azure_ml is False
assert run_info.output_folder is None

Просмотреть файл

@ -0,0 +1,37 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
from pathlib import Path
from typing import Optional
from health_azure.utils import PathOrString
def tests_root_directory(path: Optional[PathOrString] = None) -> Path:
"""
Gets the full path to the root directory that holds the tests.
If a relative path is provided then concatenate it with the absolute path
to the repository root.
:return: The full path to the repository's root directory, with symlinks resolved if any.
"""
root = Path(os.path.realpath(__file__)).parent.parent.parent
return root / path if path else root
def full_ml_test_data_path(path: str = "") -> Path:
"""
Takes a relative path inside of the testhiml/ML/test_data folder, and returns its
full absolute path.
:param path: A path relative to the ML/tests/test_data
:return: The full absolute path of the argument.
"""
return _full_test_data_path("ML", path)
def _full_test_data_path(prefix: str, suffix: str) -> Path:
root = tests_root_directory()
return root / prefix / "test_data" / suffix

Просмотреть файл

@ -2,6 +2,48 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pandas as pd
from health_azure.utils import UnitTestWorkspaceWrapper
DEFAULT_WORKSPACE = UnitTestWorkspaceWrapper()
def create_dataset_df() -> pd.DataFrame:
"""
Create a test dataframe for DATASET_CSV_FILE_NAME.
:return: Test dataframe.
"""
dataset_df = pd.DataFrame()
dataset_df['subject'] = list(range(10))
dataset_df['seriesId'] = [f"s{i}" for i in range(10)]
dataset_df['institutionId'] = ["xyz"] * 10
return dataset_df
def create_metrics_df() -> pd.DataFrame:
"""
Create a test dataframe for SUBJECT_METRICS_FILE_NAME.
:return: Test dataframe.
"""
metrics_df = pd.DataFrame()
metrics_df['Patient'] = list(range(10))
metrics_df['Structure'] = ['appendix'] * 10
metrics_df['Dice'] = [0.5 + i * 0.02 for i in range(10)]
return metrics_df
def create_comparison_metrics_df() -> pd.DataFrame:
"""
Create a test dataframe for comparison metrics.
:return: Test dataframe.
"""
comparison_metrics_df = pd.DataFrame()
comparison_metrics_df['Patient'] = list(range(10))
comparison_metrics_df['Structure'] = ['appendix'] * 10
comparison_metrics_df['Dice'] = [0.51 + i * 0.02 for i in range(10)]
return comparison_metrics_df

Просмотреть файл

@ -12,4 +12,18 @@
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"reportPrivateImportUsage": false,
"executionEnvironments": [
{
"root": "hi-ml/src"
},
{
"root": "hi-ml/testhiml"
},
{
"root": "hi-ml-azure/src"
},
{
"root": "hi-ml-azure/testazure"
}
]
}