зеркало из https://github.com/microsoft/hi-ml.git
First version of runner (#178)
* First version of runner * changelog and pyright * PR comments * Runs in AML * More PR comments * More PR comments * Remove GenericConfig from himl * Add more tests * pytest and pyright * Add more tests * Add more tests * Add documentation * bug fix align config attributes * works on config outside of hi-ml * flake8 mypy pytest * Add more tests for config loading * PR comments and add more tests * Respond to PR comments * More PR comments * pytest, mypy, flake8 * fix test running remotely * debug why test fails on remote build * debug why test fails on remote build * debug why test fails on remote build * fix test * Fix tests * debug test failing on azure agent * Fix test and update docs * update test * Update command line tools * Update coverage report to only read .py files * debug coverage problem * debug coverage failing * debug coverage failing * debug coverage failing * PR comments
This commit is contained in:
Родитель
a33c1ed07d
Коммит
c5a0348d36
|
@ -0,0 +1,8 @@
|
|||
[report]
|
||||
omit =
|
||||
**/pytest
|
||||
**/__init__.py
|
||||
*/hello_container_2.py
|
||||
|
||||
[html]
|
||||
skip_empty = true
|
2
.flake8
2
.flake8
|
@ -1,4 +1,4 @@
|
|||
[flake8]
|
||||
max-line-length = 120
|
||||
max-complexity = 25
|
||||
ignore = E731
|
||||
ignore = E731
|
|
@ -15,6 +15,7 @@ the section headers (Added/Changed/...) and incrementing the package version.
|
|||
### Added
|
||||
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) Add GaussianBlur and RotationByMultiplesOf90 augmentations. Added torchvision and opencv to
|
||||
the environment file since it is necessary for the augmentations.
|
||||
- ([#178](https://github.com/microsoft/hi-ml/pull/178)) Add runner script for running ML experiments
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
1
Makefile
1
Makefile
|
@ -111,6 +111,7 @@ combine: pip_test
|
|||
mkdir -p coverage
|
||||
cp hi-ml/.coverage coverage/hi-ml-coverage
|
||||
cp hi-ml-azure/.coverage coverage/hi-ml-azure-coverage
|
||||
cp .coveragerc coverage/
|
||||
cd coverage && \
|
||||
coverage combine hi-ml-coverage hi-ml-azure-coverage && \
|
||||
coverage html && \
|
||||
|
|
|
@ -7,8 +7,8 @@ REM Command file for Sphinx documentation
|
|||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
|
|
|
@ -3,5 +3,5 @@
|
|||
.. automodapi:: health_azure
|
||||
:no-inheritance-diagram:
|
||||
|
||||
.. automodapi:: health_ml.utils
|
||||
.. automodapi:: health_ml
|
||||
:no-inheritance-diagram:
|
||||
|
|
|
@ -33,6 +33,7 @@ The `hi-ml` toolbox provides
|
|||
|
||||
logging.md
|
||||
diagnostics.md
|
||||
runner.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
# Running ML experiments with hi-ml
|
||||
|
||||
The hi-ml toolbox is capable of training any PyTorch Lighting (PL) model inside of AzureML, making
|
||||
use of these features:
|
||||
- Training on a local GPU machine or inside of AzureML without code changes
|
||||
- Working with different models in the same codebase, and selecting one by name
|
||||
- Distributed training in AzureML
|
||||
- Logging via AzureML's native capabilities
|
||||
|
||||
|
||||
This can be used by invoking the hi-ml runner and providing the name of the container class, like this:
|
||||
`himl-runner --model=MyContainer`.
|
||||
|
||||
There is a fully working example [HelloContainer](../../hi-ml/src/health-ml/configs/hello_container.py), that
|
||||
implements a simple 1-dimensional regression model from data stored in a CSV file. You can run that
|
||||
from the command line by `himl-runner --model=HelloContainer`.
|
||||
|
||||
# Running ML experiments in Azure ML
|
||||
|
||||
To train in AzureML, add a `--azureml` flag. Use the flag `--cluster` to specify the name of the cluster
|
||||
in your Workspace that you want to submit the job to. So the whole command would look like:
|
||||
`himl-runner --model=HelloContainer --cluster=my_cluster_name --azureml`. You can also specify `--num_nodes` if
|
||||
you wish to distribute the model training.
|
||||
|
||||
|
||||
## Setup - creating your model config file
|
||||
|
||||
In order to use these capabilities, you need to implement a class deriving from
|
||||
`health_ml.lightning_container.LightningContainer`. This class encapsulates everything that is needed for training
|
||||
with PyTorch Lightning:
|
||||
|
||||
For example:
|
||||
```python
|
||||
class MyContainer(LightningContainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.azure_datasets = ["folder_name_in_azure_blob_storage"]
|
||||
self.local_datasets = [Path("/some/local/path")]
|
||||
self.max_epochs = 42
|
||||
|
||||
def create_model(self) -> LightningModule:
|
||||
return MyLightningModel()
|
||||
|
||||
def get_data_module(self) -> LightningDataModule:
|
||||
return MyDataModule(root_path=self.local_dataset)
|
||||
```
|
||||
The `create_model` method needs to return a subclass of PyTorch Lightning's [LightningModule](
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html?highlight=lightningmodule
|
||||
), that has
|
||||
all the usual PyTorch Lightning methods required for training, like the `training_step` and `forward` methods. E.g:
|
||||
```python
|
||||
class MyLightningModel(LightningModule):
|
||||
def __init__(self):
|
||||
self.layer = ...
|
||||
def training_step(self, *args, **kwargs):
|
||||
...
|
||||
def forward(self, *args, **kwargs):
|
||||
...
|
||||
def configure_optimizers(self):
|
||||
...
|
||||
def test_step(self, *args, **kwargs):
|
||||
...
|
||||
```
|
||||
The `get_data_module` method of the container needs to return a DataModule (inheriting from a [PyTorch Lightning DataModule](
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html)) which contains all of the logic for
|
||||
downloading, preparing and splitting your dataset, as well as methods for wrapping the train, val and test datasets
|
||||
respectively with [DataLoaders](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). E.g:
|
||||
|
||||
|
||||
```python
|
||||
class MyDataModule(LightningDataModule):
|
||||
def __init__(self, root_path: Path):
|
||||
# All data should be read from the folder given in self.root_path
|
||||
self.root_path = root_path
|
||||
def train_dataloader(self, *args, **kwargs) -> DataLoader:
|
||||
# The data should be read off self.root_path
|
||||
train_dataset = ...
|
||||
return DataLoader(train_dataset, batch_size=5, num_workers=5)
|
||||
def val_dataloader(self, *args, **kwargs) -> DataLoader:
|
||||
# The data should be read off self.root_path
|
||||
val_dataset = ...
|
||||
return DataLoader(val_dataset, batch_size=5, num_workers=5)
|
||||
def test_dataloader(self, *args, **kwargs) -> DataLoader:
|
||||
# The data should be read off self.root_path
|
||||
test_dataset = ...
|
||||
return DataLoader(test_dataset, batch_size=5, num_workers=5)
|
||||
```
|
||||
|
||||
So, the **full file** would look like:
|
||||
```python
|
||||
from pathlib import Path
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning import LightningModule, LightningDataModule
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
|
||||
class MyLightningModel(LightningModule):
|
||||
def __init__(self):
|
||||
self.layer = ...
|
||||
def training_step(self, *args, **kwargs):
|
||||
...
|
||||
def forward(self, *args, **kwargs):
|
||||
...
|
||||
def configure_optimizers(self):
|
||||
...
|
||||
def test_step(self, *args, **kwargs):
|
||||
...
|
||||
|
||||
class MyDataModule(LightningDataModule):
|
||||
def __init__(self, root_path: Path):
|
||||
# All data should be read from the folder given in self.root_path
|
||||
self.root_path = root_path
|
||||
def train_dataloader(self, *args, **kwargs) -> DataLoader:
|
||||
# The data should be read off self.root_path
|
||||
train_dataset = ...
|
||||
return DataLoader(train_dataset, batch_size=5, num_workers=5)
|
||||
def val_dataloader(self, *args, **kwargs) -> DataLoader:
|
||||
# The data should be read off self.root_path
|
||||
val_dataset = ...
|
||||
return DataLoader(val_dataset, batch_size=5, num_workers=5)
|
||||
def test_dataloader(self, *args, **kwargs) -> DataLoader:
|
||||
# The data should be read off self.root_path
|
||||
test_dataset = ...
|
||||
return DataLoader(test_dataset, batch_size=5, num_workers=5)
|
||||
|
||||
class MyContainer(LightningContainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.azure_datasets = ["folder_name_in_azure_blob_storage"]
|
||||
self.local_datasets = [Path("/some/local/path")]
|
||||
self.max_epochs = 42
|
||||
|
||||
def create_model(self) -> LightningModule:
|
||||
return MyLightningModel()
|
||||
|
||||
def get_data_module(self) -> LightningDataModule:
|
||||
return MyDataModule(root_path=self.local_dataset)
|
||||
```
|
||||
|
||||
By default, config files will be looked for in the folder "health_ml.configs". To specify config files
|
||||
that live elsewhere, use a fully qualified name for the parameter `--model` - e.g. "MyModule.Configs.my_config.py"
|
||||
|
||||
|
||||
### Outputting files during training
|
||||
|
||||
The Lightning model returned by `create_model` needs to write its output files to the current working directory.
|
||||
When running inside of AzureML, the output folders will be directly under the project root. If not running inside
|
||||
AzureML, a folder with a timestamp will be created for all outputs and logs.
|
||||
|
||||
When running in AzureML, the folder structure will be set up such that all files written
|
||||
to the current working directory are later uploaded to Azure blob storage at the end of the AzureML job. The files
|
||||
will also be later available via the AzureML UI.
|
||||
|
||||
### Trainer arguments
|
||||
All arguments that control the PyTorch Lightning `Trainer` object are defined in the class `TrainerParams`. A
|
||||
`LightningContainer` object inherits from this class. The most essential one is the `max_epochs` field, which controls
|
||||
the `max_epochs` argument of the `Trainer`.
|
||||
|
||||
For example:
|
||||
```python
|
||||
from pytorch_lightning import LightningModule, LightningDataModule
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
|
||||
class MyContainer(LightningContainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_epochs = 42
|
||||
|
||||
def create_model(self) -> LightningModule:
|
||||
return MyLightningModel()
|
||||
|
||||
def get_data_module(self) -> LightningDataModule:
|
||||
return MyDataModule(root_path=self.local_dataset)
|
||||
```
|
||||
### Optimizer and LR scheduler arguments
|
||||
To the optimizer and LR scheduler: the Lightning model returned by `create_model` should define its own
|
||||
`configure_optimizers` method, with the same signature as `LightningModule.configure_optimizers`,
|
||||
and returns a tuple containing the Optimizer and LRScheduler objects
|
||||
|
|
@ -83,7 +83,7 @@ pytest_fast: pip_test call_pytest_fast
|
|||
|
||||
# run pytest with coverage on package, and format coverage output as a text file, assuming test requirements already installed
|
||||
call_pytest_and_coverage:
|
||||
pytest --cov=health_azure --cov-branch --cov-report=html --cov-report=term-missing --cov-report=xml testazure
|
||||
pytest --cov=health_azure --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc testazure
|
||||
pycobertura show --format text --output coverage.txt coverage.xml
|
||||
|
||||
# install test requirements and run pytest coverage
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from azureml.core import Dataset, Datastore, Workspace
|
||||
from azureml.data import FileDataset, OutputFileDatasetConfig
|
||||
|
@ -236,6 +236,61 @@ def _replace_string_datasets(datasets: List[StrOrDatasetConfig],
|
|||
for d in datasets]
|
||||
|
||||
|
||||
def create_dataset_configs(all_azure_dataset_ids: List[str],
|
||||
all_dataset_mountpoints: Sequence[PathOrString],
|
||||
all_local_datasets: List[Optional[Path]],
|
||||
datastore: Optional[str] = None,
|
||||
use_mounting: bool = False) -> List[DatasetConfig]:
|
||||
"""
|
||||
Sets up all the dataset consumption objects for the datasets provided. The returned list will have the same length
|
||||
as there are non-empty azure dataset IDs.
|
||||
|
||||
Valid arguments combinations:
|
||||
N azure datasets, 0 or N mount points, 0 or N local datasets
|
||||
|
||||
:param all_azure_dataset_ids: The name of all datasets on blob storage that will be used for this run.
|
||||
:param all_dataset_mountpoints: When using the datasets in AzureML, these are the per-dataset mount points.
|
||||
:param all_local_datasets: The paths for all local versions of the datasets.
|
||||
:param datastore: The name of the AzureML datastore that holds the dataset. This can be empty if the AzureML
|
||||
workspace has only a single datastore, or if the default datastore should be used.
|
||||
:param use_mounting: If True, the dataset will be "mounted", that is, individual files will be read
|
||||
or written on-demand over the network. If False, the dataset will be fully downloaded before the job starts,
|
||||
respectively fully uploaded at job end for output datasets.
|
||||
:return: A list of DatasetConfig objects, in the same order as datasets were provided in all_azure_dataset_ids,
|
||||
omitting datasets with an empty name.
|
||||
"""
|
||||
datasets: List[DatasetConfig] = []
|
||||
num_local = len(all_local_datasets)
|
||||
num_azure = len(all_azure_dataset_ids)
|
||||
num_mount = len(all_dataset_mountpoints)
|
||||
if num_azure > 0 and (num_local == 0 or num_local == num_azure) and (num_mount == 0 or num_mount == num_azure):
|
||||
# Test for valid settings: If we have N azure datasets, the local datasets and mount points need to either
|
||||
# have exactly the same length, or 0. In the latter case, empty mount points and no local dataset will be
|
||||
# assumed below.
|
||||
count = num_azure
|
||||
elif num_azure == 0 and num_mount == 0:
|
||||
# No datasets in Azure at all: This is possible for runs that for example download their own data from the web.
|
||||
# There can be any number of local datasets, but we are not checking that. In MLRunner.setup, there is a check
|
||||
# that leaves local datasets intact if there are no Azure datasets.
|
||||
return []
|
||||
else:
|
||||
raise ValueError("Invalid dataset setup. You need to specify N entries in azure_datasets and a matching "
|
||||
"number of local_datasets and dataset_mountpoints")
|
||||
for i in range(count):
|
||||
azure_dataset = all_azure_dataset_ids[i] if i < num_azure else ""
|
||||
if not azure_dataset:
|
||||
continue
|
||||
mount_point = all_dataset_mountpoints[i] if i < num_mount else ""
|
||||
local_dataset = all_local_datasets[i] if i < num_local else None
|
||||
config = DatasetConfig(name=azure_dataset,
|
||||
target_folder=mount_point,
|
||||
local_folder=local_dataset,
|
||||
use_mounting=use_mounting,
|
||||
datastore=datastore or "")
|
||||
datasets.append(config)
|
||||
return datasets
|
||||
|
||||
|
||||
def find_workspace_for_local_datasets(aml_workspace: Optional[Workspace],
|
||||
workspace_config_path: Optional[Path],
|
||||
dataset_configs: List[DatasetConfig]) -> Optional[Workspace]:
|
||||
|
|
|
@ -138,6 +138,9 @@ def create_run_configuration(workspace: Workspace,
|
|||
private_pip_wheel_path=private_pip_wheel_path,
|
||||
docker_base_image=docker_base_image,
|
||||
environment_variables=environment_variables)
|
||||
conda_deps = new_environment.python.conda_dependencies
|
||||
if conda_deps.get_python_version() is None:
|
||||
raise ValueError("If specifying a conda environment file, you must specify the python version within it")
|
||||
registered_env = register_environment(workspace, new_environment)
|
||||
run_config.environment = registered_env
|
||||
else:
|
||||
|
|
|
@ -4,12 +4,14 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import param
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from azureml.core import Run
|
||||
|
||||
import health_azure.utils as azure_util
|
||||
from health_azure.himl import RUN_RECOVERY_FILE
|
||||
|
||||
|
||||
class HimlDownloadConfig(azure_util.AmlRunScriptConfig):
|
||||
|
@ -46,7 +48,8 @@ def retrieve_runs(download_config: HimlDownloadConfig) -> List[Run]:
|
|||
if len(runs) == 0:
|
||||
raise ValueError(f"Did not find any runs under the given experiment name: {download_config.experiment}")
|
||||
else:
|
||||
run_or_recovery_id = azure_util.get_most_recent_run_id(download_config.latest_run_file)
|
||||
most_recent_run_path = download_config.latest_run_file or Path(RUN_RECOVERY_FILE)
|
||||
run_or_recovery_id = azure_util.get_most_recent_run_id(most_recent_run_path)
|
||||
runs = [azure_util.get_aml_run_from_run_id(run_or_recovery_id,
|
||||
workspace_config_path=download_config.config_file)]
|
||||
if len(runs) == 0:
|
||||
|
@ -57,7 +60,9 @@ def retrieve_runs(download_config: HimlDownloadConfig) -> List[Run]:
|
|||
|
||||
def main() -> None: # pragma: no cover
|
||||
|
||||
download_config = HimlDownloadConfig.parse_args()
|
||||
download_config = HimlDownloadConfig()
|
||||
download_config = azure_util.parse_args_and_update_config(download_config, sys.argv[1:])
|
||||
|
||||
output_dir = download_config.output_dir
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
|
|
@ -108,8 +108,9 @@ class WrappedTensorboard(Tensorboard):
|
|||
|
||||
|
||||
def main() -> None: # pragma: no cover
|
||||
tb_config = HimlTensorboardConfig()
|
||||
tb_config = azure_util.parse_args_and_update_config(tb_config, sys.argv[1:])
|
||||
|
||||
tb_config = HimlTensorboardConfig.parse_args()
|
||||
config_path = tb_config.config_file
|
||||
|
||||
if not config_path:
|
||||
|
|
|
@ -10,9 +10,11 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import ArgumentParser, OPTIONAL
|
||||
from argparse import ArgumentParser, OPTIONAL, ArgumentError, _UNRECOGNIZED_ARGS_ATTR, Namespace, SUPPRESS
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
|
||||
|
@ -56,6 +58,7 @@ ENV_GLOBAL_RANK = "GLOBAL_RANK"
|
|||
ENV_LOCAL_RANK = "LOCAL_RANK"
|
||||
|
||||
RUN_CONTEXT = Run.get_context()
|
||||
PARENT_RUN_CONTEXT = getattr(RUN_CONTEXT, "parent", None)
|
||||
WORKSPACE_CONFIG_JSON = "config.json"
|
||||
|
||||
# By default, define several environment variables that work around known issues in the software stack
|
||||
|
@ -126,223 +129,314 @@ class GenericConfig(param.Parameterized):
|
|||
"""
|
||||
pass
|
||||
|
||||
def add_and_validate(self, kwargs: Dict[str, Any], validate: bool = True) -> None:
|
||||
|
||||
def set_fields_and_validate(config: param.Parameterized, fields_to_set: Dict[str, Any], validate: bool = True) -> None:
|
||||
"""
|
||||
Add further parameters and, if validate is True, validate. We first try set_param, but that
|
||||
fails when the parameter has a setter.
|
||||
|
||||
:param config: The model configuration
|
||||
:param fields_to_set: A dictionary of key, value pairs where each key represents a parameter to be added
|
||||
and val represents its value
|
||||
:param validate: Whether to validate the value of the parameter after adding.
|
||||
"""
|
||||
assert isinstance(config, param.Parameterized)
|
||||
for key, value in fields_to_set.items():
|
||||
try:
|
||||
config.set_param(key, value)
|
||||
except ValueError:
|
||||
setattr(config, key, value)
|
||||
if validate:
|
||||
config.validate()
|
||||
|
||||
|
||||
def create_argparser(config: param.Parameterized) -> ArgumentParser:
|
||||
"""
|
||||
Creates an ArgumentParser with all fields of the given config that are overridable.
|
||||
|
||||
:param config: The config whose parameters should be used to populate the argument parser
|
||||
:return: ArgumentParser
|
||||
"""
|
||||
assert isinstance(config, param.Parameterized)
|
||||
parser = ArgumentParser()
|
||||
_add_overrideable_config_args_to_parser(config, parser)
|
||||
return parser
|
||||
|
||||
|
||||
def _add_overrideable_config_args_to_parser(config: param.Parameterized, parser: ArgumentParser) -> ArgumentParser:
|
||||
"""
|
||||
Adds all overridable fields of the config class to the given argparser.
|
||||
Fields that are marked as readonly, constant or private are ignored.
|
||||
|
||||
:param parser: Parser to add properties to.
|
||||
"""
|
||||
|
||||
def parse_bool(x: str) -> bool:
|
||||
"""
|
||||
Add further parameters and, if validate is True, validate. We first try set_param, but that
|
||||
fails when the parameter has a setter.
|
||||
Parse a string as a bool. Supported values are case insensitive and one of:
|
||||
'on', 't', 'true', 'y', 'yes', '1' for True
|
||||
'off', 'f', 'false', 'n', 'no', '0' for False.
|
||||
|
||||
:param kwargs: A dictionary of key, value pairs where each key represents a parameter to be added
|
||||
and val represents its value
|
||||
:param validate: Whether to validate the value of the parameter after adding.
|
||||
:param x: string to test.
|
||||
:return: Bool value if string valid, otherwise a ValueError is raised.
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
self.set_param(key, value)
|
||||
except ValueError:
|
||||
setattr(self, key, value)
|
||||
if validate:
|
||||
self.validate()
|
||||
sx = str(x).lower()
|
||||
if sx in ('on', 't', 'true', 'y', 'yes', '1'):
|
||||
return True
|
||||
if sx in ('off', 'f', 'false', 'n', 'no', '0'):
|
||||
return False
|
||||
raise ValueError(f"Invalid value {x}, please supply one of True, true, false or False.")
|
||||
|
||||
@classmethod
|
||||
def create_argparser(cls) -> ArgumentParser:
|
||||
def _get_basic_type(_p: param.Parameter) -> Union[type, Callable]:
|
||||
"""
|
||||
Creates an ArgumentParser with all fields of the given argparser that are overridable.
|
||||
Given a parameter, get its basic Python type, e.g.: param.Boolean -> bool.
|
||||
Throw exception if it is not supported.
|
||||
|
||||
:return: ArgumentParser
|
||||
:param _p: parameter to get type and nargs for.
|
||||
:return: Type
|
||||
"""
|
||||
parser = ArgumentParser()
|
||||
cls.add_args(parser)
|
||||
if isinstance(_p, param.Boolean):
|
||||
p_type: Callable = parse_bool
|
||||
elif isinstance(_p, param.Integer):
|
||||
p_type = lambda x: _p.default if x == "" else int(x)
|
||||
elif isinstance(_p, param.Number):
|
||||
p_type = lambda x: _p.default if x == "" else float(x)
|
||||
elif isinstance(_p, param.String):
|
||||
p_type = str
|
||||
elif isinstance(_p, param.List):
|
||||
p_type = lambda x: [_p.class_(item) for item in x.split(',')]
|
||||
elif isinstance(_p, param.NumericTuple):
|
||||
float_or_int = lambda y: int(y) if isinstance(_p, IntTuple) else float(y)
|
||||
p_type = lambda x: tuple([float_or_int(item) for item in x.split(',')])
|
||||
elif isinstance(_p, param.ClassSelector):
|
||||
p_type = _p.class_
|
||||
elif isinstance(_p, CustomTypeParam):
|
||||
p_type = _p.from_string
|
||||
|
||||
return parser
|
||||
else:
|
||||
raise TypeError("Parameter of type: {} is not supported".format(_p))
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
||||
return p_type
|
||||
|
||||
def add_boolean_argument(parser: ArgumentParser, k: str, p: param.Parameter) -> None:
|
||||
"""
|
||||
Adds all overridable fields of the current class to the given argparser.
|
||||
Fields that are marked as readonly, constant or private are ignored.
|
||||
Add a boolean argument.
|
||||
If the parameter default is False then allow --flag (to set it True) and --flag=Bool as usual.
|
||||
If the parameter default is True then allow --no-flag (to set it to False) and --flag=Bool as usual.
|
||||
|
||||
:param parser: Parser to add properties to.
|
||||
:param parser: parser to add a boolean argument to.
|
||||
:param k: argument name.
|
||||
:param p: boolean parameter.
|
||||
"""
|
||||
if not p.default:
|
||||
# If the parameter default is False then use nargs="?" (OPTIONAL).
|
||||
# This means that the argument is optional.
|
||||
# If it is not supplied, i.e. in the --flag mode, use the "const" value, i.e. True.
|
||||
# Otherwise, i.e. in the --flag=value mode, try to parse the argument as a bool.
|
||||
parser.add_argument("--" + k, help=p.doc, type=parse_bool, default=False,
|
||||
nargs=OPTIONAL, const=True)
|
||||
else:
|
||||
# If the parameter default is True then create an exclusive group of arguments.
|
||||
# Either --flag=value as usual
|
||||
# Or --no-flag to store False in the parameter k.
|
||||
group = parser.add_mutually_exclusive_group(required=False)
|
||||
group.add_argument("--" + k, help=p.doc, type=parse_bool)
|
||||
group.add_argument('--no-' + k, dest=k, action='store_false')
|
||||
parser.set_defaults(**{k: p.default})
|
||||
|
||||
def parse_bool(x: str) -> bool:
|
||||
"""
|
||||
Parse a string as a bool. Supported values are case insensitive and one of:
|
||||
'on', 't', 'true', 'y', 'yes', '1' for True
|
||||
'off', 'f', 'false', 'n', 'no', '0' for False.
|
||||
for k, p in get_overridable_parameters(config).items():
|
||||
# param.Booleans need to be handled separately, they are more complicated because they have
|
||||
# an optional argument.
|
||||
if isinstance(p, param.Boolean):
|
||||
add_boolean_argument(parser, k, p)
|
||||
else:
|
||||
parser.add_argument("--" + k, help=p.doc, type=_get_basic_type(p), default=p.default)
|
||||
|
||||
:param x: string to test.
|
||||
:return: Bool value if string valid, otherwise a ValueError is raised.
|
||||
"""
|
||||
sx = str(x).lower()
|
||||
if sx in ('on', 't', 'true', 'y', 'yes', '1'):
|
||||
return True
|
||||
if sx in ('off', 'f', 'false', 'n', 'no', '0'):
|
||||
return False
|
||||
raise ValueError(f"Invalid value {x}, please supply one of True, true, false or False.")
|
||||
return parser
|
||||
|
||||
def _get_basic_type(_p: param.Parameter) -> Union[type, Callable]:
|
||||
"""
|
||||
Given a parameter, get its basic Python type, e.g.: param.Boolean -> bool.
|
||||
Throw exception if it is not supported.
|
||||
|
||||
:param _p: parameter to get type and nargs for.
|
||||
:return: Type
|
||||
"""
|
||||
if isinstance(_p, param.Boolean):
|
||||
p_type: Callable = parse_bool
|
||||
elif isinstance(_p, param.Integer):
|
||||
p_type = lambda x: _p.default if x == "" else int(x)
|
||||
elif isinstance(_p, param.Number):
|
||||
p_type = lambda x: _p.default if x == "" else float(x)
|
||||
elif isinstance(_p, param.String):
|
||||
p_type = str
|
||||
elif isinstance(_p, param.List):
|
||||
p_type = lambda x: [_p.class_(item) for item in x.split(',')]
|
||||
elif isinstance(_p, param.NumericTuple):
|
||||
float_or_int = lambda y: int(y) if isinstance(_p, IntTuple) else float(y)
|
||||
p_type = lambda x: tuple([float_or_int(item) for item in x.split(',')])
|
||||
elif isinstance(_p, param.ClassSelector):
|
||||
p_type = _p.class_
|
||||
elif isinstance(_p, CustomTypeParam):
|
||||
p_type = _p.from_string
|
||||
@dataclass
|
||||
class ParserResult:
|
||||
"""
|
||||
Stores the results of running an argument parser, broken down into a argument-to-value dictionary,
|
||||
arguments that the parser does not recognize.
|
||||
"""
|
||||
args: Dict[str, Any]
|
||||
unknown: List[str]
|
||||
overrides: Dict[str, Any]
|
||||
|
||||
|
||||
def _create_default_namespace(parser: ArgumentParser) -> Namespace:
|
||||
"""
|
||||
Creates an argparse Namespace with all parser-specific default values set.
|
||||
|
||||
:param parser: The parser to work with.
|
||||
:return: the Namespace object
|
||||
"""
|
||||
# This is copy/pasted from parser.parse_known_args
|
||||
namespace = Namespace()
|
||||
for action in parser._actions:
|
||||
if action.dest is not SUPPRESS:
|
||||
if not hasattr(namespace, action.dest):
|
||||
if action.default is not SUPPRESS:
|
||||
setattr(namespace, action.dest, action.default)
|
||||
for dest in parser._defaults:
|
||||
if not hasattr(namespace, dest):
|
||||
setattr(namespace, dest, parser._defaults[dest])
|
||||
return namespace
|
||||
|
||||
|
||||
def parse_arguments(parser: ArgumentParser,
|
||||
fail_on_unknown_args: bool = False,
|
||||
args: List[str] = None) -> ParserResult:
|
||||
"""
|
||||
Parses a list of commandline arguments with a given parser. Returns results broken down into a full
|
||||
arguments dictionary, a dictionary of arguments that were set to non-default values, and unknown
|
||||
arguments.
|
||||
|
||||
:param parser: The parser to use
|
||||
:param fail_on_unknown_args: If True, raise an exception if the parser encounters an argument that it does
|
||||
not recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of
|
||||
the parser result.
|
||||
:param args: Arguments to parse. If not given, use those in sys.argv
|
||||
:return: The parsed arguments, and overrides
|
||||
"""
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
# The following code is a slightly modified version of what happens in parser.parse_known_args. This had to be
|
||||
# copied here because otherwise we would not be able to achieve the priority order that we desire.
|
||||
namespace = _create_default_namespace(parser)
|
||||
|
||||
try:
|
||||
namespace, unknown = parser._parse_known_args(args, namespace)
|
||||
if hasattr(namespace, _UNRECOGNIZED_ARGS_ATTR):
|
||||
unknown.extend(getattr(namespace, _UNRECOGNIZED_ARGS_ATTR))
|
||||
delattr(namespace, _UNRECOGNIZED_ARGS_ATTR)
|
||||
except ArgumentError:
|
||||
parser.print_usage(sys.stderr)
|
||||
err = sys.exc_info()[1]
|
||||
parser._print_message(str(err), sys.stderr)
|
||||
raise
|
||||
# Parse the arguments a second time, without supplying defaults, to see which arguments actually differ
|
||||
# from defaults.
|
||||
namespace_without_defaults, _ = parser._parse_known_args(args, Namespace())
|
||||
parsed_args = vars(namespace).copy()
|
||||
overrides = vars(namespace_without_defaults).copy()
|
||||
if len(unknown) > 0 and fail_on_unknown_args:
|
||||
raise ValueError(f'Unknown arguments: {unknown}')
|
||||
return ParserResult(
|
||||
args=parsed_args,
|
||||
unknown=unknown,
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
|
||||
def parse_args_and_update_config(config: Any, args: List[str]) -> Any:
|
||||
"""
|
||||
Given a model config and a list of command line arguments, creates an argparser, adds arguments from the config
|
||||
parses the list of provided args and updates the config accordingly. Returns the updated config
|
||||
|
||||
:param config: The model configuration
|
||||
:param args: A list of command line args to parse
|
||||
:return: The config, updated with the values of the provided args
|
||||
"""
|
||||
parser = create_argparser(config)
|
||||
parser_results = parse_arguments(parser, args=args)
|
||||
_ = apply_overrides(config, parser_results.args)
|
||||
return config
|
||||
|
||||
|
||||
def get_overridable_parameters(config: Any) -> Dict[str, param.Parameter]:
|
||||
"""
|
||||
Get properties that are not constant, readonly or private (eg: prefixed with an underscore).
|
||||
|
||||
:param config: The model configuration
|
||||
:return: A dictionary of parameter names and their definitions.
|
||||
"""
|
||||
assert isinstance(config, param.Parameterized)
|
||||
return dict((k, v) for k, v in config.params().items()
|
||||
if reason_not_overridable(v) is None)
|
||||
|
||||
|
||||
def reason_not_overridable(value: param.Parameter) -> Optional[str]:
|
||||
"""
|
||||
Given a parameter, check for attributes that denote it is not overrideable (e.g. readonly, constant,
|
||||
private etc). If such an attribute exists, return a string containing a single-word description of the
|
||||
reason. Otherwise returns None.
|
||||
|
||||
:param value: a parameter value
|
||||
:return: None if the parameter is overridable; otherwise a one-word string explaining why not.
|
||||
"""
|
||||
if value.readonly:
|
||||
return "readonly"
|
||||
elif value.constant:
|
||||
return "constant"
|
||||
elif is_private_field_name(value.name):
|
||||
return "private"
|
||||
elif isinstance(value, param.Callable):
|
||||
return "callable"
|
||||
return None
|
||||
|
||||
|
||||
def apply_overrides(config: Any, overrides_to_apply: Optional[Dict[str, Any]], should_validate: bool = False,
|
||||
keys_to_ignore: Optional[Set[str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies the provided `values` overrides to the config.
|
||||
Only properties that are marked as overridable are actually overwritten.
|
||||
|
||||
:param config: The model configuration
|
||||
:param overrides_to_apply: A dictionary mapping from field name to value.
|
||||
:param should_validate: If true, run the .validate() method after applying overrides.
|
||||
:param keys_to_ignore: keys to ignore in reporting failed overrides. If None, do not report.
|
||||
:return: A dictionary with all the fields that were modified.
|
||||
"""
|
||||
|
||||
def _apply(_overrides: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
applied: Dict[str, Any] = {}
|
||||
if _overrides is not None:
|
||||
overridable_parameters = get_overridable_parameters(config).keys()
|
||||
for k, v in _overrides.items():
|
||||
if k in overridable_parameters:
|
||||
applied[k] = v
|
||||
setattr(config, k, v)
|
||||
|
||||
return applied
|
||||
|
||||
actual_overrides = _apply(overrides_to_apply)
|
||||
if keys_to_ignore is not None:
|
||||
report_on_overrides(config, overrides_to_apply, keys_to_ignore) # type: ignore
|
||||
if should_validate:
|
||||
config.validate()
|
||||
return actual_overrides
|
||||
|
||||
|
||||
def report_on_overrides(config: Any, overrides_to_apply: Dict[str, Any], keys_to_ignore: Set[str]) -> None:
|
||||
"""
|
||||
Logs a warning for every parameter whose value is not as given in "overrides_to_apply", other than those
|
||||
in keys_to_ignore.
|
||||
|
||||
:param config: The model configuration
|
||||
:param overrides_to_apply: override dictionary, parameter names to values
|
||||
:param keys_to_ignore: set of dictionary keys not to report on
|
||||
"""
|
||||
assert isinstance(config, param.Parameterized)
|
||||
for key, desired in overrides_to_apply.items():
|
||||
if key in keys_to_ignore:
|
||||
continue
|
||||
actual = getattr(config, key, None)
|
||||
if actual == desired:
|
||||
continue
|
||||
if key not in config.params():
|
||||
reason = "parameter is undefined"
|
||||
else:
|
||||
val = config.params()[key]
|
||||
reason = reason_not_overridable(val) # type: ignore
|
||||
if reason is None:
|
||||
reason = "for UNKNOWN REASONS"
|
||||
else:
|
||||
raise TypeError("Parameter of type: {} is not supported".format(_p))
|
||||
|
||||
return p_type
|
||||
|
||||
def add_boolean_argument(parser: ArgumentParser, k: str, p: param.Parameter) -> None:
|
||||
"""
|
||||
Add a boolean argument.
|
||||
If the parameter default is False then allow --flag (to set it True) and --flag=Bool as usual.
|
||||
If the parameter default is True then allow --no-flag (to set it to False) and --flag=Bool as usual.
|
||||
|
||||
:param parser: parser to add a boolean argument to.
|
||||
:param k: argument name.
|
||||
:param p: boolean parameter.
|
||||
"""
|
||||
if not p.default:
|
||||
# If the parameter default is False then use nargs="?" (OPTIONAL).
|
||||
# This means that the argument is optional.
|
||||
# If it is not supplied, i.e. in the --flag mode, use the "const" value, i.e. True.
|
||||
# Otherwise, i.e. in the --flag=value mode, try to parse the argument as a bool.
|
||||
parser.add_argument("--" + k, help=p.doc, type=parse_bool, default=False,
|
||||
nargs=OPTIONAL, const=True)
|
||||
else:
|
||||
# If the parameter default is True then create an exclusive group of arguments.
|
||||
# Either --flag=value as usual
|
||||
# Or --no-flag to store False in the parameter k.
|
||||
group = parser.add_mutually_exclusive_group(required=False)
|
||||
group.add_argument("--" + k, help=p.doc, type=parse_bool)
|
||||
group.add_argument('--no-' + k, dest=k, action='store_false')
|
||||
parser.set_defaults(**{k: p.default})
|
||||
|
||||
for k, p in cls.get_overridable_parameters().items():
|
||||
# param.Booleans need to be handled separately, they are more complicated because they have
|
||||
# an optional argument.
|
||||
if isinstance(p, param.Boolean):
|
||||
add_boolean_argument(parser, k, p)
|
||||
else:
|
||||
parser.add_argument("--" + k, help=p.doc, type=_get_basic_type(p), default=p.default)
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def parse_args(cls: Type[T], args: Optional[List[str]] = None) -> T:
|
||||
"""
|
||||
Creates an argparser based on the params class and parses stdin args (or the args provided)
|
||||
|
||||
:param args: The arguments to be parsed
|
||||
"""
|
||||
return cls(**vars(cls.create_argparser().parse_args(args))) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_overridable_parameters(cls) -> Dict[str, param.Parameter]:
|
||||
"""
|
||||
Get properties that are not constant, readonly or private (eg: prefixed with an underscore).
|
||||
|
||||
:return: A dictionary of parameter names and their definitions.
|
||||
"""
|
||||
return dict((k, v) for k, v in cls.params().items()
|
||||
if cls.reason_not_overridable(v) is None)
|
||||
|
||||
@staticmethod
|
||||
def reason_not_overridable(value: param.Parameter) -> Optional[str]:
|
||||
"""
|
||||
Given a parameter, check for attributes that denote it is not overrideable (e.g. readonly, constant,
|
||||
private etc). If such an attribute exists, return a string containing a single-word description of the
|
||||
reason. Otherwise returns None.
|
||||
|
||||
:param value: a parameter value
|
||||
:return: None if the parameter is overridable; otherwise a one-word string explaining why not.
|
||||
"""
|
||||
if value.readonly:
|
||||
return "readonly"
|
||||
elif value.constant:
|
||||
return "constant"
|
||||
elif is_private_field_name(value.name):
|
||||
return "private"
|
||||
elif isinstance(value, param.Callable):
|
||||
return "callable"
|
||||
return None
|
||||
|
||||
def apply_overrides(self, values: Optional[Dict[str, Any]], should_validate: bool = True,
|
||||
keys_to_ignore: Optional[Set[str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies the provided `values` overrides to the config.
|
||||
Only properties that are marked as overridable are actually overwritten.
|
||||
|
||||
:param values: A dictionary mapping from field name to value.
|
||||
:param should_validate: If true, run the .validate() method after applying overrides.
|
||||
:param keys_to_ignore: keys to ignore in reporting failed overrides. If None, do not report.
|
||||
:return: A dictionary with all the fields that were modified.
|
||||
"""
|
||||
|
||||
def _apply(_overrides: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
applied: Dict[str, Any] = {}
|
||||
if _overrides is not None:
|
||||
overridable_parameters = self.get_overridable_parameters().keys()
|
||||
for k, v in _overrides.items():
|
||||
if k in overridable_parameters:
|
||||
applied[k] = v
|
||||
setattr(self, k, v)
|
||||
|
||||
return applied
|
||||
|
||||
actual_overrides = _apply(values)
|
||||
if keys_to_ignore is not None:
|
||||
self.report_on_overrides(values, keys_to_ignore) # type: ignore
|
||||
if should_validate:
|
||||
self.validate()
|
||||
return actual_overrides
|
||||
|
||||
def report_on_overrides(self, values: Dict[str, Any], keys_to_ignore: Optional[Set[str]] = None) -> None:
|
||||
"""
|
||||
Logs a warning for every parameter whose value is not as given in "values", other than those
|
||||
in keys_to_ignore.
|
||||
|
||||
:param values: override dictionary, parameter names to values
|
||||
:param keys_to_ignore: set of dictionary keys not to report on
|
||||
:return: None
|
||||
"""
|
||||
for key, desired in values.items():
|
||||
# If this isn't an AzureConfig instance, we don't want to warn on keys intended for it.
|
||||
if keys_to_ignore and (key in keys_to_ignore):
|
||||
continue
|
||||
actual = getattr(self, key, None)
|
||||
if actual == desired:
|
||||
continue
|
||||
if key not in self.params():
|
||||
reason = "parameter is undefined"
|
||||
else:
|
||||
val = self.params()[key]
|
||||
reason = self.reason_not_overridable(val) # type: ignore
|
||||
if reason is None:
|
||||
reason = "for UNKNOWN REASONS"
|
||||
else:
|
||||
reason = f"parameter is {reason}"
|
||||
# We could raise an error here instead - to be discussed.
|
||||
logging.warning(f"Override {key}={desired} failed: {reason} in class {self.__class__.name}")
|
||||
reason = f"parameter is {reason}"
|
||||
# We could raise an error here instead - to be discussed.
|
||||
logging.warning(f"Override {key}={desired} failed: {reason} in class {config.__class__.name}")
|
||||
|
||||
|
||||
def create_from_matching_params(from_object: param.Parameterized, cls_: Type[T]) -> T:
|
||||
|
@ -565,6 +659,7 @@ def _find_file(file_name: str, stop_at_pythonpath: bool = True) -> Optional[Path
|
|||
file_name: str,
|
||||
stop_at_pythonpath: bool,
|
||||
pythonpaths: List[Path]) -> Optional[Path]:
|
||||
|
||||
logging.debug(f"Searching for file {file_name} in {start_at}")
|
||||
expected = start_at / file_name
|
||||
if expected.is_file() and expected.name == file_name:
|
||||
|
@ -761,23 +856,72 @@ def _log_conda_dependencies_stats(conda: CondaDependencies, message_prefix: str)
|
|||
logging.debug(f" {p}")
|
||||
|
||||
|
||||
def merge_conda_files(files: List[Path], result_file: Path) -> None:
|
||||
def _retrieve_unique_deps(dependencies: List[str], keep_method: str = "first") -> List[str]:
|
||||
"""
|
||||
Merges the given Conda environment files using the conda_merge package, and writes the merged file to disk.
|
||||
Given a list of conda dependencies, which may contain duplicate versions
|
||||
of the same package name with the same or different versions, returns a
|
||||
list of them where each package name occurs only once. If a
|
||||
package name appears more than once, only the first value will be retained.
|
||||
|
||||
:param files: The Conda environment files to read.
|
||||
:param result_file: The location where the merge results should be written.
|
||||
:param dependencies: The original list of package names to deduplicate
|
||||
:param keep_method: The strategy for choosing which package version to keep
|
||||
:return: a list in which each package name occurs only once
|
||||
"""
|
||||
for file in files:
|
||||
_log_conda_dependencies_stats(CondaDependencies(file), f"Conda environment in {file}")
|
||||
# This code is a slightly modified version of conda_merge. That code can't be re-used easily
|
||||
# it defaults to writing to stdout
|
||||
env_definitions = [conda_merge.read_file(str(f)) for f in files]
|
||||
unique_deps: Dict[str, Tuple[str, str]] = {}
|
||||
for dep in dependencies:
|
||||
dep_parts: List[str] = re.split("(=<|==|=|>=|<|>)", dep)
|
||||
len_parts = len(dep_parts)
|
||||
dep_name = dep_parts[0]
|
||||
if len_parts > 1:
|
||||
dep_join = ''.join(dep_parts[1:-1])
|
||||
dep_version = dep_parts[-1]
|
||||
else:
|
||||
dep_join = ''
|
||||
dep_version = ''
|
||||
|
||||
if dep_name in unique_deps:
|
||||
if keep_method == "first":
|
||||
keep_version, _ = unique_deps[dep_name]
|
||||
elif keep_method == "last":
|
||||
keep_version = dep_version
|
||||
unique_deps[dep_name] = (keep_version, dep_join)
|
||||
else:
|
||||
raise ValueError(f"Unrecognised value of 'keep_method: {keep_method}'. Accepted values"
|
||||
f" include: ['first', 'last']")
|
||||
logging.warning(f"Found duplicate requirements: {dep}. Keeping the {keep_method} "
|
||||
f"version: {keep_version}")
|
||||
|
||||
else:
|
||||
unique_deps[dep_name] = (dep_version, dep_join)
|
||||
|
||||
unique_deps_list = [f"{pkg}{joiner}{vrsn}" for pkg, (vrsn, joiner) in unique_deps.items()]
|
||||
return unique_deps_list
|
||||
|
||||
|
||||
def merge_conda_files(conda_files: List[Path], result_file: Path, pip_files: List[Path] = None,
|
||||
pip_clash_keep_method: str = "first") -> None:
|
||||
"""
|
||||
Merges the given Conda environment files using the conda_merge package, optionally adds any
|
||||
dependencies from pip requirements files, and writes the merged file to disk.
|
||||
|
||||
:param conda_files: The Conda environment files to read.
|
||||
:param result_file: The location where the merge results should be written.
|
||||
:param pip_files: An optional list of one or more pip requirements files including extra dependencies.
|
||||
:param pip_clash_keep_method: If two or more pip packages are specified with the same name, this determines
|
||||
which one should be kept. Current options: ['first', 'last']
|
||||
"""
|
||||
env_definitions = [conda_merge.read_file(str(f)) for f in conda_files]
|
||||
unified_definition = {}
|
||||
NAME = "name"
|
||||
CHANNELS = "channels"
|
||||
DEPENDENCIES = "dependencies"
|
||||
|
||||
extra_pip_deps = []
|
||||
for pip_file in pip_files or []:
|
||||
with open(pip_file, "r") as f_path:
|
||||
additional_pip_deps = [d for d in f_path.read().split("\n") if d]
|
||||
extra_pip_deps.extend(additional_pip_deps)
|
||||
|
||||
name = conda_merge.merge_names(env.get(NAME) for env in env_definitions)
|
||||
if name:
|
||||
unified_definition[NAME] = name
|
||||
|
@ -791,12 +935,34 @@ def merge_conda_files(files: List[Path], result_file: Path) -> None:
|
|||
unified_definition[CHANNELS] = channels
|
||||
|
||||
try:
|
||||
deps = conda_merge.merge_dependencies(env.get(DEPENDENCIES) for env in env_definitions)
|
||||
deps_to_merge = [env.get(DEPENDENCIES) for env in env_definitions]
|
||||
if len(extra_pip_deps) > 0:
|
||||
deps_to_merge.extend([[{"pip": extra_pip_deps}]])
|
||||
deps = conda_merge.merge_dependencies(deps_to_merge)
|
||||
|
||||
# Remove duplicated pip packages from merged dependencies sections. Note that for a package that is
|
||||
# duplicated, the first value encountered will be retained.
|
||||
pip_deps_entries = [d for d in deps if isinstance(d, dict) and "pip" in d] # type: ignore
|
||||
if len(pip_deps_entries) == 0:
|
||||
raise ValueError("Didn't find a dictionary with the key 'pip' in the list of dependencies")
|
||||
pip_deps_entry: Dict[str, List[str]] = pip_deps_entries[0]
|
||||
pip_deps = pip_deps_entry["pip"]
|
||||
# temporarily remove pip dependencies from deps to be added back after deduplicaton
|
||||
deps.remove(pip_deps_entry)
|
||||
|
||||
# remove all non-pip duplicates from the list of dependencies
|
||||
unique_deps = _retrieve_unique_deps(deps, keep_method=pip_clash_keep_method)
|
||||
|
||||
unique_pip_deps = _retrieve_unique_deps(pip_deps, keep_method=pip_clash_keep_method)
|
||||
|
||||
# finally add back the deduplicated list of dependencies
|
||||
unique_deps.append({"pip": unique_pip_deps}) # type: ignore
|
||||
|
||||
except conda_merge.MergeError:
|
||||
logging.error("Failed to merge dependencies.")
|
||||
raise
|
||||
if deps:
|
||||
unified_definition[DEPENDENCIES] = deps
|
||||
if unique_deps:
|
||||
unified_definition[DEPENDENCIES] = unique_deps
|
||||
else:
|
||||
raise ValueError("No dependencies found in any of the conda files.")
|
||||
|
||||
|
@ -1230,7 +1396,7 @@ def upload_to_datastore(datastore_name: str, local_data_folder: Path, remote_pat
|
|||
logging.info(f"Uploaded data to {str(remote_path)}")
|
||||
|
||||
|
||||
class AmlRunScriptConfig(GenericConfig):
|
||||
class AmlRunScriptConfig(param.Parameterized):
|
||||
"""
|
||||
Base config for a script that handles Azure ML Runs, which can be retrieved with either a run id, latest_run_file,
|
||||
or by giving the experiment name (optionally alongside tags and number of runs to retrieve). A config file path can
|
||||
|
@ -1319,6 +1485,16 @@ def is_running_in_azure_ml(aml_run: Run = RUN_CONTEXT) -> bool:
|
|||
return hasattr(aml_run, 'experiment')
|
||||
|
||||
|
||||
def is_running_on_azure_agent() -> bool:
|
||||
"""
|
||||
Determine whether the current code is running on an Azure agent by examing the environment variable
|
||||
for AGENT_OS, that all Azure hosted agents define.
|
||||
|
||||
:return: True if the code appears to be running on an Azure build agent, and False otherwise.
|
||||
"""
|
||||
return bool(os.environ.get("AGENT_OS", None))
|
||||
|
||||
|
||||
def torch_barrier() -> None:
|
||||
"""
|
||||
This is a barrier to use in distributed jobs. Use it to make all processes that participate in a distributed
|
||||
|
|
|
@ -10,10 +10,10 @@ import logging
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from argparse import ArgumentParser, Namespace, ArgumentError
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
@ -36,6 +36,7 @@ from testazure.test_himl import RunTarget, render_and_run_test_script
|
|||
from testazure.utils_testazure import (DEFAULT_IGNORE_FOLDERS, DEFAULT_WORKSPACE, MockRun, change_working_directory,
|
||||
repository_root)
|
||||
|
||||
|
||||
RUN_ID = uuid4().hex
|
||||
RUN_NUMBER = 42
|
||||
EXPERIMENT_NAME = "fancy-experiment"
|
||||
|
@ -252,6 +253,16 @@ def test_split_recovery_id(id: str, expected1: str, expected2: str) -> None:
|
|||
assert util.split_recovery_id(id) == (expected1, expected2)
|
||||
|
||||
|
||||
def test_retrieve_unique_deps() -> None:
|
||||
deps_with_duplicates = ["package==1.0", "package==1.1", "git+https:www.github.com/something.git"]
|
||||
|
||||
dedup_deps = util._retrieve_unique_deps(deps_with_duplicates) # type: ignore
|
||||
assert dedup_deps == ["package==1.0", "git+https:www.github.com/something.git"]
|
||||
|
||||
dedup_deps_keep_last = util._retrieve_unique_deps(deps_with_duplicates, keep_method="last")
|
||||
assert dedup_deps_keep_last == ["package==1.1", "git+https:www.github.com/something.git"]
|
||||
|
||||
|
||||
def test_merge_conda(
|
||||
random_folder: Path,
|
||||
caplog: CaptureFixture,
|
||||
|
@ -298,12 +309,10 @@ dependencies:
|
|||
- pytorch
|
||||
dependencies:
|
||||
- conda1=1.0
|
||||
- conda1=1.1
|
||||
- conda2=2.0
|
||||
- conda_both=3.0
|
||||
- pip:
|
||||
- azureml-sdk==1.6.0
|
||||
- azureml-sdk==1.7.0
|
||||
- bar==2.0
|
||||
- foo==1.0
|
||||
""".splitlines()
|
||||
|
@ -313,8 +322,30 @@ dependencies:
|
|||
assert list(conda_dep.conda_channels) == ["defaults", "pytorch"]
|
||||
|
||||
# Package version conflicts are not resolved, both versions are retained.
|
||||
assert list(conda_dep.conda_packages) == ["conda1=1.0", "conda1=1.1", "conda2=2.0", "conda_both=3.0"]
|
||||
assert list(conda_dep.pip_packages) == ["azureml-sdk==1.6.0", "azureml-sdk==1.7.0", "bar==2.0", "foo==1.0"]
|
||||
assert list(conda_dep.conda_packages) == ["conda1=1.0", "conda2=2.0", "conda_both=3.0"]
|
||||
assert list(conda_dep.pip_packages) == ["azureml-sdk==1.6.0", "bar==2.0", "foo==1.0"]
|
||||
|
||||
# Assert that extra pip requirements are added correctly
|
||||
pip_contents = """package1==0.0.1
|
||||
package2==0.0.1
|
||||
"""
|
||||
pip_file = random_folder / "req.txt"
|
||||
pip_file.write_text(pip_contents)
|
||||
util.merge_conda_files(files, merged_file, pip_files=[pip_file])
|
||||
merged_file_text = merged_file.read_text()
|
||||
assert merged_file_text.splitlines() == """channels:
|
||||
- defaults
|
||||
- pytorch
|
||||
dependencies:
|
||||
- conda1=1.0
|
||||
- conda2=2.0
|
||||
- conda_both=3.0
|
||||
- pip:
|
||||
- azureml-sdk==1.6.0
|
||||
- bar==2.0
|
||||
- foo==1.0
|
||||
- package1==0.0.1
|
||||
- package2==0.0.1""".splitlines()
|
||||
|
||||
# Are names merged correctly?
|
||||
assert "name:" not in merged_file_text
|
||||
|
@ -349,9 +380,8 @@ dependencies:
|
|||
# If there are no dependencies then something is wrong with the conda files or our parsing of them
|
||||
with mock.patch("health_azure.utils.conda_merge.merge_dependencies") as mock_merge_dependencies:
|
||||
mock_merge_dependencies.return_value = []
|
||||
with pytest.raises(ValueError) as e:
|
||||
with pytest.raises(ValueError):
|
||||
util.merge_conda_files(files, merged_file)
|
||||
assert "No dependencies found in any of the conda files" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["s", "expected"],
|
||||
|
@ -933,12 +963,13 @@ def test_get_run_source(dummy_recovery_id: str,
|
|||
arguments = ["", "--run", dummy_recovery_id]
|
||||
with patch.object(sys, "argv", arguments):
|
||||
|
||||
run_source = util.AmlRunScriptConfig.parse_args()
|
||||
script_config = util.AmlRunScriptConfig()
|
||||
script_config = util.parse_args_and_update_config(script_config, arguments)
|
||||
|
||||
if isinstance(run_source.run, List):
|
||||
assert isinstance(run_source.run[0], str)
|
||||
if isinstance(script_config.run, List):
|
||||
assert isinstance(script_config.run[0], str)
|
||||
else:
|
||||
assert isinstance(run_source.run, str)
|
||||
assert isinstance(script_config.run, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("overwrite", [True, False])
|
||||
|
@ -1036,7 +1067,8 @@ def test_upload_to_datastore(tmp_path: Path, overwrite: bool, show_progress: boo
|
|||
])
|
||||
def test_script_config_run_src(arguments: List[str], run_id: Union[str, List[str]]) -> None:
|
||||
with patch.object(sys, "argv", arguments):
|
||||
script_config = util.AmlRunScriptConfig.parse_args()
|
||||
script_config = util.AmlRunScriptConfig()
|
||||
script_config = util.parse_args_and_update_config(script_config, arguments)
|
||||
|
||||
if isinstance(run_id, list):
|
||||
for script_config_run, expected_run_id in zip(script_config.run, run_id):
|
||||
|
@ -1172,7 +1204,65 @@ class IllegalCustomTypeNoValidate(util.CustomTypeParam):
|
|||
return x
|
||||
|
||||
|
||||
class ParamClass(util.GenericConfig):
|
||||
class DummyConfig(param.Parameterized):
|
||||
string_param = param.String()
|
||||
int_param = param.Integer()
|
||||
|
||||
def validate(self) -> None:
|
||||
assert isinstance(self.string_param, str)
|
||||
assert isinstance(self.int_param, int)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dummy_model_config() -> DummyConfig:
|
||||
string_param = "dummy"
|
||||
int_param = 1
|
||||
return DummyConfig(param1=string_param, param2=int_param)
|
||||
|
||||
|
||||
def test_add_and_validate(dummy_model_config: DummyConfig) -> None:
|
||||
new_string_param = "new_dummy"
|
||||
new_int_param = 2
|
||||
new_args = {"string_param": new_string_param, "int_param": new_int_param}
|
||||
util.set_fields_and_validate(dummy_model_config, new_args)
|
||||
|
||||
assert dummy_model_config.string_param == new_string_param
|
||||
assert dummy_model_config.int_param == new_int_param
|
||||
|
||||
|
||||
def test_create_argparse(dummy_model_config: DummyConfig) -> None:
|
||||
with patch("health_azure.utils._add_overrideable_config_args_to_parser") as mock_add_args:
|
||||
parser = util.create_argparser(dummy_model_config)
|
||||
mock_add_args.assert_called_once()
|
||||
assert isinstance(parser, ArgumentParser)
|
||||
|
||||
|
||||
def test_add_args(dummy_model_config: DummyConfig) -> None:
|
||||
parser = ArgumentParser()
|
||||
# assert that calling parse_args on a default ArgumentParser returns an empty Namespace
|
||||
args = parser.parse_args([])
|
||||
assert args == Namespace()
|
||||
# now call _add_overrideable_config_args_to_parser and assert that calling parse_args on the result
|
||||
# of that is a non-empty Namepsace
|
||||
with patch("health_azure.utils.get_overridable_parameters") as mock_get_overridable_parameters:
|
||||
mock_get_overridable_parameters.return_value = {"string_param": param.String(default="Hello")}
|
||||
parser = util._add_overrideable_config_args_to_parser(dummy_model_config, parser)
|
||||
assert isinstance(parser, ArgumentParser)
|
||||
args = parser.parse_args([])
|
||||
assert args != Namespace()
|
||||
assert args.string_param == "Hello"
|
||||
|
||||
|
||||
def test_parse_args(dummy_model_config: DummyConfig) -> None:
|
||||
new_string_arg = "dummy_string"
|
||||
new_args = ["--string_param", new_string_arg]
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--string_param", type=str, default=None)
|
||||
parser_result = util.parse_arguments(parser, args=new_args)
|
||||
assert parser_result.args.get("string_param") == new_string_arg
|
||||
|
||||
|
||||
class ParamClass(param.Parameterized):
|
||||
name: str = param.String(None, doc="Name")
|
||||
seed: int = param.Integer(42, doc="Seed")
|
||||
flag: bool = param.Boolean(False, doc="Flag")
|
||||
|
@ -1190,6 +1280,9 @@ class ParamClass(util.GenericConfig):
|
|||
constant: str = param.String("Nope", constant=True)
|
||||
other_args = util.ListOrDictParam(None, doc="List or dictionary of other args")
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ClassFrom(param.Parameterized):
|
||||
foo = param.String("foo")
|
||||
|
@ -1210,12 +1303,20 @@ class NotParameterized:
|
|||
foo = 1
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def parameterized_config_and_parser() -> Tuple[ParamClass, ArgumentParser]:
|
||||
parameterized_config = ParamClass()
|
||||
parser = util.create_argparser(parameterized_config)
|
||||
return parameterized_config, parser
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_overridable_parameter() -> None:
|
||||
def test_get_overridable_parameter(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser]) -> None:
|
||||
"""
|
||||
Test to check overridable parameters are correctly identified.
|
||||
"""
|
||||
param_dict = ParamClass.get_overridable_parameters()
|
||||
parameterized_config = parameterized_config_and_parser[0]
|
||||
param_dict = util.get_overridable_parameters(parameterized_config)
|
||||
assert "name" in param_dict
|
||||
assert "flag" in param_dict
|
||||
assert "not_flag" in param_dict
|
||||
|
@ -1235,12 +1336,13 @@ def test_overridable_parameter() -> None:
|
|||
|
||||
|
||||
@pytest.mark.fast
|
||||
def test_parser_defaults() -> None:
|
||||
def test_parser_defaults(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser]) -> None:
|
||||
"""
|
||||
Check that default values are created as expected, and that the non-overridable parameters
|
||||
are omitted.
|
||||
"""
|
||||
defaults = vars(ParamClass.create_argparser().parse_args([]))
|
||||
parameterized_config = parameterized_config_and_parser[0]
|
||||
defaults = vars(util.create_argparser(parameterized_config).parse_args([]))
|
||||
assert defaults["seed"] == 42
|
||||
assert defaults["tuple1"] == (1, 2.3)
|
||||
assert defaults["int_tuple"] == (1, 1, 1)
|
||||
|
@ -1254,17 +1356,20 @@ def test_parser_defaults() -> None:
|
|||
# upon errors.
|
||||
|
||||
|
||||
def check_parsing_succeeds(arg: List[str], expected_key: str, expected_value: Any) -> None:
|
||||
parsed = ParamClass.parse_args(arg)
|
||||
assert getattr(parsed, expected_key) == expected_value
|
||||
def check_parsing_succeeds(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
|
||||
arg: List[str],
|
||||
expected_key: str,
|
||||
expected_value: Any) -> None:
|
||||
parameterized_config, parser = parameterized_config_and_parser
|
||||
parser_result = util.parse_arguments(parser, args=arg)
|
||||
assert parser_result.args.get(expected_key) == expected_value
|
||||
|
||||
|
||||
def check_parsing_fails(arg: List[str], expected_key: Optional[str] = None, expected_value: Optional[Any] = None
|
||||
) -> None:
|
||||
with pytest.raises(SystemExit) as e:
|
||||
ParamClass.parse_args(arg)
|
||||
assert e.type == SystemExit
|
||||
assert e.value.code == 2
|
||||
def check_parsing_fails(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
|
||||
arg: List[str]) -> None:
|
||||
parameterized_config, parser = parameterized_config_and_parser
|
||||
with pytest.raises(Exception):
|
||||
util.parse_arguments(parser, args=arg, fail_on_unknown_args=True)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
|
@ -1296,14 +1401,18 @@ def check_parsing_fails(arg: List[str], expected_key: Optional[str] = None, expe
|
|||
(["--other_args={'learning':3"], None, None, False),
|
||||
(["--other_args=['foo','bar'"], None, None, False)
|
||||
])
|
||||
def test_create_parser(args: List[str], expected_key: str, expected_value: Any, expected_pass: bool) -> None:
|
||||
def test_create_parser(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
|
||||
args: List[str],
|
||||
expected_key: str,
|
||||
expected_value: Any,
|
||||
expected_pass: bool) -> None:
|
||||
"""
|
||||
Check that parse_args works as expected, with both non default and default values.
|
||||
"""
|
||||
if expected_pass:
|
||||
check_parsing_succeeds(args, expected_key, expected_value)
|
||||
check_parsing_succeeds(parameterized_config_and_parser, args, expected_key, expected_value)
|
||||
else:
|
||||
check_parsing_fails(args)
|
||||
check_parsing_fails(parameterized_config_and_parser, args)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
|
@ -1311,73 +1420,106 @@ def test_create_parser(args: List[str], expected_key: str, expected_value: Any,
|
|||
('on', True), ('t', True), ('true', True), ('y', True), ('yes', True), ('1', True),
|
||||
('off', False), ('f', False), ('false', False), ('n', False), ('no', False), ('0', False)
|
||||
])
|
||||
def test_parsing_bools(flag: str, expected_value: bool) -> None:
|
||||
def test_parsing_bools(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
|
||||
flag: str,
|
||||
expected_value: bool) -> None:
|
||||
"""
|
||||
Check all the ways of passing in True and False, with and without the first letter capitialized
|
||||
"""
|
||||
check_parsing_succeeds([f"--flag={flag}"], "flag", expected_value)
|
||||
check_parsing_succeeds([f"--flag={flag.capitalize()}"], "flag", expected_value)
|
||||
check_parsing_succeeds([f"--not_flag={flag}"], "not_flag", expected_value)
|
||||
check_parsing_succeeds([f"--not_flag={flag.capitalize()}"], "not_flag", expected_value)
|
||||
check_parsing_succeeds(parameterized_config_and_parser,
|
||||
[f"--flag={flag}"],
|
||||
"flag",
|
||||
expected_value)
|
||||
check_parsing_succeeds(parameterized_config_and_parser,
|
||||
[f"--flag={flag.capitalize()}"],
|
||||
"flag",
|
||||
expected_value)
|
||||
check_parsing_succeeds(parameterized_config_and_parser,
|
||||
[f"--not_flag={flag}"],
|
||||
"not_flag",
|
||||
expected_value)
|
||||
check_parsing_succeeds(parameterized_config_and_parser,
|
||||
[f"--not_flag={flag.capitalize()}"],
|
||||
"not_flag",
|
||||
expected_value)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
@patch("health_azure.utils.GenericConfig.report_on_overrides")
|
||||
@patch("health_azure.utils.GenericConfig.validate")
|
||||
def test_apply_overrides(mock_validate: MagicMock, mock_report_on_overrides: MagicMock) -> None:
|
||||
def test_apply_overrides(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser]) -> None:
|
||||
"""
|
||||
Test that overrides are applied correctly, ond only to overridable parameters,
|
||||
Test that overrides are applied correctly, ond only to overridable parameters
|
||||
"""
|
||||
m = ParamClass()
|
||||
overrides = {"name": "newName", "int_tuple": (0, 1, 2)}
|
||||
actual_overrides = m.apply_overrides(overrides)
|
||||
assert actual_overrides == overrides
|
||||
assert all([x == i and isinstance(x, int) for i, x in enumerate(m.int_tuple)])
|
||||
assert m.name == "newName"
|
||||
# Attempt to change seed and constant, but the latter should be ignored.
|
||||
change_seed = {"seed": 123}
|
||||
old_constant = m.constant
|
||||
changes2 = m.apply_overrides({**change_seed, "constant": "Nothing"}) # type: ignore
|
||||
assert changes2 == change_seed
|
||||
assert m.seed == 123
|
||||
assert m.constant == old_constant
|
||||
parameterized_config = parameterized_config_and_parser[0]
|
||||
with patch("health_azure.utils.report_on_overrides") as mock_report_on_overrides:
|
||||
overrides = {"name": "newName", "int_tuple": (0, 1, 2)}
|
||||
actual_overrides = util.apply_overrides(parameterized_config, overrides)
|
||||
assert actual_overrides == overrides
|
||||
assert all([x == i and isinstance(x, int) for i, x in enumerate(parameterized_config.int_tuple)])
|
||||
assert parameterized_config.name == "newName"
|
||||
|
||||
# Check the call count of mock_validate and check it doesn't increase if should_validate is set to False
|
||||
# and that setting this flag doesn't affect on the outputs
|
||||
mock_validate_call_count = mock_validate.call_count
|
||||
actual_overrides = m.apply_overrides(values=overrides, should_validate=False)
|
||||
assert actual_overrides == overrides
|
||||
assert mock_validate.call_count == mock_validate_call_count
|
||||
# Attempt to change seed and constant, but the latter should be ignored.
|
||||
change_seed = {"seed": 123}
|
||||
old_constant = parameterized_config.constant
|
||||
extra_overrides = {**change_seed, "constant": "Nothing"} # type: ignore
|
||||
changes2 = util.apply_overrides(parameterized_config, overrides_to_apply=extra_overrides) # type: ignore
|
||||
assert changes2 == change_seed
|
||||
assert parameterized_config.seed == 123
|
||||
assert parameterized_config.constant == old_constant
|
||||
|
||||
# Check that report_on_overrides has not yet been called, but is called if keys_to_ignore is not None
|
||||
# and that setting this flag doesn't affect on the outputs
|
||||
assert mock_report_on_overrides.call_count == 0
|
||||
actual_overrides = m.apply_overrides(values=overrides, keys_to_ignore={"name"})
|
||||
assert actual_overrides == overrides
|
||||
assert mock_report_on_overrides.call_count == 1
|
||||
# Check the call count of mock_validate and check it doesn't increase if should_validate is set to False
|
||||
# and that setting this flag doesn't affect on the outputs
|
||||
# mock_validate_call_count = mock_validate.call_count
|
||||
actual_overrides = util.apply_overrides(parameterized_config,
|
||||
overrides_to_apply=overrides,
|
||||
should_validate=False)
|
||||
assert actual_overrides == overrides
|
||||
# assert mock_validate.call_count == mock_validate_call_count
|
||||
|
||||
# Check that report_on_overrides has not yet been called, but is called if keys_to_ignore is not None
|
||||
# and that setting this flag doesn't affect on the outputs
|
||||
assert mock_report_on_overrides.call_count == 0
|
||||
actual_overrides = util.apply_overrides(parameterized_config,
|
||||
overrides_to_apply=overrides,
|
||||
keys_to_ignore={"name"})
|
||||
assert actual_overrides == overrides
|
||||
assert mock_report_on_overrides.call_count == 1
|
||||
|
||||
|
||||
def test_report_on_overrides() -> None:
|
||||
m = ParamClass()
|
||||
overrides = {"name": "newName", "int_tuple": (0, 1, 2)}
|
||||
m.report_on_overrides(overrides)
|
||||
def test_report_on_overrides(parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser],
|
||||
caplog: LogCaptureFixture) -> None:
|
||||
caplog.set_level(logging.WARNING)
|
||||
parameterized_config = parameterized_config_and_parser[0]
|
||||
old_logs = caplog.messages
|
||||
assert len(old_logs) == 0
|
||||
# the following overrides are expected to cause logged warnings because
|
||||
# a) parameter 'constant' is constant
|
||||
# b) parameter 'readonly' is readonly
|
||||
# b) parameter 'idontexist' is undefined (not the name of a parameter of ParamClass)
|
||||
overrides = {"constant": "dif_value", "readonly": "new_value", "idontexist": (0, 1, 2)}
|
||||
keys_to_ignore: Set = set()
|
||||
util.report_on_overrides(parameterized_config, overrides, keys_to_ignore)
|
||||
# Expect one warning message per failed override
|
||||
new_logs = caplog.messages
|
||||
expected_warnings = len(overrides.keys())
|
||||
assert len(new_logs) == expected_warnings, f"Expected {expected_warnings} warnings but found: {caplog.records}"
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
@pytest.mark.parametrize("value_idx_0", [1.0, 1])
|
||||
@pytest.mark.parametrize("value_idx_1", [2.0, 2])
|
||||
@pytest.mark.parametrize("value_idx_2", [3.0, 3])
|
||||
def test_int_tuple_validation(value_idx_0: Any, value_idx_1: Any, value_idx_2: Any) -> None:
|
||||
def test_int_tuple_validation(value_idx_0: Any, value_idx_1: Any, value_idx_2: Any,
|
||||
parameterized_config_and_parser: Tuple[ParamClass, ArgumentParser]) -> None:
|
||||
"""
|
||||
Test integer tuple parameter is validated correctly.
|
||||
"""
|
||||
m = ParamClass()
|
||||
parameterized_config = parameterized_config_and_parser[0]
|
||||
val = (value_idx_0, value_idx_1, value_idx_2)
|
||||
if not all([isinstance(x, int) for x in val]):
|
||||
with pytest.raises(ValueError):
|
||||
m.int_tuple = (value_idx_0, value_idx_1, value_idx_2)
|
||||
parameterized_config.int_tuple = (value_idx_0, value_idx_1, value_idx_2)
|
||||
else:
|
||||
m.int_tuple = (value_idx_0, value_idx_1, value_idx_2)
|
||||
parameterized_config.int_tuple = (value_idx_0, value_idx_1, value_idx_2)
|
||||
|
||||
|
||||
@pytest.mark.fast
|
||||
|
@ -1404,42 +1546,24 @@ def test_create_from_matching_params() -> None:
|
|||
|
||||
|
||||
def test_parse_illegal_params() -> None:
|
||||
with pytest.raises(ValueError) as e:
|
||||
with pytest.raises(TypeError) as e:
|
||||
ParamClass(readonly="abc")
|
||||
assert "cannot be overridden" in str(e.value)
|
||||
|
||||
|
||||
def test_parse_throw_if_unknown() -> None:
|
||||
with pytest.raises(ValueError) as e:
|
||||
ParamClass(throw_if_unknown_param=True, idontexist="hello")
|
||||
assert "parameters do not exist" in str(e.value)
|
||||
|
||||
|
||||
@patch("health_azure.utils.GenericConfig.validate")
|
||||
def test_config_validate(mock_validate: MagicMock) -> None:
|
||||
_ = ParamClass(should_validate=False)
|
||||
assert mock_validate.call_count == 0
|
||||
|
||||
_ = ParamClass(should_validate=True)
|
||||
assert mock_validate.call_count == 1
|
||||
|
||||
_ = ParamClass()
|
||||
assert mock_validate.call_count == 2
|
||||
assert "cannot be modified" in str(e.value)
|
||||
|
||||
|
||||
def test_config_add_and_validate() -> None:
|
||||
config = ParamClass.parse_args([])
|
||||
assert config.name == "ParamClass"
|
||||
config.add_and_validate({"name": "foo"})
|
||||
config = ParamClass()
|
||||
assert config.name.startswith("ParamClass")
|
||||
util.set_fields_and_validate(config, {"name": "foo"})
|
||||
assert config.name == "foo"
|
||||
|
||||
assert hasattr(config, "new_property") is False
|
||||
config.add_and_validate({"new_property": "bar"})
|
||||
util.set_fields_and_validate(config, {"new_property": "bar"})
|
||||
assert hasattr(config, "new_property") is True
|
||||
assert config.new_property == "bar"
|
||||
|
||||
|
||||
class IllegalParamClassNoString(util.GenericConfig):
|
||||
class IllegalParamClassNoString(param.Parameterized):
|
||||
custom_type_no_from_string = IllegalCustomTypeNoFromString(
|
||||
None, doc="This should fail since from_string method is missing"
|
||||
)
|
||||
|
@ -1449,8 +1573,10 @@ def test_cant_parse_param_type() -> None:
|
|||
"""
|
||||
Assert that a TypeError is raised when trying to add a custom type with no from_string method as an argument
|
||||
"""
|
||||
config = IllegalParamClassNoString()
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
IllegalParamClassNoString.parse_args([])
|
||||
util.create_argparser(config)
|
||||
assert "is not supported" in str(e.value)
|
||||
|
||||
|
||||
|
@ -1469,33 +1595,39 @@ class EvenNumberParam(util.CustomTypeParam):
|
|||
return int(x)
|
||||
|
||||
|
||||
class MyScriptConfig(util.AmlRunScriptConfig):
|
||||
class MyScriptConfig(param.Parameterized):
|
||||
simple_string: str = param.String(default="")
|
||||
even_number: int = EvenNumberParam(2, doc="your choice of even number", allow_None=False)
|
||||
|
||||
|
||||
def test_my_script_config() -> None:
|
||||
even_number = randint(0, 100) * 2
|
||||
odd_number = even_number + 1
|
||||
none_number = "None"
|
||||
def test_parse_args_and_apply_overrides() -> None:
|
||||
config = MyScriptConfig()
|
||||
assert config.even_number == 2
|
||||
assert config.simple_string == ""
|
||||
|
||||
config = MyScriptConfig.parse_args(["--even_number", f"{even_number}"])
|
||||
assert config.even_number == even_number
|
||||
new_even_number = config.even_number * 2
|
||||
new_string = config.simple_string + "something_new"
|
||||
config_w_results = util.parse_args_and_update_config(config, ["--even_number", str(new_even_number),
|
||||
"--simple_string", new_string])
|
||||
assert config_w_results.even_number == new_even_number
|
||||
assert config_w_results.simple_string == new_string
|
||||
|
||||
# parsing args with unaccepted values should cause an exception to be raised
|
||||
odd_number = new_even_number + 1
|
||||
with pytest.raises(ValueError) as e:
|
||||
MyScriptConfig.parse_args(["--even_number", f"{odd_number}"])
|
||||
util.parse_args_and_update_config(config, args=["--even_number", f"{odd_number}"])
|
||||
assert "not an even number" in str(e.value)
|
||||
|
||||
# If parser can't parse type, will raise a SystemExit
|
||||
with pytest.raises(SystemExit):
|
||||
MyScriptConfig.parse_args(["--even_number", f"{none_number}"])
|
||||
none_number = "None"
|
||||
with pytest.raises(ArgumentError):
|
||||
util.parse_args_and_update_config(config, args=["--even_number", f"{none_number}"])
|
||||
|
||||
# Mock from_string to check test _validate
|
||||
mock_from_string_none = lambda a, b: None
|
||||
mock_from_string_none = lambda a, b: None # type: ignore
|
||||
with patch.object(EvenNumberParam, "from_string", new=mock_from_string_none):
|
||||
# Check that _validate fails with None value
|
||||
with pytest.raises(ValueError) as e:
|
||||
MyScriptConfig.parse_args(["--even_number", f"{none_number}"])
|
||||
util.parse_args_and_update_config(config, ["--even_number", f"{none_number}"])
|
||||
assert "must not be None" in str(e.value)
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ Test the data input and output functionality
|
|||
from pathlib import Path
|
||||
from unittest import mock
|
||||
from health_azure.utils import PathOrString
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import pytest
|
||||
from azureml._restclient.exceptions import ServiceException
|
||||
|
@ -19,7 +19,8 @@ from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
|
|||
from azureml.exceptions._azureml_exception import UserErrorException
|
||||
|
||||
from health_azure.datasets import (DatasetConfig, _input_dataset_key, _output_dataset_key,
|
||||
_replace_string_datasets, get_datastore, get_or_create_dataset)
|
||||
_replace_string_datasets, get_datastore, get_or_create_dataset,
|
||||
create_dataset_configs)
|
||||
from testazure.utils_testazure import DEFAULT_DATASTORE, DEFAULT_WORKSPACE
|
||||
|
||||
|
||||
|
@ -205,3 +206,58 @@ def test_dataset_keys() -> None:
|
|||
assert in1
|
||||
assert out1
|
||||
assert in1 != out1
|
||||
|
||||
|
||||
def test_create_dataset_configs() -> None:
|
||||
azure_datasets: List[str] = []
|
||||
dataset_mountpoints: List[str] = []
|
||||
local_datasets: List[Optional[Path]] = []
|
||||
datastore = None
|
||||
use_mounting = False
|
||||
datasets = create_dataset_configs(azure_datasets,
|
||||
dataset_mountpoints,
|
||||
local_datasets,
|
||||
datastore,
|
||||
use_mounting)
|
||||
assert datasets == []
|
||||
|
||||
# if local_datasets is not empty but azure_datasets still is, expect an empty list
|
||||
local_datasets = [Path("dummy")]
|
||||
datasets = create_dataset_configs(azure_datasets,
|
||||
dataset_mountpoints,
|
||||
local_datasets,
|
||||
datastore,
|
||||
use_mounting)
|
||||
assert datasets == []
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
azure_datasets = ["dummy"]
|
||||
local_datasets = [Path("another_dummy"), Path("another_extra_dummy")]
|
||||
create_dataset_configs(azure_datasets,
|
||||
dataset_mountpoints,
|
||||
local_datasets,
|
||||
datastore,
|
||||
use_mounting)
|
||||
assert "Invalid dataset setup" in str(e)
|
||||
|
||||
az_dataset_name = "dummy"
|
||||
azure_datasets = [az_dataset_name]
|
||||
local_datasets = [Path("another_dummy")]
|
||||
datasets = create_dataset_configs(azure_datasets,
|
||||
dataset_mountpoints,
|
||||
local_datasets,
|
||||
datastore,
|
||||
use_mounting)
|
||||
assert len(datasets) == 1
|
||||
assert isinstance(datasets[0], DatasetConfig)
|
||||
assert datasets[0].name == az_dataset_name
|
||||
|
||||
# If azure dataset name is empty, should still create
|
||||
azure_datasets = [" "]
|
||||
with pytest.raises(Exception) as e:
|
||||
create_dataset_configs(azure_datasets,
|
||||
dataset_mountpoints,
|
||||
local_datasets,
|
||||
datastore,
|
||||
use_mounting)
|
||||
assert "Invalid dataset setup" in str(e)
|
||||
|
|
|
@ -278,6 +278,28 @@ def test_create_run_configuration_correct_env(mock_create_environment: MagicMock
|
|||
mock_register.assert_called_once()
|
||||
assert mock_environment_get.call_count == 2
|
||||
|
||||
# Assert that a Conda env spec with no python version raises an exception
|
||||
conda_env_spec = OrderedDict({"name": "dummy_env",
|
||||
"channels": OrderedList("default"),
|
||||
"dependencies": OrderedList(["- pip=20.1.1"])})
|
||||
|
||||
conda_env_path = tmp_path / "dummy_conda_env_no_python.yml"
|
||||
with open(conda_env_path, "w+") as f_path:
|
||||
yaml.dump(conda_env_spec, f_path)
|
||||
assert conda_env_path.is_file()
|
||||
|
||||
with patch.object(mock_environment, "register") as mock_register:
|
||||
mock_register.return_value = mock_environment
|
||||
|
||||
with patch("azureml.core.Environment.get") as mock_environment_get: # type: ignore
|
||||
mock_environment_get.side_effect = Exception()
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
himl.create_run_configuration(mock_workspace,
|
||||
"dummy_compute_cluster",
|
||||
conda_environment_file=conda_env_path)
|
||||
assert "you must specify the python version" in str(e)
|
||||
|
||||
# check that when create_run_configuration is called, whatever is returned from register_environment
|
||||
# is set as the new "environment" attribute of the run config
|
||||
with patch("health_azure.himl.register_environment") as mock_register_environment:
|
||||
|
|
|
@ -82,7 +82,7 @@ pytest_fast: pip_test call_pytest_fast
|
|||
|
||||
# run pytest with coverage on package, and format coverage output as a text file, assuming test requirements already installed
|
||||
call_pytest_and_coverage:
|
||||
pytest --cov=health_ml --cov-branch --cov-report=html --cov-report=term-missing --cov-report=xml testhiml
|
||||
pytest --cov=health_ml --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc testhiml
|
||||
pycobertura show --format text --output coverage.txt coverage.xml
|
||||
|
||||
# install test requirements and run pytest coverage
|
||||
|
|
|
@ -3,6 +3,7 @@ jinja2==3.0.2
|
|||
matplotlib==3.4.3
|
||||
opencv-python-headless==4.5.1.48
|
||||
pandas==1.3.4
|
||||
pytorch-lightning>=1.4.9
|
||||
torchvision==0.9.0
|
||||
pytorch-lightning==1.5.5
|
||||
rpdb==0.1.6
|
||||
torchvision==0.11.1
|
||||
torch>=1.8
|
||||
|
|
|
@ -86,6 +86,7 @@ setup(
|
|||
install_requires=install_requires,
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'himl-runner = health_ml.runner:main'
|
||||
]
|
||||
}
|
||||
)
|
||||
|
|
|
@ -2,3 +2,13 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from health_ml.model_trainer import model_train
|
||||
from health_ml.run_ml import MLRunner
|
||||
from health_ml.runner import Runner
|
||||
|
||||
|
||||
__all__ = [
|
||||
"model_train",
|
||||
"MLRunner",
|
||||
"Runner"
|
||||
]
|
||||
|
|
|
@ -0,0 +1,235 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch_lightning import LightningDataModule, LightningModule
|
||||
from torchmetrics import MeanAbsoluteError
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.optim.lr_scheduler import StepLR, _LRScheduler
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
|
||||
|
||||
class HelloDataset(Dataset):
|
||||
"""
|
||||
A simple 1dim regression task, read from a data file stored in the test data folder.
|
||||
"""
|
||||
# Creating the data file:
|
||||
# import numpy as np
|
||||
# import torch
|
||||
#
|
||||
# N = 100
|
||||
# x = torch.rand((N, 1)) * 10
|
||||
# y = 0.2 * x + 0.1 * torch.randn(x.size())
|
||||
# xy = torch.cat((x, y), dim=1)
|
||||
# np.savetxt("health_ml/configs/hellocontainer.csv", xy.numpy(), delimiter=",")
|
||||
def __init__(self, raw_data: List[List[float]]) -> None:
|
||||
"""
|
||||
Creates the 1-dim regression dataset.
|
||||
|
||||
:param raw_data: The raw data. This must be numeric data which can be converted into a tensor.
|
||||
See the static method from_path_and_indexes for an example call.
|
||||
"""
|
||||
super().__init__() # type: ignore
|
||||
self.data = torch.tensor(raw_data, dtype=torch.float)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.data.shape[0]
|
||||
|
||||
def __getitem__(self, item: int) -> Dict[str, torch.Tensor]:
|
||||
return {'x': self.data[item][0:1], 'y': self.data[item][1:2]}
|
||||
|
||||
@staticmethod
|
||||
def from_path_and_indexes(
|
||||
root_folder: Path,
|
||||
start_index: int,
|
||||
end_index: int) -> 'HelloDataset':
|
||||
"""
|
||||
Static method to instantiate a HelloDataset from the root folder with the start and end indexes.
|
||||
|
||||
:param root_folder: The folder in which the data file lives ("hellocontainer.csv")
|
||||
:param start_index: The first row to read.
|
||||
:param end_index: The last row to read (exclusive)
|
||||
:return: A new instance based on the root folder and the start and end indexes.
|
||||
"""
|
||||
raw_data = np.loadtxt(root_folder / "hellocontainer.csv", delimiter=",")[start_index:end_index]
|
||||
return HelloDataset(raw_data)
|
||||
|
||||
|
||||
class HelloDataModule(LightningDataModule):
|
||||
"""
|
||||
A data module that gives the training, validation and test data for a simple 1-dim regression task.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
root_folder: Path) -> None:
|
||||
super().__init__()
|
||||
self.train = HelloDataset.from_path_and_indexes(root_folder, start_index=0, end_index=50)
|
||||
self.val = HelloDataset.from_path_and_indexes(root_folder, start_index=50, end_index=70)
|
||||
self.test = HelloDataset.from_path_and_indexes(root_folder, start_index=70, end_index=100)
|
||||
|
||||
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def setup(self, stage: Optional[str] = None) -> None:
|
||||
pass
|
||||
|
||||
def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
|
||||
return DataLoader(self.train, batch_size=5)
|
||||
|
||||
def val_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
|
||||
return DataLoader(self.val, batch_size=5)
|
||||
|
||||
def test_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
|
||||
return DataLoader(self.test, batch_size=5)
|
||||
|
||||
|
||||
class HelloRegression(LightningModule):
|
||||
"""
|
||||
A simple 1-dim regression model.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.model = torch.nn.Linear(in_features=1, out_features=1, bias=True) # type: ignore
|
||||
self.test_mse: List[torch.Tensor] = []
|
||||
self.test_mae = MeanAbsoluteError()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
"""
|
||||
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
|
||||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
It runs a forward pass of a tensor through the model.
|
||||
|
||||
:param x: The input tensor(s)
|
||||
:return: The model output.
|
||||
"""
|
||||
return self.model(x)
|
||||
|
||||
def training_step(self, batch: Dict[str, torch.Tensor], *args: Any, **kwargs: Any) -> torch.Tensor: # type: ignore
|
||||
"""
|
||||
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
|
||||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
It consumes a minibatch of training data (coming out of the data loader), does forward propagation, and
|
||||
computes the loss.
|
||||
|
||||
:param batch: The batch of training data
|
||||
:return: The loss value with a computation graph attached.
|
||||
"""
|
||||
loss = self.shared_step(batch)
|
||||
self.log("loss", loss, on_epoch=True, on_step=False)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch: Dict[str, torch.Tensor], *args: Any, # type: ignore
|
||||
**kwargs: Any) -> torch.Tensor:
|
||||
"""
|
||||
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
|
||||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
It consumes a minibatch of validation data (coming out of the data loader), does forward propagation, and
|
||||
computes the loss.
|
||||
|
||||
:param batch: The batch of validation data
|
||||
:return: The loss value on the validation data.
|
||||
"""
|
||||
loss = self.shared_step(batch)
|
||||
self.log("val_loss", loss, on_epoch=True, on_step=False)
|
||||
return loss
|
||||
|
||||
def shared_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
This is a convenience method to reduce code duplication, because training, validation, and test step share
|
||||
large amounts of code.
|
||||
|
||||
:param batch: The batch of data to process, with input data and targets.
|
||||
:return: The MSE loss that the model achieved on this batch.
|
||||
"""
|
||||
input = batch["x"]
|
||||
target = batch["y"]
|
||||
prediction = self.forward(input)
|
||||
return torch.nn.functional.mse_loss(prediction, target) # type: ignore
|
||||
|
||||
def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]:
|
||||
"""
|
||||
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
|
||||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
It returns the PyTorch optimizer(s) and learning rate scheduler(s) that should be used for training.
|
||||
"""
|
||||
optimizer = Adam(self.parameters(), lr=1e-1)
|
||||
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def on_test_epoch_start(self) -> None:
|
||||
"""
|
||||
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
|
||||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
In this method, you can prepare data structures that need to be in place before evaluating the model on the
|
||||
test set (that is done in the test_step).
|
||||
"""
|
||||
self.test_mse = []
|
||||
self.test_mae.reset()
|
||||
|
||||
def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: # type: ignore
|
||||
"""
|
||||
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
|
||||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
It evaluates the model in "inference mode" on data coming from the test set. It could, for example,
|
||||
also write each model prediction to disk.
|
||||
|
||||
:param batch: The batch of test data.
|
||||
:param batch_idx: The index (0, 1, ...) of the batch when the data loader is enumerated.
|
||||
:return: The loss on the test data.
|
||||
"""
|
||||
input = batch["x"]
|
||||
target = batch["y"]
|
||||
prediction = self.forward(input)
|
||||
# This illustrates two ways of computing metrics: Using standard torch
|
||||
loss = torch.nn.functional.mse_loss(prediction, target) # type: ignore
|
||||
self.test_mse.append(loss)
|
||||
# Metrics computed using PyTorch Lightning objects. Note that these will, by default, attempt
|
||||
# to synchronize across GPUs.
|
||||
self.test_mae.update(preds=prediction, target=target)
|
||||
return loss
|
||||
|
||||
def on_test_epoch_end(self) -> None:
|
||||
"""
|
||||
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
|
||||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
In this method, you can finish off anything to do with evaluating the model on the test set,
|
||||
for example writing aggregate metrics to disk.
|
||||
"""
|
||||
average_mse = torch.mean(torch.stack(self.test_mse))
|
||||
Path("test_mse.txt").write_text(str(average_mse.item()))
|
||||
Path("test_mae.txt").write_text(str(self.test_mae.compute().item()))
|
||||
|
||||
|
||||
class HelloContainer(LightningContainer):
|
||||
"""
|
||||
An example container for using the hi-ml runner. This container has methods
|
||||
to generate the actual Lightning model, and read out the datamodule that will be used for training.
|
||||
The number of training epochs is controlled at container level.
|
||||
You can train this model by running `python health_ml/runner.py --model=HelloContainer` on the local box,
|
||||
or via `python health_ml/runner.py --model=HelloContainer --azureml=True` in AzureML
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.local_dataset_dir = Path(__file__).parent
|
||||
self.max_epochs = 20
|
||||
|
||||
# This method must be overridden by any subclass of LightningContainer. It returns the model that you wish to
|
||||
# train, as a LightningModule
|
||||
def create_model(self) -> LightningModule:
|
||||
return HelloRegression()
|
||||
|
||||
# This method must be overridden by any subclass of LightningContainer. It returns a data module, which
|
||||
# in turn contains 3 data loaders for training, validation, and test set.
|
||||
def get_data_module(self) -> LightningDataModule:
|
||||
assert self.local_dataset_dir is not None
|
||||
return HelloDataModule(
|
||||
root_folder=self.local_dataset_dir) # type: ignore
|
|
@ -0,0 +1,413 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum, unique
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import param
|
||||
from param import Parameterized
|
||||
|
||||
from health_azure.utils import RUN_CONTEXT, PathOrString, is_running_in_azure_ml
|
||||
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.utils.common_utils import (CHECKPOINT_FOLDER,
|
||||
create_unique_timestamp_id,
|
||||
DEFAULT_AML_UPLOAD_DIR,
|
||||
DEFAULT_LOGS_DIR_NAME, is_windows, parse_model_id_and_version)
|
||||
from health_ml.utils.type_annotations import TupleFloat2
|
||||
|
||||
|
||||
@unique
|
||||
class LRWarmUpType(Enum):
|
||||
"""
|
||||
Supported LR warm up types for model training
|
||||
"""
|
||||
NoWarmUp = "NoWarmUp"
|
||||
Linear = "Linear"
|
||||
|
||||
|
||||
@unique
|
||||
class LRSchedulerType(Enum):
|
||||
"""
|
||||
Supported lr scheduler types for model training
|
||||
"""
|
||||
Exponential = "Exponential"
|
||||
Step = "Step"
|
||||
Polynomial = "Polynomial"
|
||||
Cosine = "Cosine"
|
||||
MultiStep = "MultiStep"
|
||||
|
||||
|
||||
@unique
|
||||
class MultiprocessingStartMethod(Enum):
|
||||
"""
|
||||
Different methods for starting data loader processes.
|
||||
"""
|
||||
fork = "fork"
|
||||
forkserver = "forkserver"
|
||||
spawn = "spawn"
|
||||
|
||||
|
||||
@unique
|
||||
class OptimizerType(Enum):
|
||||
"""
|
||||
Supported optimizers for model training
|
||||
"""
|
||||
Adam = "Adam"
|
||||
AMSGrad = "AMSGrad"
|
||||
SGD = "SGD"
|
||||
RMSprop = "RMSprop"
|
||||
|
||||
|
||||
class ExperimentFolderHandler(Parameterized):
|
||||
"""High level config to abstract the file system related settings for experiments"""
|
||||
outputs_folder: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
|
||||
doc="The folder where all training and test outputs should go.")
|
||||
logs_folder: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
|
||||
doc="The folder for all log files and Tensorboard event files")
|
||||
project_root: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
|
||||
doc="The root folder for the codebase that triggers the training run.")
|
||||
run_folder: Path = param.ClassSelector(class_=Path, default=Path(), instantiate=False,
|
||||
doc="The folder that contains outputs and the logs subfolder.")
|
||||
|
||||
@staticmethod
|
||||
def create(project_root: Path,
|
||||
is_offline_run: bool,
|
||||
model_name: str,
|
||||
output_to: Optional[str] = None) -> ExperimentFolderHandler:
|
||||
"""
|
||||
Creates a new object that holds output folder configurations. When running inside of AzureML, the output
|
||||
folders will be directly under the project root. If not running inside AzureML, a folder with a timestamp
|
||||
will be created for all outputs and logs.
|
||||
|
||||
:param project_root: The root folder that contains the code that submitted the present training run.
|
||||
When running inside the hi-ml repository, it is the git repo root. When consuming hi-ml as a package,
|
||||
this should be the root of the source code that calls the package.
|
||||
:param is_offline_run: If true, this is a run outside AzureML. If False, it is inside AzureML.
|
||||
:param model_name: The name of the model that is trained. This is used to generate a run-specific output
|
||||
folder.
|
||||
:param output_to: If provided, the output folders will be created as a subfolder of this argument. If not
|
||||
given, the output folders will be created inside of the project root.
|
||||
"""
|
||||
if not project_root.is_absolute():
|
||||
raise ValueError(f"The project root is required to be an absolute path, but got {project_root}")
|
||||
|
||||
if is_offline_run or output_to:
|
||||
if output_to:
|
||||
logging.info(f"All results will be written to the specified output folder {output_to}")
|
||||
root = Path(output_to).absolute()
|
||||
else:
|
||||
logging.info("All results will be written to a subfolder of the project root folder.")
|
||||
root = project_root.absolute() / DEFAULT_AML_UPLOAD_DIR
|
||||
timestamp = create_unique_timestamp_id()
|
||||
run_folder = root / f"{timestamp}_{model_name}"
|
||||
outputs_folder = run_folder
|
||||
logs_folder = run_folder / DEFAULT_LOGS_DIR_NAME
|
||||
else:
|
||||
logging.info("Running inside AzureML.")
|
||||
logging.info("All results will be written to a subfolder of the project root folder.")
|
||||
run_folder = project_root
|
||||
outputs_folder = project_root / DEFAULT_AML_UPLOAD_DIR
|
||||
logs_folder = project_root / DEFAULT_LOGS_DIR_NAME
|
||||
logging.info(f"Run outputs folder: {outputs_folder}")
|
||||
logging.info(f"Logs folder: {logs_folder}")
|
||||
return ExperimentFolderHandler(
|
||||
outputs_folder=outputs_folder,
|
||||
logs_folder=logs_folder,
|
||||
project_root=project_root,
|
||||
run_folder=run_folder
|
||||
)
|
||||
|
||||
|
||||
class WorkflowParams(param.Parameterized):
|
||||
"""
|
||||
This class contains all parameters that affect how the whole training and testing workflow is executed.
|
||||
"""
|
||||
random_seed: int = param.Integer(42, doc="The seed to use for all random number generators.")
|
||||
weights_url: List[str] = param.List(default=[], class_=str,
|
||||
doc="If provided, a set of urls from which checkpoints will be downloaded"
|
||||
"and used for inference.")
|
||||
local_weights_path: List[Path] = param.List(default=[], class_=Path,
|
||||
doc="A list of checkpoints paths to use for inference, "
|
||||
"when the job is running outside Azure.")
|
||||
model_id: str = param.String(default="",
|
||||
doc="A model id string in the form 'model name:version' "
|
||||
"to use a registered model for inference.")
|
||||
multiprocessing_start_method: MultiprocessingStartMethod = \
|
||||
param.ClassSelector(class_=MultiprocessingStartMethod,
|
||||
default=(MultiprocessingStartMethod.spawn if is_windows()
|
||||
else MultiprocessingStartMethod.fork),
|
||||
doc="Method to be used to start child processes in pytorch. Should be one of forkserver, "
|
||||
"fork or spawn. If not specified, fork is used on Linux and spawn on Windows. "
|
||||
"Set to forkserver as a possible remedy for stuck jobs.")
|
||||
regression_test_folder: Optional[Path] = \
|
||||
param.ClassSelector(class_=Path, default=None, allow_None=True,
|
||||
doc="A path to a folder that contains a set of files. At the end of training and "
|
||||
"model evaluation, all files given in that folder must be present in the job's output "
|
||||
"folder, and their contents must match exactly. When running in AzureML, you need to "
|
||||
"ensure that this folder is part of the snapshot that gets uploaded. The path should "
|
||||
"be relative to the repository root directory.")
|
||||
|
||||
def validate(self) -> None:
|
||||
if sum([bool(param) for param in [self.weights_url, self.local_weights_path, self.model_id]]) > 1:
|
||||
raise ValueError("Cannot specify more than one of local_weights_path, weights_url or model_id.")
|
||||
|
||||
if self.model_id:
|
||||
parse_model_id_and_version(self.model_id)
|
||||
|
||||
@property
|
||||
def is_running_in_aml(self) -> bool:
|
||||
"""
|
||||
Whether the current run is executing inside Azure ML
|
||||
|
||||
:return: True if the run is executing inside Azure ML, or False if outside AzureML.
|
||||
"""
|
||||
return is_running_in_azure_ml(RUN_CONTEXT)
|
||||
|
||||
def get_effective_random_seed(self) -> int:
|
||||
"""
|
||||
Returns the random seed set as part of this configuration.
|
||||
|
||||
:return:
|
||||
"""
|
||||
seed = self.random_seed
|
||||
return seed
|
||||
|
||||
|
||||
class DatasetParams(param.Parameterized):
|
||||
azure_datasets: List[str] = param.List(default=[], class_=str,
|
||||
doc="If provided, the ID of one or more datasets to use when running in"
|
||||
" AzureML.This dataset must exist as a folder of the same name in the"
|
||||
" 'datasets' container in the datasets storage account. This dataset"
|
||||
" will be mounted and made available at the 'local_dataset' path"
|
||||
" when running in AzureML.")
|
||||
local_datasets: List[Path] = param.List(default=[], class_=Path,
|
||||
doc="A list of one or more paths to the dataset to use, when training"
|
||||
" outside of Azure ML.")
|
||||
dataset_mountpoints: List[Path] = param.List(default=[], class_=Path,
|
||||
doc="The path at which the AzureML dataset should be made available "
|
||||
"via mounting or downloading. This only affects jobs running in "
|
||||
"AzureML. If empty, use a random mount/download point.")
|
||||
|
||||
def validate(self) -> None:
|
||||
if (not self.azure_datasets) and (not self.local_datasets):
|
||||
raise ValueError("Either local_datasets or azure_datasets must be set.")
|
||||
|
||||
if self.dataset_mountpoints and len(self.azure_datasets) != len(self.dataset_mountpoints):
|
||||
raise ValueError(f"Expected the number of azure datasets to equal the number of mountpoints, "
|
||||
f"got datasets [{','.join(self.azure_datasets)}] "
|
||||
f"and mountpoints [{','.join([str(m) for m in self.dataset_mountpoints])}]")
|
||||
|
||||
|
||||
class OutputParams(param.Parameterized):
|
||||
output_to: Path = param.ClassSelector(class_=Path, default=Path(),
|
||||
doc="If provided, the run outputs will be written to the given folder. If "
|
||||
"not provided, outputs will go into a subfolder of the project root "
|
||||
"folder.")
|
||||
file_system_config: ExperimentFolderHandler = param.ClassSelector(default=ExperimentFolderHandler(),
|
||||
class_=ExperimentFolderHandler,
|
||||
instantiate=False,
|
||||
doc="File system related configs")
|
||||
_model_name: str = param.String("", doc="The human readable name of the model (for example, Liver). This is "
|
||||
"usually set from the class name.")
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""
|
||||
Gets the human readable name of the model (e.g., Liver). This is usually set from the class name.
|
||||
|
||||
:return: A model name as a string.
|
||||
"""
|
||||
return self._model_name
|
||||
|
||||
def set_output_to(self, output_to: PathOrString) -> None:
|
||||
"""
|
||||
Adjusts the file system settings in the present object such that all outputs are written to the given folder.
|
||||
|
||||
:param output_to: The absolute path to a folder that should contain the outputs.
|
||||
"""
|
||||
self.output_to = Path(output_to)
|
||||
self.create_filesystem()
|
||||
|
||||
def create_filesystem(self, project_root: Path = fixed_paths.repository_root_directory()) -> None:
|
||||
"""
|
||||
Creates new file system settings (outputs folder, logs folder) based on the information stored in the
|
||||
present object. If any of the folders do not yet exist, they are created.
|
||||
|
||||
:param project_root: The root folder for the codebase that triggers the training run.
|
||||
"""
|
||||
self.file_system_config = ExperimentFolderHandler.create(
|
||||
project_root=project_root,
|
||||
model_name=self.model_name,
|
||||
is_offline_run=not is_running_in_azure_ml(RUN_CONTEXT),
|
||||
output_to=str(self.output_to)
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs_folder(self) -> Path:
|
||||
"""Gets the full path in which the model outputs should be stored."""
|
||||
return self.file_system_config.outputs_folder
|
||||
|
||||
@property
|
||||
def logs_folder(self) -> Path:
|
||||
"""Gets the full path in which the model logs should be stored."""
|
||||
return self.file_system_config.logs_folder
|
||||
|
||||
@property
|
||||
def checkpoint_folder(self) -> Path:
|
||||
"""Gets the full path in which the model checkpoints should be stored during training."""
|
||||
return self.outputs_folder / CHECKPOINT_FOLDER
|
||||
|
||||
|
||||
class OptimizerParams(param.Parameterized):
|
||||
l_rate: float = param.Number(1e-4, doc="The initial learning rate", bounds=(0, None))
|
||||
_min_l_rate: float = param.Number(0.0, doc="The minimum learning rate for the Polynomial and Cosine schedulers.",
|
||||
bounds=(0.0, None))
|
||||
l_rate_scheduler: LRSchedulerType = param.ClassSelector(default=LRSchedulerType.Polynomial,
|
||||
class_=LRSchedulerType,
|
||||
instantiate=False,
|
||||
doc="Learning rate decay method (Cosine, Polynomial, "
|
||||
"Step, MultiStep or Exponential)")
|
||||
l_rate_exponential_gamma: float = param.Number(0.9, doc="Controls the rate of decay for the Exponential "
|
||||
"LR scheduler.")
|
||||
l_rate_step_gamma: float = param.Number(0.1, doc="Controls the rate of decay for the "
|
||||
"Step LR scheduler.")
|
||||
l_rate_step_step_size: int = param.Integer(50, bounds=(0, None),
|
||||
doc="The step size for Step LR scheduler")
|
||||
l_rate_multi_step_gamma: float = param.Number(0.1, doc="Controls the rate of decay for the "
|
||||
"MultiStep LR scheduler.")
|
||||
l_rate_multi_step_milestones: Optional[List[int]] = param.List(None, bounds=(1, None),
|
||||
allow_None=True, class_=int,
|
||||
doc="The milestones for MultiStep decay.")
|
||||
l_rate_polynomial_gamma: float = param.Number(1e-4, doc="Controls the rate of decay for the "
|
||||
"Polynomial LR scheduler.")
|
||||
l_rate_warmup: LRWarmUpType = param.ClassSelector(default=LRWarmUpType.NoWarmUp, class_=LRWarmUpType,
|
||||
instantiate=False,
|
||||
doc="The type of learning rate warm up to use. "
|
||||
"Can be NoWarmUp (default) or Linear.")
|
||||
l_rate_warmup_epochs: int = param.Integer(0, bounds=(0, None),
|
||||
doc="Number of warmup epochs (linear warmup) before the "
|
||||
"scheduler starts decaying the learning rate. "
|
||||
"For example, if you are using MultiStepLR with "
|
||||
"milestones [50, 100, 200] and warmup epochs = 100, warmup "
|
||||
"will last for 100 epochs and the first decay of LR "
|
||||
"will happen on epoch 150")
|
||||
optimizer_type: OptimizerType = param.ClassSelector(default=OptimizerType.Adam, class_=OptimizerType,
|
||||
instantiate=False, doc="The optimizer_type to use")
|
||||
opt_eps: float = param.Number(1e-4, doc="The epsilon parameter of RMSprop or Adam")
|
||||
rms_alpha: float = param.Number(0.9, doc="The alpha parameter of RMSprop")
|
||||
adam_betas: TupleFloat2 = param.NumericTuple((0.9, 0.999), length=2,
|
||||
doc="The betas parameter of Adam, default is (0.9, 0.999)")
|
||||
momentum: float = param.Number(0.6, doc="The momentum parameter of the optimizers")
|
||||
weight_decay: float = param.Number(1e-4, doc="The weight decay used to control L2 regularization")
|
||||
|
||||
def validate(self) -> None:
|
||||
if len(self.adam_betas) < 2:
|
||||
raise ValueError(
|
||||
"The adam_betas parameter should be the coefficients used for computing running averages of "
|
||||
"gradient and its square")
|
||||
|
||||
if self.l_rate_scheduler == LRSchedulerType.MultiStep:
|
||||
if not self.l_rate_multi_step_milestones:
|
||||
raise ValueError("Must specify l_rate_multi_step_milestones to use LR scheduler MultiStep")
|
||||
if sorted(set(self.l_rate_multi_step_milestones)) != self.l_rate_multi_step_milestones:
|
||||
raise ValueError("l_rate_multi_step_milestones must be a strictly increasing list")
|
||||
if self.l_rate_multi_step_milestones[0] <= 0:
|
||||
raise ValueError("l_rate_multi_step_milestones cannot be negative or 0.")
|
||||
|
||||
@property
|
||||
def min_l_rate(self) -> float:
|
||||
return self._min_l_rate
|
||||
|
||||
@min_l_rate.setter
|
||||
def min_l_rate(self, value: float) -> None:
|
||||
if value > self.l_rate:
|
||||
raise ValueError("l_rate must be >= min_l_rate, found: {}, {}".format(self.l_rate, value))
|
||||
self._min_l_rate = value
|
||||
|
||||
|
||||
class TrainerParams(param.Parameterized):
|
||||
max_epochs: int = param.Integer(100, bounds=(1, None), doc="Number of epochs to train.")
|
||||
autosave_every_n_val_epochs: int = param.Integer(1, bounds=(0, None),
|
||||
doc="Save epoch checkpoints every N validation epochs. "
|
||||
"If pl_check_val_every_n_epoch > 1, this means that "
|
||||
"checkpoints are saved every N * pl_check_val_every_n_epoch "
|
||||
"training epochs.")
|
||||
detect_anomaly: bool = param.Boolean(False, doc="If true, test gradients for anomalies (NaN or Inf) during "
|
||||
"training.")
|
||||
use_mixed_precision: bool = param.Boolean(False, doc="If true, mixed precision training is activated during "
|
||||
"training.")
|
||||
max_num_gpus: int = param.Integer(default=-1, doc="The maximum number of GPUS to use. If set to a value < 0, use"
|
||||
"all available GPUs. In distributed training, this is the "
|
||||
"maximum number of GPUs per node.")
|
||||
pl_progress_bar_refresh_rate: Optional[int] = \
|
||||
param.Integer(default=None,
|
||||
doc="PyTorch Lightning trainer flag 'progress_bar_refresh_rate': How often to refresh progress "
|
||||
"bar (in steps). Value 0 disables progress bar. Value None chooses automatically.")
|
||||
pl_num_sanity_val_steps: int = \
|
||||
param.Integer(default=0,
|
||||
doc="PyTorch Lightning trainer flag 'num_sanity_val_steps': Number of validation "
|
||||
"steps to run before training, to identify possible problems")
|
||||
pl_deterministic: bool = \
|
||||
param.Boolean(default=False,
|
||||
doc="Controls the PyTorch Lightning trainer flags 'deterministic' and 'benchmark'. If "
|
||||
"'pl_deterministic' is True, results are perfectly reproducible. If False, they are not, but "
|
||||
"you may see training speed increases.")
|
||||
pl_find_unused_parameters: bool = \
|
||||
param.Boolean(default=False,
|
||||
doc="Controls the PyTorch Lightning flag 'find_unused_parameters' for the DDP plugin. "
|
||||
"Setting it to True comes with a performance hit.")
|
||||
pl_limit_train_batches: Optional[int] = \
|
||||
param.Integer(default=None,
|
||||
doc="PyTorch Lightning trainer flag 'limit_train_batches': Limit the training dataset to the "
|
||||
"given number of batches.")
|
||||
pl_limit_val_batches: Optional[int] = \
|
||||
param.Integer(default=None,
|
||||
doc="PyTorch Lightning trainer flag 'limit_val_batches': Limit the validation dataset to the "
|
||||
"given number of batches.")
|
||||
pl_profiler: Optional[str] = \
|
||||
param.String(default=None,
|
||||
doc="The value to use for the 'profiler' argument for the Lightning trainer. "
|
||||
"Set to either 'simple', 'advanced', or 'pytorch'")
|
||||
monitor_gpu: bool = param.Boolean(default=False,
|
||||
doc="If True, add the GPUStatsMonitor callback to the Lightning trainer object. "
|
||||
"This will write GPU utilization metrics every 50 batches by default.")
|
||||
monitor_loading: bool = param.Boolean(default=False,
|
||||
doc="If True, add the BatchTimeCallback callback to the Lightning trainer "
|
||||
"object. This will monitor how long individual batches take to load.")
|
||||
additional_env_files: List[str] = param.List(class_=Path, default=[],
|
||||
doc="Additional conda environment (.yml) files to merge into the"
|
||||
" overall environment definition")
|
||||
|
||||
@property
|
||||
def use_gpu(self) -> bool:
|
||||
"""
|
||||
Returns True if a GPU is available, and the self.max_num_gpus flag allows it to be used. Returns False
|
||||
otherwise (i.e., if there is no GPU available, or self.max_num_gpus==0)
|
||||
"""
|
||||
if self.max_num_gpus == 0:
|
||||
return False
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
return is_gpu_available()
|
||||
|
||||
def num_gpus_per_node(self) -> int:
|
||||
"""
|
||||
Computes the number of gpus to use for each node: either the number of gpus available on the device
|
||||
or restrict it to max_num_gpu, whichever is smaller. Returns 0 if running on a CPU device.
|
||||
"""
|
||||
import torch
|
||||
available_gpus = torch.cuda.device_count() # type: ignore
|
||||
num_gpus = available_gpus if self.use_gpu else 0
|
||||
message_suffix = "" if self.use_gpu else ", but not using them because use_gpu == False"
|
||||
logging.info(f"Number of available GPUs: {available_gpus}{message_suffix}")
|
||||
if 0 <= self.max_num_gpus < num_gpus:
|
||||
num_gpus = self.max_num_gpus
|
||||
logging.info(f"Restricting the number of GPUs to {num_gpus}")
|
||||
elif self.max_num_gpus > num_gpus:
|
||||
logging.warning(f"You requested max_num_gpus {self.max_num_gpus} but there are only {num_gpus} available.")
|
||||
return num_gpus
|
|
@ -0,0 +1,13 @@
|
|||
import param
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ExperimentConfig(param.Parameterized):
|
||||
cluster: Optional[str] = param.String(default=None, allow_None=True,
|
||||
doc="The name of the GPU or CPU cluster inside the AzureML workspace"
|
||||
"that should execute the job.")
|
||||
num_nodes: int = param.Integer(default=1, doc="The number of virtual machines that will be allocated for this"
|
||||
"job in AzureML.")
|
||||
model: str = param.String(doc="The fully qualified name of the model to train/test -e.g."
|
||||
"mymodule.configs.MyConfig.")
|
||||
azureml: bool = param.Boolean(False, doc="If True, submit the executing script to run on AzureML.")
|
|
@ -0,0 +1,171 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
import param
|
||||
from azureml.core import ScriptRunConfig
|
||||
from azureml.train.hyperdrive import HyperDriveConfig
|
||||
from pytorch_lightning import LightningDataModule, LightningModule
|
||||
|
||||
from health_ml.deep_learning_config import DatasetParams, OptimizerParams, OutputParams, TrainerParams, \
|
||||
WorkflowParams
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
|
||||
|
||||
class LightningContainer(WorkflowParams,
|
||||
DatasetParams,
|
||||
OutputParams,
|
||||
TrainerParams,
|
||||
OptimizerParams):
|
||||
"""
|
||||
A LightningContainer contains all information to train a user-specified PyTorch Lightning model. The model that
|
||||
should be trained is returned by the `get_model` method. The training data must be returned in the form of
|
||||
a LightningDataModule, by the `get_data_module` method.
|
||||
"""
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._model: Optional[LightningModule] = None
|
||||
self._model_name = type(self).__name__
|
||||
self.num_nodes = 1
|
||||
|
||||
def validate(self) -> None:
|
||||
WorkflowParams.validate(self)
|
||||
OptimizerParams.validate(self)
|
||||
|
||||
def setup(self) -> None:
|
||||
"""
|
||||
This method is called as one of the first operations of the training/testing workflow, before any other
|
||||
operations on the present object. At the point when called, the datasets are already available in
|
||||
the locations given by self.local_datasets. Use this method to prepare datasets or data loaders, for example.
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_model(self) -> LightningModule: # type: ignore
|
||||
"""
|
||||
This method must create the actual Lightning model that will be trained. It can read out parameters from the
|
||||
container and pass them into the model, for example.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data_module(self) -> LightningDataModule:
|
||||
"""
|
||||
Gets the data that is used for the training, validation, and test steps.
|
||||
This should read datasets from the self.local_datasets folder or download from a web location.
|
||||
The format of the data is not specified any further.
|
||||
|
||||
:return: A LightningDataModule
|
||||
"""
|
||||
return None # type: ignore
|
||||
|
||||
def get_trainer_arguments(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Gets additional parameters that will be passed on to the PyTorch Lightning trainer.
|
||||
"""
|
||||
return dict()
|
||||
|
||||
def get_parameter_search_hyperdrive_config(self, _: ScriptRunConfig) -> HyperDriveConfig: # type: ignore
|
||||
"""
|
||||
Parameter search is not implemented. It should be implemented in a sub class if needed.
|
||||
"""
|
||||
raise NotImplementedError("Parameter search is not implemented. It should be implemented in"
|
||||
"a sub class if needed.")
|
||||
|
||||
def update_experiment_config(self, experiment_config: ExperimentConfig) -> None:
|
||||
"""
|
||||
This method allows overriding ExperimentConfig parameters from within a LightningContainer.
|
||||
It is called right after the ExperimentConfig and container are initialised.
|
||||
Be careful when using class parameters to override these values. If the parameter names clash,
|
||||
CLI values will be consumed by the ExperimentConfig, but container parameters will keep their defaults.
|
||||
This can be avoided by always using unique parameter names.
|
||||
Also note that saving a reference to `experiment_config` and updating its attributes at any other
|
||||
point may lead to unexpected behaviour.
|
||||
|
||||
:param experiment_config: The initialised ExperimentConfig whose parameters to override in-place.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_training_on_global_rank_zero(self) -> None:
|
||||
"""
|
||||
A hook that will be called before starting model training, before creating the Lightning Trainer object.
|
||||
In distributed training, this is only run on global rank zero (i.e, on the process that runs on node 0, GPU 0).
|
||||
The order in which hooks are called is: before_training_on_global_rank_zero, before_training_on_local_rank_zero,
|
||||
before_training_on_all_ranks.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_training_on_local_rank_zero(self) -> None:
|
||||
"""
|
||||
A hook that will be called before starting model training.
|
||||
In distributed training, this hook will be called once per node (i.e., whenever the LOCAL_RANK environment
|
||||
variable is zero).
|
||||
The order in which hooks are called is: before_training_on_global_rank_zero, before_training_on_local_rank_zero,
|
||||
before_training_on_all_ranks.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_training_on_all_ranks(self) -> None:
|
||||
"""
|
||||
A hook that will be called before starting model training.
|
||||
In distributed training, this hook will be called on all ranks (i.e., once per GPU).
|
||||
The order in which hooks are called is: before_training_on_global_rank_zero, before_training_on_local_rank_zero,
|
||||
before_training_on_all_ranks.
|
||||
"""
|
||||
pass
|
||||
|
||||
# The code from here on does not need to be modified.
|
||||
|
||||
@property
|
||||
def model(self) -> LightningModule:
|
||||
"""
|
||||
Returns the PyTorch Lightning module that the present container object manages.
|
||||
|
||||
:return: A PyTorch Lightning module
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ValueError("No Lightning module has been set yet.")
|
||||
return self._model
|
||||
|
||||
def create_lightning_module_and_store(self) -> None:
|
||||
"""
|
||||
Creates the Lightning model
|
||||
"""
|
||||
self._model = self.create_model()
|
||||
|
||||
def get_hyperdrive_config(self, run_config: ScriptRunConfig) -> HyperDriveConfig:
|
||||
"""
|
||||
Returns the HyperDrive config for either parameter search
|
||||
|
||||
:param run_config: AzureML estimator
|
||||
:return: HyperDriveConfigs
|
||||
"""
|
||||
return self.get_parameter_search_hyperdrive_config(run_config)
|
||||
|
||||
def load_model_checkpoint(self, checkpoint_path: Path) -> None:
|
||||
"""
|
||||
Load a checkpoint from the given path. We need to define a separate method since pytorch lightning cannot
|
||||
access the _model attribute to modify it.
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ValueError("No Lightning module has been set yet.")
|
||||
self._model = type(self._model).load_from_checkpoint(checkpoint_path=str(checkpoint_path))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Returns a string describing the present object, as a list of key: value strings."""
|
||||
arguments_str = "\nContainer:\n"
|
||||
# Avoid callable params, the bindings that are printed out can be humongous.
|
||||
# Avoid dataframes
|
||||
skip_params = {name for name, value in self.param.params().items()
|
||||
if isinstance(value, (param.Callable, param.DataFrame))}
|
||||
for key, value in self.param.get_param_values():
|
||||
if key not in skip_params:
|
||||
arguments_str += f"\t{key:40}: {value}\n"
|
||||
# Print out all other separate vars that are not under the guidance of the params library,
|
||||
# skipping the two that are introduced by params
|
||||
skip_vars = {"param", "initialized"}
|
||||
for key, value in vars(self).items():
|
||||
if key not in skip_vars and key[0] != "_":
|
||||
arguments_str += f"\t{key:40}: {value}\n"
|
||||
return arguments_str
|
|
@ -0,0 +1,207 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, TypeVar
|
||||
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import GPUStatsMonitor, TQDMProgressBar
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
|
||||
|
||||
from health_azure.utils import (ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK, RUN_CONTEXT, is_global_rank_zero,
|
||||
is_local_rank_zero, is_running_in_azure_ml)
|
||||
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils import AzureMLLogger, AzureMLProgressBar
|
||||
from health_ml.utils.common_utils import EXPERIMENT_SUMMARY_FILE
|
||||
from health_ml.utils.lightning_loggers import StoringLogger
|
||||
|
||||
TEMP_PREFIX = "temp/"
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def write_experiment_summary_file(config: Any, outputs_folder: Path) -> None:
|
||||
"""
|
||||
Writes the given config to disk in plain text in the default output folder.
|
||||
"""
|
||||
output = str(config)
|
||||
outputs_folder.mkdir(exist_ok=True, parents=True)
|
||||
dst = outputs_folder / EXPERIMENT_SUMMARY_FILE
|
||||
dst.write_text(output)
|
||||
logging.info(output)
|
||||
|
||||
|
||||
def create_lightning_trainer(container: LightningContainer,
|
||||
num_nodes: int = 1) -> Tuple[Trainer, StoringLogger]:
|
||||
"""
|
||||
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
|
||||
and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
|
||||
return value.
|
||||
|
||||
:param container: The container with model and data.
|
||||
:param num_nodes: The number of nodes to use in distributed training.
|
||||
:param kwargs: Any additional keyowrd arguments will be passed to the constructor of Trainer.
|
||||
:return: A tuple [Trainer object, diagnostic logger]
|
||||
"""
|
||||
num_gpus = container.num_gpus_per_node()
|
||||
effective_num_gpus = num_gpus * num_nodes
|
||||
strategy = None
|
||||
if effective_num_gpus == 0:
|
||||
accelerator = "cpu"
|
||||
devices = 1
|
||||
message = "CPU"
|
||||
else:
|
||||
accelerator = "gpu"
|
||||
devices = num_gpus
|
||||
message = f"{devices} GPU"
|
||||
if effective_num_gpus > 1:
|
||||
# Accelerator should be "ddp" when running large models in AzureML (when using DDP_spawn, we get out of
|
||||
# GPU memory).
|
||||
# Initialize the DDP plugin. The default for pl_find_unused_parameters is False. If True, the plugin
|
||||
# prints out lengthy warnings about the performance impact of find_unused_parameters.
|
||||
strategy = DDPPlugin(find_unused_parameters=container.pl_find_unused_parameters)
|
||||
message += "s per node with DDP"
|
||||
logging.info(f"Using {message}")
|
||||
tensorboard_logger = TensorBoardLogger(save_dir=str(container.logs_folder), name="Lightning", version="")
|
||||
loggers = [tensorboard_logger, AzureMLLogger(False)]
|
||||
storing_logger = StoringLogger()
|
||||
loggers.append(storing_logger)
|
||||
# Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
|
||||
precision = 32 if num_gpus == 0 else 16 if container.use_mixed_precision else 32
|
||||
# The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html
|
||||
# Note that switching to deterministic models can have large performance downside.
|
||||
if container.pl_deterministic:
|
||||
deterministic = True
|
||||
benchmark = False
|
||||
else:
|
||||
deterministic = False
|
||||
benchmark = True
|
||||
|
||||
# Get more callbacks
|
||||
callbacks: List[Any] = []
|
||||
if container.monitor_loading:
|
||||
# TODO antonsc: Remove after fixing the callback.
|
||||
raise NotImplementedError("Monitoring batch loading times has been temporarily disabled.")
|
||||
# callbacks.append(BatchTimeCallback())
|
||||
if num_gpus > 0 and container.monitor_gpu:
|
||||
logging.info("Adding monitoring for GPU utilization")
|
||||
callbacks.append(GPUStatsMonitor(intra_step_time=True, inter_step_time=True)) # type: ignore
|
||||
# Add the additional callbacks that were specified in get_trainer_arguments for LightningContainers
|
||||
additional_args = container.get_trainer_arguments()
|
||||
# Callbacks can be specified via the "callbacks" argument (the legacy behaviour) or the new get_callbacks method
|
||||
if "callbacks" in additional_args:
|
||||
more_callbacks = additional_args.pop("callbacks")
|
||||
if isinstance(more_callbacks, list):
|
||||
callbacks.extend(more_callbacks) # type: ignore
|
||||
else:
|
||||
callbacks.append(more_callbacks) # type: ignore
|
||||
|
||||
is_azureml_run = is_running_in_azure_ml(RUN_CONTEXT)
|
||||
progress_bar_refresh_rate = container.pl_progress_bar_refresh_rate
|
||||
if progress_bar_refresh_rate is None:
|
||||
progress_bar_refresh_rate = 50
|
||||
logging.info(f"The progress bar refresh rate is not set. Using a default of {progress_bar_refresh_rate}. "
|
||||
f"To change, modify the pl_progress_bar_refresh_rate field of the container.")
|
||||
if is_azureml_run:
|
||||
callbacks.append(AzureMLProgressBar(refresh_rate=progress_bar_refresh_rate,
|
||||
write_to_logging_info=True,
|
||||
print_timestamp=False))
|
||||
else:
|
||||
callbacks.append(TQDMProgressBar(refresh_rate=progress_bar_refresh_rate))
|
||||
# Read out additional model-specific args here.
|
||||
# We probably want to keep essential ones like numgpu and logging.
|
||||
trainer = Trainer(default_root_dir=str(container.outputs_folder),
|
||||
deterministic=deterministic,
|
||||
benchmark=benchmark,
|
||||
accelerator=accelerator,
|
||||
strategy=strategy,
|
||||
max_epochs=container.max_epochs,
|
||||
# Both these arguments can be integers or floats. If integers, it is the number of batches.
|
||||
# If float, it's the fraction of batches. We default to 1.0 (processing all batches).
|
||||
limit_train_batches=container.pl_limit_train_batches or 1.0,
|
||||
limit_val_batches=container.pl_limit_val_batches or 1.0,
|
||||
num_sanity_val_steps=container.pl_num_sanity_val_steps,
|
||||
callbacks=callbacks,
|
||||
logger=loggers,
|
||||
num_nodes=num_nodes,
|
||||
devices=devices,
|
||||
precision=precision,
|
||||
sync_batchnorm=True,
|
||||
detect_anomaly=container.detect_anomaly,
|
||||
profiler=container.pl_profiler,
|
||||
**additional_args)
|
||||
return trainer, storing_logger
|
||||
|
||||
|
||||
def model_train(container: LightningContainer
|
||||
) -> Tuple[Trainer, StoringLogger]:
|
||||
"""
|
||||
The main training loop. It creates the Pytorch model based on the configuration options passed in,
|
||||
creates a Pytorch Lightning trainer, and trains the model.
|
||||
If a checkpoint was specified, then it loads the checkpoint before resuming training.
|
||||
|
||||
:param container: A container object that holds the training data in PyTorch Lightning format
|
||||
and the model to train.
|
||||
:return: A tuple of [Trainer, StoringLogger]. Trainer is the Lightning Trainer object that was used for fitting
|
||||
the model. The StoringLogger object is returned when training a built-in model, this is None when
|
||||
fitting other models.
|
||||
"""
|
||||
lightning_model = container.model
|
||||
|
||||
# resource_monitor: Optional[ResourceMonitor] = None
|
||||
# Execute some bookkeeping tasks only once if running distributed:
|
||||
if is_global_rank_zero():
|
||||
logging.info(f"Model checkpoints are saved at {container.checkpoint_folder}")
|
||||
write_experiment_summary_file(container,
|
||||
outputs_folder=container.outputs_folder)
|
||||
|
||||
data_module = container.get_data_module()
|
||||
if is_global_rank_zero():
|
||||
container.before_training_on_global_rank_zero()
|
||||
if is_local_rank_zero():
|
||||
container.before_training_on_local_rank_zero()
|
||||
container.before_training_on_all_ranks()
|
||||
|
||||
# Create the trainer object. Backup the environment variables before doing that, in case we need to run a second
|
||||
# training in the unit tests.
|
||||
old_environ = dict(os.environ)
|
||||
# Set random seeds just before training
|
||||
seed_everything(container.get_effective_random_seed())
|
||||
trainer, storing_logger = create_lightning_trainer(container,
|
||||
num_nodes=container.num_nodes)
|
||||
rank_info = ", ".join(f"{env}: {os.getenv(env)}"
|
||||
for env in [ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NODE_RANK])
|
||||
logging.info(f"Environment variables: {rank_info}. trainer.global_rank: {trainer.global_rank}")
|
||||
|
||||
# get recovery checkpoint if it exists
|
||||
|
||||
logging.info("Starting training")
|
||||
trainer.fit(lightning_model, datamodule=data_module)
|
||||
trainer.logger.finalize('success')
|
||||
|
||||
# DDP will start multiple instances of the runner, one for each GPU. Those should terminate here after training.
|
||||
# We can now use the global_rank of the Lightning model, rather than environment variables, because DDP has set
|
||||
# all necessary properties.
|
||||
if lightning_model.global_rank != 0:
|
||||
logging.info(f"Terminating training thread with rank {lightning_model.global_rank}.")
|
||||
sys.exit()
|
||||
|
||||
logging.info("Removing redundant checkpoint files.")
|
||||
# get_best_checkpoint_path(container.checkpoint_folder)
|
||||
# Lightning modifies a ton of environment variables. If we first run training and then the test suite,
|
||||
# those environment variables will mislead the training runs in the test suite, and make them crash.
|
||||
# Hence, restore the original environment after training.
|
||||
os.environ.clear()
|
||||
os.environ.update(old_environ)
|
||||
|
||||
logging.info("Finished training")
|
||||
|
||||
return trainer, storing_logger
|
|
@ -0,0 +1,164 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch.multiprocessing
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from health_azure import AzureRunInfo
|
||||
from health_azure.utils import (ENV_OMPI_COMM_WORLD_RANK, RUN_CONTEXT, create_run_recovery_id,
|
||||
PARENT_RUN_CONTEXT, is_running_in_azure_ml)
|
||||
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.model_trainer import create_lightning_trainer, model_train
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.utils.common_utils import (
|
||||
change_working_directory, logging_section, RUN_RECOVERY_ID_KEY,
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME, RUN_RECOVERY_FROM_ID_KEY_NAME)
|
||||
from health_ml.utils.lightning_loggers import StoringLogger
|
||||
from health_ml.utils.type_annotations import PathOrString
|
||||
|
||||
|
||||
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
|
||||
"""
|
||||
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
|
||||
to a Path instance. If it does not exist, raise a FileNotFoundError.
|
||||
|
||||
:param local_dataset: The dataset folder to check.
|
||||
:return: The local_dataset argument, converted to a Path.
|
||||
"""
|
||||
expected_dir = Path(local_dataset)
|
||||
if not expected_dir.is_dir():
|
||||
raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.")
|
||||
logging.info(f"Model training will use the local dataset provided in {expected_dir}")
|
||||
return expected_dir
|
||||
|
||||
|
||||
class MLRunner:
|
||||
|
||||
def __init__(self,
|
||||
experiment_config: ExperimentConfig,
|
||||
container: LightningContainer,
|
||||
project_root: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Driver class to run a ML experiment. Note that the project root argument MUST be supplied when using hi-ml
|
||||
as a package!
|
||||
|
||||
:param container: The LightningContainer object to use for training.
|
||||
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
|
||||
it is crucial when using hi-ml as a package or submodule!
|
||||
"""
|
||||
self.container = container
|
||||
self.experiment_config = experiment_config
|
||||
self.container.num_nodes = self.experiment_config.num_nodes
|
||||
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
|
||||
self.storing_logger: Optional[StoringLogger] = None
|
||||
self._has_setup_run = False
|
||||
|
||||
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
|
||||
"""
|
||||
Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual
|
||||
Lightning modules.
|
||||
|
||||
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
|
||||
This can be missing when running in unit tests, where the local dataset paths are already populated.
|
||||
"""
|
||||
if self._has_setup_run:
|
||||
return
|
||||
if azure_run_info:
|
||||
# Set up the paths to the datasets. azure_run_info already has all necessary information, using either
|
||||
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
|
||||
# This must happen before container setup because that could already read datasets.
|
||||
if len(azure_run_info.input_datasets) > 0:
|
||||
input_datasets = azure_run_info.input_datasets
|
||||
assert len(input_datasets) > 0
|
||||
local_datasets = [
|
||||
check_dataset_folder_exists(input_dataset) for input_dataset in input_datasets # type: ignore
|
||||
]
|
||||
self.container.local_datasets = local_datasets # type: ignore
|
||||
# Ensure that we use fixed seeds before initializing the PyTorch models
|
||||
seed_everything(self.container.get_effective_random_seed())
|
||||
|
||||
# Creating the folder structure must happen before the LightningModule is created, because the output
|
||||
# parameters of the container will be copied into the module.
|
||||
self.container.create_filesystem(self.project_root)
|
||||
|
||||
self.container.setup()
|
||||
self.container.create_lightning_module_and_store()
|
||||
self._has_setup_run = True
|
||||
|
||||
def set_run_tags_from_parent(self) -> None:
|
||||
"""
|
||||
Set metadata for the run
|
||||
"""
|
||||
assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run."
|
||||
run_tags_parent = PARENT_RUN_CONTEXT.get_tags()
|
||||
tags_to_copy = [
|
||||
"tag",
|
||||
"model_name",
|
||||
"execution_mode",
|
||||
"recovered_from",
|
||||
"friendly_name",
|
||||
"build_number",
|
||||
"build_user",
|
||||
RUN_RECOVERY_FROM_ID_KEY_NAME
|
||||
]
|
||||
new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy}
|
||||
new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT)
|
||||
new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed())
|
||||
RUN_CONTEXT.set_tags(new_tags)
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Driver function to run a ML experiment
|
||||
"""
|
||||
self.setup()
|
||||
is_offline_run = not is_running_in_azure_ml(RUN_CONTEXT)
|
||||
# Get the AzureML context in which the script is running
|
||||
if not is_offline_run and PARENT_RUN_CONTEXT is not None:
|
||||
logging.info("Setting tags from parent run.")
|
||||
self.set_run_tags_from_parent()
|
||||
|
||||
# do training
|
||||
with logging_section("Model training"):
|
||||
_, storing_logger = model_train(container=self.container)
|
||||
self.storing_logger = storing_logger
|
||||
|
||||
def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> List[Dict[str, float]]:
|
||||
"""
|
||||
Run inference on the test set for all models that are specified via a LightningContainer.
|
||||
|
||||
:param checkpoint_paths: The path to the checkpoint that should be used for inference.
|
||||
"""
|
||||
if len(checkpoint_paths) != 1:
|
||||
raise ValueError(f"This method expects exactly 1 checkpoint for inference, but got {len(checkpoint_paths)}")
|
||||
# lightning_model = self.container.model
|
||||
|
||||
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
|
||||
logging.info("Running inference via the LightningModule.test_step method")
|
||||
# Lightning does not cope with having two calls to .fit or .test in the same script. As a workaround for
|
||||
# now, restrict number of GPUs to 1, meaning that it will not start DDP.
|
||||
self.container.max_num_gpus = 1
|
||||
# Without this, the trainer will think it should still operate in multi-node mode, and wrongly start
|
||||
# searching for Horovod
|
||||
if ENV_OMPI_COMM_WORLD_RANK in os.environ:
|
||||
del os.environ[ENV_OMPI_COMM_WORLD_RANK]
|
||||
# From the training setup, torch still thinks that it should run in a distributed manner,
|
||||
# and would block on some GPU operations. Hence, clean up distributed training.
|
||||
if torch.distributed.is_initialized(): # type: ignore
|
||||
torch.distributed.destroy_process_group() # type: ignore
|
||||
|
||||
trainer, _ = create_lightning_trainer(self.container, num_nodes=1)
|
||||
|
||||
self.container.load_model_checkpoint(checkpoint_path=checkpoint_paths[0])
|
||||
# Change the current working directory to ensure that test files go to thr right folder
|
||||
data_module = self.container.get_data_module()
|
||||
with change_working_directory(self.container.outputs_folder):
|
||||
results = trainer.test(self.container.model, datamodule=data_module)
|
||||
return results
|
|
@ -0,0 +1,299 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import param
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import matplotlib
|
||||
|
||||
# Add hi-ml packages to sys.path so that AML can find them
|
||||
himl_root = Path(__file__).parent.parent.parent.parent
|
||||
print(f"Starting the himl runner at {himl_root}")
|
||||
print(f"health_ml pkg: {himl_root}")
|
||||
health_ml_pkg = himl_root / "hi-ml" / "src"
|
||||
health_azure_pkg = himl_root / "hi-ml-azure" / "src"
|
||||
|
||||
sys.path.insert(0, str(health_azure_pkg))
|
||||
sys.path.insert(0, str(health_ml_pkg))
|
||||
print(f"sys path: {sys.path}")
|
||||
|
||||
from health_azure import AzureRunInfo, submit_to_azure_if_needed # noqa: E402
|
||||
from health_azure.datasets import create_dataset_configs # noqa: E402
|
||||
from health_azure.utils import (get_workspace, is_local_rank_zero, merge_conda_files, # noqa: E402
|
||||
set_environment_variables_for_multi_node, create_argparser, parse_arguments,
|
||||
ParserResult, apply_overrides)
|
||||
|
||||
from health_ml.experiment_config import ExperimentConfig # noqa: E402
|
||||
from health_ml.lightning_container import LightningContainer # noqa: E402
|
||||
from health_ml.run_ml import MLRunner # noqa: E402
|
||||
from health_ml.utils import fixed_paths # noqa: E402
|
||||
from health_ml.utils.common_utils import (get_all_environment_files, # noqa: E402
|
||||
get_all_pip_requirements_files,
|
||||
is_linux, logging_to_stdout)
|
||||
from health_ml.utils.config_loader import ModelConfigLoader # noqa: E402
|
||||
|
||||
|
||||
DEFAULT_DOCKER_BASE_IMAGE = "mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.2-cudnn8-ubuntu18.04"
|
||||
|
||||
|
||||
def initialize_rpdb() -> None:
|
||||
"""
|
||||
On Linux only, import and initialize rpdb, to enable remote debugging if necessary.
|
||||
"""
|
||||
# rpdb signal trapping does not work on Windows, as there is no SIGTRAP:
|
||||
if not is_linux():
|
||||
return
|
||||
import rpdb
|
||||
rpdb_port = 4444
|
||||
rpdb.handle_trap(port=rpdb_port)
|
||||
# For some reason, os.getpid() does not return the ID of what appears to be the currently running process.
|
||||
logging.info("rpdb is handling traps. To debug: identify the main runner.py process, then as root: "
|
||||
f"kill -TRAP <process_id>; nc 127.0.0.1 {rpdb_port}")
|
||||
|
||||
|
||||
def package_setup_and_hacks() -> None:
|
||||
"""
|
||||
Set up the Python packages where needed. In particular, reduce the logging level for some of the used
|
||||
libraries, which are particularly talkative in DEBUG mode. Usually when running in DEBUG mode, we want
|
||||
diagnostics about the model building itself, but not for the underlying libraries.
|
||||
It also adds workarounds for known issues in some packages.
|
||||
"""
|
||||
# Numba code generation is extremely talkative in DEBUG mode, disable that.
|
||||
logging.getLogger('numba').setLevel(logging.WARNING)
|
||||
# Matplotlib is also very talkative in DEBUG mode, filling half of the log file in a PR build.
|
||||
logging.getLogger('matplotlib').setLevel(logging.INFO)
|
||||
# Urllib3 prints out connection information for each call to write metrics, etc
|
||||
logging.getLogger('urllib3').setLevel(logging.INFO)
|
||||
logging.getLogger('msrest').setLevel(logging.INFO)
|
||||
# AzureML prints too many details about logging metrics
|
||||
logging.getLogger('azureml').setLevel(logging.INFO)
|
||||
# Jupyter notebook report generation
|
||||
logging.getLogger('papermill').setLevel(logging.INFO)
|
||||
logging.getLogger('nbconvert').setLevel(logging.INFO)
|
||||
# This is working around a spurious error message thrown by MKL, see
|
||||
# https://github.com/pytorch/pytorch/issues/37377
|
||||
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
||||
# Workaround for issues with matplotlib on some X servers, see
|
||||
# https://stackoverflow.com/questions/45993879/matplot-lib-fatal-io-error-25-inappropriate-ioctl-for-device-on-x
|
||||
# -server-loc
|
||||
matplotlib.use('Agg')
|
||||
|
||||
|
||||
def create_runner_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Creates a commandline parser, that understands all necessary arguments for training a model
|
||||
|
||||
:return: An instance of ArgumentParser with args from ExperimentConfig added
|
||||
"""
|
||||
config = ExperimentConfig()
|
||||
parser = create_argparser(config)
|
||||
return parser
|
||||
|
||||
|
||||
def additional_run_tags(commandline_args: str) -> Dict[str, str]:
|
||||
"""
|
||||
Gets the set of tags from the commandline arguments that will be added to the AzureML run as metadata
|
||||
|
||||
:param commandline_args: A string that holds all commandline arguments that were used for the present run.
|
||||
"""
|
||||
return {
|
||||
"commandline_args": commandline_args,
|
||||
}
|
||||
|
||||
|
||||
class Runner:
|
||||
"""
|
||||
This class contains the high-level logic to start a training run: choose a model configuration by name,
|
||||
submit to AzureML if needed, or otherwise start the actual training and test loop.
|
||||
|
||||
:param project_root: The root folder that contains all of the source code that should be executed.
|
||||
"""
|
||||
|
||||
def __init__(self, project_root: Path):
|
||||
self.project_root = project_root
|
||||
self.experiment_config: ExperimentConfig = ExperimentConfig()
|
||||
self.lightning_container: LightningContainer = None # type: ignore
|
||||
# This field stores the MLRunner object that has been created in the most recent call to the run() method.
|
||||
self.ml_runner: Optional[MLRunner] = None
|
||||
|
||||
def parse_and_load_model(self) -> ParserResult:
|
||||
"""
|
||||
Parses the command line arguments, and creates configuration objects for the model itself, and for the
|
||||
Azure-related parameters. Sets self.experiment_config to its proper values. Returns the
|
||||
parser output from parsing the model commandline arguments.
|
||||
|
||||
:return: ParserResult object containing args, overrides and settings
|
||||
"""
|
||||
parser = create_runner_parser()
|
||||
parser_result = parse_arguments(parser, args=sys.argv[1:])
|
||||
experiment_config = ExperimentConfig(**parser_result.args)
|
||||
|
||||
self.experiment_config = experiment_config
|
||||
if not experiment_config.model:
|
||||
raise ValueError("Parameter 'model' needs to be set to specify which model to run.")
|
||||
print(f"Creating model loader with the following args: {parser_result.args}")
|
||||
model_config_loader: ModelConfigLoader = ModelConfigLoader(**parser_result.args)
|
||||
# Create the model as per the "model" commandline option. This is a LightningContainer.
|
||||
container = model_config_loader.create_model_config_from_name(model_name=experiment_config.model)
|
||||
|
||||
# parse overrides and apply
|
||||
assert isinstance(container, param.Parameterized)
|
||||
parser_ = create_argparser(container)
|
||||
# For each parser, feed in the unknown settings from the previous parser. All commandline args should
|
||||
# be consumed by name, hence fail if there is something that is still unknown.
|
||||
parser_result_ = parse_arguments(parser_, args=parser_result.unknown)
|
||||
# Apply the overrides and validate. Overrides can come from either YAML settings or the commandline.
|
||||
_ = apply_overrides(container, parser_result_.overrides) # type: ignore
|
||||
container.validate()
|
||||
|
||||
self.lightning_container = container
|
||||
|
||||
return parser_result_
|
||||
|
||||
def run(self) -> Tuple[LightningContainer, AzureRunInfo]:
|
||||
"""
|
||||
The main entry point for training and testing models from the commandline. This chooses a model to train
|
||||
via a commandline argument, runs training or testing, and writes all required info to disk and logs.
|
||||
|
||||
:return: a tuple of the LightningContainer object and an AzureRunInfo containing all information about
|
||||
the present run (whether running in AzureML or not)
|
||||
"""
|
||||
# Usually, when we set logging to DEBUG, we want diagnostics about the model
|
||||
# build itself, but not the tons of debug information that AzureML submissions create.
|
||||
logging_to_stdout(logging.INFO if is_local_rank_zero() else "ERROR")
|
||||
initialize_rpdb()
|
||||
self.parse_and_load_model()
|
||||
azure_run_info = self.submit_to_azureml_if_needed()
|
||||
self.run_in_situ(azure_run_info)
|
||||
return self.lightning_container, azure_run_info
|
||||
|
||||
def submit_to_azureml_if_needed(self) -> AzureRunInfo:
|
||||
"""
|
||||
Submit a job to AzureML, returning the resulting Run object, or exiting if we were asked to wait for
|
||||
completion and the Run did not succeed.
|
||||
|
||||
:return: an AzureRunInfo object containing all of the details of the present run. If AzureML is not
|
||||
specified, the attribute 'run' will None, but the object still contains helpful information
|
||||
about datasets etc
|
||||
"""
|
||||
root_folder = self.project_root
|
||||
entry_script = Path(sys.argv[0]).resolve()
|
||||
script_params = sys.argv[1:]
|
||||
|
||||
additional_conda_env_files = self.lightning_container.additional_env_files
|
||||
additional_env_files: Optional[List[Path]]
|
||||
if additional_conda_env_files is not None:
|
||||
additional_env_files = [Path(f) for f in additional_conda_env_files]
|
||||
else:
|
||||
additional_env_files = None
|
||||
|
||||
conda_dependencies_files = get_all_environment_files(self.project_root,
|
||||
additional_files=additional_env_files)
|
||||
pip_requirements_files = get_all_pip_requirements_files()
|
||||
|
||||
# Merge the project-specific dependencies with the packages and write unified definition
|
||||
# to temp file. In case of version conflicts, the package version in the outer project is given priority.
|
||||
temp_conda: Optional[Path] = None
|
||||
if len(conda_dependencies_files) > 1 or len(pip_requirements_files) > 0:
|
||||
temp_conda = root_folder / f"temp_environment-{uuid.uuid4().hex[:8]}.yml"
|
||||
merge_conda_files(conda_dependencies_files, temp_conda, pip_files=pip_requirements_files)
|
||||
|
||||
# TODO: Update environment variables
|
||||
environment_variables: Dict[str, Any] = {}
|
||||
|
||||
# get default datastore from provided workspace
|
||||
workspace = get_workspace()
|
||||
default_datastore = workspace.get_default_datastore().name
|
||||
|
||||
local_datasets = self.lightning_container.local_datasets
|
||||
all_local_datasets = [Path(p) for p in local_datasets] if len(local_datasets) > 0 else []
|
||||
input_datasets = \
|
||||
create_dataset_configs(all_azure_dataset_ids=self.lightning_container.azure_datasets,
|
||||
all_dataset_mountpoints=self.lightning_container.dataset_mountpoints,
|
||||
all_local_datasets=all_local_datasets, # type: ignore
|
||||
datastore=default_datastore)
|
||||
try:
|
||||
if self.experiment_config.azureml:
|
||||
if not self.experiment_config.cluster:
|
||||
raise ValueError("You need to specify a cluster name via '--cluster NAME' to submit"
|
||||
"the script to run in AzureML")
|
||||
azure_run_info = submit_to_azure_if_needed(
|
||||
entry_script=entry_script,
|
||||
snapshot_root_directory=root_folder,
|
||||
script_params=script_params,
|
||||
conda_environment_file=temp_conda or conda_dependencies_files[0],
|
||||
aml_workspace=workspace,
|
||||
compute_cluster_name=self.experiment_config.cluster,
|
||||
environment_variables=environment_variables,
|
||||
default_datastore=default_datastore,
|
||||
experiment_name=self.lightning_container.name, # create_experiment_name(),
|
||||
input_datasets=input_datasets, # type: ignore
|
||||
num_nodes=self.experiment_config.num_nodes,
|
||||
wait_for_completion=False,
|
||||
ignored_folders=[],
|
||||
submit_to_azureml=self.experiment_config.azureml,
|
||||
docker_base_image=DEFAULT_DOCKER_BASE_IMAGE,
|
||||
tags=additional_run_tags(
|
||||
commandline_args=" ".join(script_params))
|
||||
)
|
||||
else:
|
||||
azure_run_info = submit_to_azure_if_needed(
|
||||
input_datasets=input_datasets, # type: ignore
|
||||
submit_to_azureml=False)
|
||||
finally:
|
||||
if temp_conda:
|
||||
temp_conda.unlink()
|
||||
# submit_to_azure_if_needed calls sys.exit after submitting to AzureML. We only reach this when running
|
||||
# the script locally or in AzureML.
|
||||
return azure_run_info
|
||||
|
||||
def run_in_situ(self, azure_run_info: AzureRunInfo) -> None:
|
||||
"""
|
||||
Actually run the AzureML job; this method will typically run on an Azure VM.
|
||||
|
||||
:param azure_run_info: Contains all information about the present run in AzureML, in particular where the
|
||||
datasets are mounted.
|
||||
"""
|
||||
# Only set the logging level now. Usually, when we set logging to DEBUG, we want diagnostics about the model
|
||||
# build itself, but not the tons of debug information that AzureML submissions create.
|
||||
# Suppress the logging from all processes but the one for GPU 0 on each node, to make log files more readable
|
||||
logging_to_stdout("INFO" if is_local_rank_zero() else "ERROR")
|
||||
package_setup_and_hacks()
|
||||
|
||||
# Set environment variables for multi-node training if needed. This function will terminate early
|
||||
# if it detects that it is not in a multi-node environment.
|
||||
set_environment_variables_for_multi_node()
|
||||
self.ml_runner = MLRunner(
|
||||
experiment_config=self.experiment_config,
|
||||
container=self.lightning_container,
|
||||
project_root=self.project_root)
|
||||
self.ml_runner.setup(azure_run_info)
|
||||
self.ml_runner.run()
|
||||
|
||||
|
||||
def run(project_root: Path) -> Tuple[LightningContainer, AzureRunInfo]:
|
||||
"""
|
||||
The main entry point for training and testing models from the commandline. This chooses a model to train
|
||||
via a commandline argument, runs training or testing, and writes all required info to disk and logs.
|
||||
|
||||
:param project_root: The root folder that contains all of the source code that should be executed.
|
||||
:return: If submitting to AzureML, returns the model configuration that was used for training,
|
||||
including commandline overrides applied (if any). For details on the arguments, see the constructor of Runner.
|
||||
"""
|
||||
runner = Runner(project_root)
|
||||
return runner.run()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
run(project_root=fixed_paths.repository_root_directory())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,11 +1,264 @@
|
|||
from typing import Optional
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from enum import Enum, unique
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from health_azure.utils import PathOrString
|
||||
|
||||
from health_ml.utils import fixed_paths
|
||||
|
||||
|
||||
MAX_PATH_LENGTH = 260
|
||||
|
||||
# convert string to None if an empty string or whitespace is provided
|
||||
empty_string_to_none = lambda x: None if (x is None or len(x.strip()) == 0) else x
|
||||
string_to_path = lambda x: None if (x is None or len(x.strip()) == 0) else Path(x)
|
||||
|
||||
EXPERIMENT_SUMMARY_FILE = "experiment_summary.txt"
|
||||
CHECKPOINT_FOLDER = "checkpoints"
|
||||
RUN_RECOVERY_ID_KEY = 'run_recovery_id'
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME = "effective_random_seed"
|
||||
RUN_RECOVERY_FROM_ID_KEY_NAME = "recovered_from"
|
||||
DEFAULT_AML_UPLOAD_DIR = "outputs"
|
||||
DEFAULT_LOGS_DIR_NAME = "logs"
|
||||
|
||||
|
||||
@unique
|
||||
class ModelExecutionMode(Enum):
|
||||
"""
|
||||
Model execution mode
|
||||
"""
|
||||
TRAIN = "Train"
|
||||
TEST = "Test"
|
||||
VAL = "Val"
|
||||
|
||||
|
||||
def check_is_any_of(message: str, actual: Optional[str], valid: Iterable[Optional[str]]) -> None:
|
||||
"""
|
||||
Raises an exception if 'actual' is not any of the given valid values.
|
||||
:param message: The prefix for the error message.
|
||||
:param actual: The actual value.
|
||||
:param valid: The set of valid strings that 'actual' is allowed to take on.
|
||||
:return:
|
||||
"""
|
||||
if actual not in valid:
|
||||
all_valid = ", ".join(["<None>" if v is None else v for v in valid])
|
||||
raise ValueError("{} must be one of [{}], but got: {}".format(message, all_valid, actual))
|
||||
|
||||
|
||||
logging_stdout_handler: Optional[logging.StreamHandler] = None
|
||||
logging_to_file_handler: Optional[logging.StreamHandler] = None
|
||||
|
||||
|
||||
def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:
|
||||
"""
|
||||
Instructs the Python logging libraries to start writing logs to stdout up to the given logging level.
|
||||
Logging will use a timestamp as the prefix, using UTC.
|
||||
|
||||
:param log_level: The logging level. All logging message with a level at or above this level will be written to
|
||||
stdout. log_level can be numeric, or one of the pre-defined logging strings (INFO, DEBUG, ...).
|
||||
"""
|
||||
log_level = standardize_log_level(log_level)
|
||||
logger = logging.getLogger()
|
||||
# This function can be called multiple times, in particular in AzureML when we first run a training job and
|
||||
# then a couple of tests, which also often enable logging. This would then add multiple handlers, and repeated
|
||||
# logging lines.
|
||||
global logging_stdout_handler
|
||||
if not logging_stdout_handler:
|
||||
print("Setting up logging to stdout.")
|
||||
# At startup, logging has one handler set, that writes to stderr, with a log level of 0 (logging.NOTSET)
|
||||
if len(logger.handlers) == 1:
|
||||
logger.removeHandler(logger.handlers[0])
|
||||
logging_stdout_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
_add_formatter(logging_stdout_handler)
|
||||
logger.addHandler(logging_stdout_handler)
|
||||
print(f"Setting logging level to {log_level}")
|
||||
logging_stdout_handler.setLevel(log_level)
|
||||
logger.setLevel(log_level)
|
||||
|
||||
|
||||
def standardize_log_level(log_level: Union[int, str]) -> int:
|
||||
"""
|
||||
|
||||
:param log_level: integer or string (any casing) version of a log level, e.g. 20 or "INFO".
|
||||
:return: integer version of the level; throws if the string does not name a level.
|
||||
"""
|
||||
if isinstance(log_level, str):
|
||||
log_level = log_level.upper()
|
||||
check_is_any_of("log_level", log_level, logging._nameToLevel.keys())
|
||||
return logging._nameToLevel[log_level]
|
||||
return log_level
|
||||
|
||||
|
||||
def _add_formatter(handler: logging.StreamHandler) -> None:
|
||||
"""
|
||||
Adds a logging formatter that includes the timestamp and the logging level.
|
||||
"""
|
||||
formatter = logging.Formatter(fmt="%(asctime)s %(levelname)-8s %(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%SZ")
|
||||
# noinspection PyTypeHints
|
||||
formatter.converter = time.gmtime # type: ignore
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def logging_section(gerund: str) -> Generator:
|
||||
"""
|
||||
Context manager to print "**** STARTING: ..." and "**** FINISHED: ..." lines around sections of the log,
|
||||
to help people locate particular sections. Usage:
|
||||
with logging_section("doing this and that"):
|
||||
do_this_and_that()
|
||||
|
||||
:param gerund: string expressing what happens in this section of the log.
|
||||
"""
|
||||
from time import time
|
||||
logging.info("")
|
||||
msg = f"**** STARTING: {gerund} "
|
||||
logging.info(msg + (100 - len(msg)) * "*")
|
||||
logging.info("")
|
||||
start_time = time()
|
||||
yield
|
||||
elapsed = time() - start_time
|
||||
logging.info("")
|
||||
if elapsed >= 3600:
|
||||
time_expr = f"{elapsed / 3600:0.2f} hours"
|
||||
elif elapsed >= 60:
|
||||
time_expr = f"{elapsed / 60:0.2f} minutes"
|
||||
else:
|
||||
time_expr = f"{elapsed:0.2f} seconds"
|
||||
msg = f"**** FINISHED: {gerund} after {time_expr} "
|
||||
logging.info(msg + (100 - len(msg)) * "*")
|
||||
logging.info("")
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
"""
|
||||
Returns True if the host operating system is Windows.
|
||||
"""
|
||||
return os.name == 'nt'
|
||||
|
||||
|
||||
def is_linux() -> bool:
|
||||
"""
|
||||
Returns True if the host operating system is a flavour of Linux.
|
||||
"""
|
||||
return os.name == 'posix'
|
||||
|
||||
|
||||
def check_properties_are_not_none(obj: Any, ignore: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
Checks to make sure the provided object has no properties that have a None value assigned.
|
||||
"""
|
||||
if ignore is not None:
|
||||
none_props = [k for k, v in vars(obj).items() if v is None and k not in ignore]
|
||||
if len(none_props) > 0:
|
||||
raise ValueError("Properties had None value: {}".format(none_props))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def change_working_directory(path_or_str: PathOrString) -> Generator:
|
||||
"""
|
||||
Context manager for changing the current working directory to the value provided. Outside the context
|
||||
manager, the original working directory will be restored.
|
||||
|
||||
:param path_or_str: The new directory to change to
|
||||
:yield: a _GeneratorContextManager object (this object itself is of no use, rather we are interested in
|
||||
the side effect of the working directory temporarily changing
|
||||
"""
|
||||
new_path = Path(path_or_str).expanduser()
|
||||
old_path = Path.cwd()
|
||||
os.chdir(new_path)
|
||||
yield
|
||||
os.chdir(str(old_path))
|
||||
|
||||
|
||||
def _create_generator(seed: Optional[int] = None) -> torch.Generator:
|
||||
"""
|
||||
Create Torch generator and sets seed with value if provided, or else with a random seed.
|
||||
|
||||
:param seed: Optional seed to set the Generator object with
|
||||
:return: Torch Generator object
|
||||
"""
|
||||
generator = torch.Generator()
|
||||
if seed is None:
|
||||
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||||
generator.manual_seed(seed)
|
||||
return generator
|
||||
|
||||
|
||||
def get_all_environment_files(project_root: Path, additional_files: Optional[List[Path]] = None) -> List[Path]:
|
||||
"""
|
||||
Returns a list of all Conda environment files that should be used. This is just an
|
||||
environment.yml file that lives at the project root folder, plus any additional files provided.
|
||||
|
||||
:param project_root: The root folder of the code that starts the present training run.
|
||||
:param additional_files: Optional list of additional environment files to merge
|
||||
:return: A list with 1 entry that is the root level repo's conda environment files.
|
||||
"""
|
||||
env_files = []
|
||||
project_yaml = project_root / fixed_paths.ENVIRONMENT_YAML_FILE_NAME
|
||||
if project_yaml.exists():
|
||||
env_files.append(project_yaml)
|
||||
if additional_files:
|
||||
for additional_file in additional_files:
|
||||
if additional_file.exists():
|
||||
env_files.append(additional_file)
|
||||
return env_files
|
||||
|
||||
|
||||
def get_all_pip_requirements_files() -> List[Path]:
|
||||
"""
|
||||
If the root level hi-ml directory is available (e.g. it has been installed as a submodule or
|
||||
downloaded directly into a parent repo) then we must add it's pip requirements to any environment
|
||||
definition. This function returns a list of the necessary pip requirements files. If the hi-ml
|
||||
root directory does not exist (e.g. hi-ml has been installed as a pip package, this is not necessary
|
||||
and so this function returns None)
|
||||
|
||||
:return: An list list of pip requirements files in the hi-ml and hi-ml-azure packages if relevant,
|
||||
or else an empty list
|
||||
"""
|
||||
files = []
|
||||
himl_root_dir = fixed_paths.himl_root_dir()
|
||||
if himl_root_dir is not None:
|
||||
himl_yaml = himl_root_dir / "hi-ml" / "run_requirements.txt"
|
||||
himl_az_yaml = himl_root_dir / "hi-ml-azure" / "run_requirements.txt"
|
||||
files.append(himl_yaml)
|
||||
files.append(himl_az_yaml)
|
||||
return files
|
||||
return []
|
||||
|
||||
|
||||
def create_unique_timestamp_id() -> str:
|
||||
"""
|
||||
Creates a unique string using the current time in UTC, up to seconds precision, with characters that
|
||||
are suitable for use in filenames. For example, on 31 Dec 2019 at 11:59:59pm UTC, the result would be
|
||||
2019-12-31T235959Z.
|
||||
"""
|
||||
unique_id = datetime.utcnow().strftime("%Y-%m-%dT%H%M%SZ")
|
||||
return unique_id
|
||||
|
||||
|
||||
def is_gpu_available() -> bool:
|
||||
"""
|
||||
|
||||
:return: True if a GPU with at least 1 device is available.
|
||||
"""
|
||||
return torch.cuda.is_available() and torch.cuda.device_count() > 0 # type: ignore
|
||||
|
||||
|
||||
def parse_model_id_and_version(model_id_and_version: str) -> None:
|
||||
"""
|
||||
When using registered models, the model id must have both the id and version present, in the format
|
||||
model_name:version. This function checks the input model id and raises a ValueError if it is not of the
|
||||
expected format
|
||||
"""
|
||||
if len(model_id_and_version.split(":")) != 2:
|
||||
raise ValueError(
|
||||
f"model id should be in the form 'model_name:version', got {model_id_and_version}")
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import param
|
||||
from importlib._bootstrap import ModuleSpec
|
||||
|
||||
from health_azure.utils import PathOrString
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils import fixed_paths
|
||||
|
||||
|
||||
class ModelConfigLoader(param.Parameterized):
|
||||
"""
|
||||
Helper class to manage model config loading.
|
||||
"""
|
||||
|
||||
def __init__(self, **params: Any):
|
||||
super().__init__(**params)
|
||||
default_module = self.get_default_search_module()
|
||||
self.module_search_specs: List[ModuleSpec] = [importlib.util.find_spec(default_module)] # type: ignore
|
||||
self._find_module_search_specs()
|
||||
|
||||
def _find_module_search_specs(self) -> None:
|
||||
"""
|
||||
Given the fully qualified model name, append the root folder to the system path (so that the config
|
||||
file can be discovered) and try to find a spec for the specifed module. If found, appends the spec
|
||||
to self.module_search_specs
|
||||
"""
|
||||
model_namespace_parts = self.model.split(".")
|
||||
if len(model_namespace_parts) == 1:
|
||||
# config must be in the default path. This is already in module_search_specs so we dont need to do anything
|
||||
return
|
||||
else:
|
||||
# Get the root folder of the fully qualified model name and ensure it is in the path to enable
|
||||
# discovery of the config file
|
||||
root_namespace = str(Path(model_namespace_parts[0]).absolute())
|
||||
if root_namespace not in sys.path:
|
||||
print(f"Adding {str(root_namespace)} to path")
|
||||
sys.path.insert(0, str(root_namespace))
|
||||
|
||||
# Strip the root folder (now in the path) and the class name from the model namespace, leaving the
|
||||
# module name - e.g. "mymodule.configs"
|
||||
model_namespace = ".".join([str(p) for p in model_namespace_parts[1:-1]]) # type: ignore
|
||||
|
||||
custom_spec = importlib.util.find_spec(model_namespace) # type: ignore
|
||||
if custom_spec is None:
|
||||
raise ValueError(f"Search namespace {model_namespace} was not found.")
|
||||
self.module_search_specs.append(custom_spec)
|
||||
|
||||
@staticmethod
|
||||
def get_default_search_module() -> str:
|
||||
from health_ml import configs # type: ignore
|
||||
return configs.__name__
|
||||
|
||||
def create_model_config_from_name(self, model_name: str) -> LightningContainer:
|
||||
"""
|
||||
Returns a model configuration for a model of the given name.
|
||||
To avoid having to import torch here, there are no references to LightningContainer.
|
||||
Searching for a class member called <model_name> in the search modules provided recursively.
|
||||
|
||||
:param model_name: Fully qualified name of the model for which to get the configs for - i.e.
|
||||
mymodule.configs.MyConfig
|
||||
"""
|
||||
if not model_name:
|
||||
raise ValueError("Unable to load a model configuration because the model name is missing.")
|
||||
|
||||
# get the class name from the fully qualified name
|
||||
model_name = model_name.split(".")[-1]
|
||||
|
||||
configs: Dict[str, LightningContainer] = {}
|
||||
|
||||
def _get_model_config(module_spec: ModuleSpec) -> Optional[LightningContainer]:
|
||||
"""
|
||||
Given a module specification check to see if it has a class property with
|
||||
the <model_name> provided, and instantiate that config class with the
|
||||
provided <config_overrides>. Otherwise, return None.
|
||||
|
||||
:param module_spec:
|
||||
:return: Instantiated model config if it was found.
|
||||
"""
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
logging.debug(f"Importing {module_spec.name}")
|
||||
target_module = importlib.import_module(module_spec.name)
|
||||
# The "if" clause checks that obj is a class, of the desired name, that is
|
||||
# defined in this module rather than being imported into it (and hence potentially
|
||||
# being found twice).
|
||||
_class = next(obj for name, obj in inspect.getmembers(target_module)
|
||||
if inspect.isclass(obj)
|
||||
and name == model_name # noqa: W503
|
||||
and inspect.getmodule(obj) == target_module) # noqa: W503
|
||||
logging.info(f"Found class {_class} in file {module_spec.origin}")
|
||||
# ignore the exception which will occur if the provided module cannot be loaded
|
||||
# or the loaded module does not have the required class as a member
|
||||
except Exception as e:
|
||||
exception_text = str(e)
|
||||
if exception_text != "":
|
||||
logging.warning(f"(from attempt to import module {module_spec.name}): {exception_text}")
|
||||
return None
|
||||
model_config = _class()
|
||||
return model_config
|
||||
|
||||
def _search_recursively_and_store(module_search_spec: ModuleSpec) -> None:
|
||||
"""
|
||||
Given a root namespace eg: A.B.C searches recursively in all child namespaces
|
||||
for class property with the <model_name> provided. If found, this is
|
||||
instantiated with the provided overrides, and added to the configs dictionary.
|
||||
|
||||
:param module_search_spec:
|
||||
"""
|
||||
root_namespace = module_search_spec.name
|
||||
namespaces_to_search: List[str] = []
|
||||
if module_search_spec.submodule_search_locations:
|
||||
logging.debug(f"Searching through {len(module_search_spec.submodule_search_locations)} folders that "
|
||||
f"match namespace {module_search_spec.name}: "
|
||||
f"{module_search_spec.submodule_search_locations}")
|
||||
for root in module_search_spec.submodule_search_locations:
|
||||
# List all python files in all the dirs under root, except for private dirs (prefixed with .)
|
||||
all_py_files = [x for x in Path(root).rglob("*.py") if ".." not in str(x)]
|
||||
for f in all_py_files:
|
||||
if f.is_file() and "__pycache__" not in str(f) and f.name != "setup.py":
|
||||
sub_namespace = path_to_namespace(f, root=root)
|
||||
namespaces_to_search.append(root_namespace + "." + sub_namespace)
|
||||
elif module_search_spec.origin:
|
||||
# The module search spec already points to a python file: Search only that.
|
||||
namespaces_to_search.append(module_search_spec.name)
|
||||
else:
|
||||
raise ValueError(f"Unable to process module spec: {module_search_spec}")
|
||||
|
||||
for n in namespaces_to_search: # type: ignore
|
||||
_module_spec = None
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
_module_spec = find_spec(n) # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if _module_spec:
|
||||
config = _get_model_config(_module_spec)
|
||||
if config:
|
||||
configs[n] = config # type: ignore
|
||||
|
||||
for search_spec in self.module_search_specs:
|
||||
_search_recursively_and_store(search_spec)
|
||||
|
||||
if len(configs) == 0:
|
||||
raise ValueError(
|
||||
f"Model name {model_name} was not found in search namespaces: "
|
||||
f"{[s.name for s in self.module_search_specs]}.")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(
|
||||
f"Multiple instances of model name {model_name} were found in namespaces: {configs.keys()}.")
|
||||
else:
|
||||
return list(configs.values())[0]
|
||||
|
||||
|
||||
def path_to_namespace(path: Path, root: PathOrString = fixed_paths.repository_root_directory()) -> str:
|
||||
"""
|
||||
Given a path (in form R/A/B/C) and an optional root directory R, create a namespace A.B.C.
|
||||
If root is provided, then path must be a relative child to it.
|
||||
|
||||
:param path: Path to convert to namespace
|
||||
:param root: Path prefix to remove from namespace (default is project root)
|
||||
:return:
|
||||
"""
|
||||
return ".".join([Path(x).stem for x in path.relative_to(root).parts])
|
|
@ -211,7 +211,7 @@ class BatchTimeCallback(Callback):
|
|||
self.write_and_log_epoch_time(is_training=True)
|
||||
self.write_and_log_epoch_time(is_training=False)
|
||||
|
||||
def on_train_batch_start(self,
|
||||
def on_train_batch_start(self, # type: ignore
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
batch: Any,
|
||||
|
@ -229,7 +229,7 @@ class BatchTimeCallback(Callback):
|
|||
) -> None:
|
||||
self.batch_start(batch_idx=batch_idx, is_training=False)
|
||||
|
||||
def on_train_batch_end(self,
|
||||
def on_train_batch_end(self, # type: ignore
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
outputs: Any,
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from health_azure.utils import PathOrString
|
||||
|
||||
ENVIRONMENT_YAML_FILE_NAME = "environment.yml"
|
||||
|
||||
|
||||
def get_environment_yaml_file() -> Path:
|
||||
"""
|
||||
Returns the path where the environment.yml file is located, in the repository root directory.
|
||||
The function throws an exception if the file is not found
|
||||
|
||||
:return: The full path to the environment files.
|
||||
"""
|
||||
# The environment file is copied into the package folder in setup.py.
|
||||
root_dir = repository_root_directory()
|
||||
env = root_dir / ENVIRONMENT_YAML_FILE_NAME
|
||||
if not env.exists():
|
||||
raise ValueError(f"File {ENVIRONMENT_YAML_FILE_NAME} was not found not found in in the repository root"
|
||||
f"{root_dir}.")
|
||||
return env
|
||||
|
||||
|
||||
def repository_root_directory(path: Optional[PathOrString] = None) -> Path:
|
||||
"""
|
||||
Gets the full path to the root directory that holds the present repository.
|
||||
|
||||
:param path: if provided, a relative path to append to the absolute path to the repository root.
|
||||
:return: The full path to the repository's root directory, with symlinks resolved if any.
|
||||
"""
|
||||
root = Path.cwd()
|
||||
if path:
|
||||
full_path = root / path
|
||||
assert full_path.exists(), f"Path {full_path} doesn't exist"
|
||||
return root / path
|
||||
else:
|
||||
return root
|
||||
|
||||
|
||||
def himl_root_dir() -> Optional[Path]:
|
||||
"""
|
||||
Attempts to return the path to the top-level hi-ml repo that contains the hi-ml and hi-ml-azure packages.
|
||||
This top level repo will only be present if hi-ml has been installed as a git submodule, or the repo has
|
||||
been directly downlaoded. Otherwise (e.g.if hi-ml has been installed as a pip package) returns None
|
||||
|
||||
return: Path to the himl root dir if it exists, else None
|
||||
"""
|
||||
health_ml_root = Path(__file__).parent.parent
|
||||
if health_ml_root.parent.stem == "site-packages":
|
||||
return None
|
||||
himl_root = health_ml_root.parent.parent.parent
|
||||
assert (himl_root / "hi-ml").is_dir() and (himl_root / "hi-ml-azure").is_dir()
|
||||
return himl_root
|
|
@ -0,0 +1,110 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from health_ml.utils.type_annotations import DictStrFloat, DictStrFloatOrFloatList
|
||||
|
||||
|
||||
class StoringLogger(LightningLoggerBase):
|
||||
"""
|
||||
A Pytorch Lightning logger that simply stores the metrics that are written to it.
|
||||
Used for diagnostic purposes in unit tests.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.results_per_epoch: Dict[int, DictStrFloatOrFloatList] = {}
|
||||
self.hyperparams: Any = None
|
||||
# Fields to store diagnostics for unit testing
|
||||
self.train_diagnostics: List[Any] = []
|
||||
self.val_diagnostics: List[Any] = []
|
||||
self.results_without_epoch: List[DictStrFloat] = []
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: DictStrFloat, step: Optional[int] = None) -> None:
|
||||
logging.debug(f"StoringLogger step={step}: {metrics}")
|
||||
epoch_name = "epoch"
|
||||
if epoch_name not in metrics:
|
||||
# Metrics without an "epoch" key are logged during testing, for example
|
||||
self.results_without_epoch.append(metrics)
|
||||
return
|
||||
epoch = int(metrics[epoch_name])
|
||||
del metrics[epoch_name]
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, int):
|
||||
metrics[key] = float(value)
|
||||
if epoch in self.results_per_epoch:
|
||||
current_results = self.results_per_epoch[epoch]
|
||||
for key, value in metrics.items():
|
||||
if key in current_results:
|
||||
logging.debug(f"StoringLogger: appending results for metric {key}")
|
||||
current_metrics = current_results[key]
|
||||
if isinstance(current_metrics, list):
|
||||
current_metrics.append(value)
|
||||
else:
|
||||
current_results[key] = [current_metrics, value]
|
||||
else:
|
||||
current_results[key] = value
|
||||
else:
|
||||
self.results_per_epoch[epoch] = metrics # type: ignore
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Any) -> None:
|
||||
self.hyperparams = params
|
||||
|
||||
def experiment(self) -> Any:
|
||||
return None
|
||||
|
||||
def name(self) -> Any:
|
||||
return ""
|
||||
|
||||
def version(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def epochs(self) -> Iterable[int]:
|
||||
"""
|
||||
Gets the epochs for which the present object holds any results.
|
||||
"""
|
||||
return self.results_per_epoch.keys()
|
||||
|
||||
def extract_by_prefix(self, epoch: int, prefix_filter: str = "") -> DictStrFloat:
|
||||
"""
|
||||
Reads the set of metrics for a given epoch, filters them to retain only those that have the given prefix,
|
||||
and returns the filtered ones. This is used to break a set
|
||||
of results down into those for training data (prefix "Train/") or validation data (prefix "Val/").
|
||||
|
||||
:param epoch: The epoch for which results should be read.
|
||||
:param prefix_filter: If empty string, return all metrics. If not empty, return only those metrics that
|
||||
have a name starting with `prefix`, and strip off the prefix.
|
||||
:return: A metrics dictionary.
|
||||
"""
|
||||
epoch_results = self.results_per_epoch.get(epoch, None)
|
||||
if epoch_results is None:
|
||||
raise KeyError(f"No results are stored for epoch {epoch}")
|
||||
filtered = {}
|
||||
for key, value in epoch_results.items():
|
||||
assert isinstance(key, str), f"All dictionary keys should be strings, but got: {type(key)}"
|
||||
# Add the metric if either there is no prefix filter (prefix does not matter), or if the prefix
|
||||
# filter is supplied and really matches the metric name
|
||||
if (not prefix_filter) or key.startswith(prefix_filter):
|
||||
stripped_key = key[len(prefix_filter):]
|
||||
filtered[stripped_key] = value # type: ignore
|
||||
return filtered # type: ignore
|
||||
|
||||
def to_metrics_dicts(self, prefix_filter: str = "") -> Dict[int, DictStrFloat]:
|
||||
"""
|
||||
Converts the results stored in the present object into a two-level dictionary, mapping from epoch number to
|
||||
metric name to metric value. Only metrics where the name starts with the given prefix are retained, and the
|
||||
prefix is stripped off in the result.
|
||||
|
||||
:param prefix_filter: If empty string, return all metrics. If not empty, return only those metrics that
|
||||
have a name starting with `prefix`, and strip off the prefix.
|
||||
:return: A dictionary mapping from epoch number to metric name to metric value.
|
||||
"""
|
||||
return {epoch: self.extract_by_prefix(epoch, prefix_filter) for epoch in self.epochs}
|
|
@ -0,0 +1,287 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from itertools import combinations
|
||||
from math import ceil
|
||||
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from health_ml.utils import common_utils
|
||||
from health_ml.utils.common_utils import ModelExecutionMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetSplits:
|
||||
train: pd.DataFrame
|
||||
val: pd.DataFrame
|
||||
test: pd.DataFrame
|
||||
subject_column: Optional[str] = None
|
||||
group_column: Optional[str] = None
|
||||
allow_empty: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
common_utils.check_properties_are_not_none(self)
|
||||
|
||||
def pairwise_intersection(*collections: Iterable) -> Set:
|
||||
"""
|
||||
Returns any element that appears in more than one collection
|
||||
|
||||
:return: a Set of elements that appear in more than one collection
|
||||
"""
|
||||
intersection = set()
|
||||
for col1, col2 in combinations(map(set, collections), 2):
|
||||
intersection |= col1 & col2
|
||||
return intersection
|
||||
|
||||
# perform dataset split validity assertions
|
||||
unique_train, unique_test, unique_val = self.unique_subjects()
|
||||
intersection = pairwise_intersection(unique_train, unique_test, unique_val)
|
||||
|
||||
if len(intersection) != 0:
|
||||
raise ValueError("Train, Test, and Val splits must have no intersection, found: {}".format(intersection))
|
||||
|
||||
if self.group_column is not None:
|
||||
groups_train = self.train[self.group_column].unique()
|
||||
groups_test = self.test[self.group_column].unique()
|
||||
groups_val = self.val[self.group_column].unique()
|
||||
group_intersection = pairwise_intersection(groups_train, groups_test, groups_val)
|
||||
if len(group_intersection) != 0:
|
||||
raise ValueError("Train, Test, and Val splits must have no intersecting groups, found: {}"
|
||||
.format(group_intersection))
|
||||
|
||||
if (not self.allow_empty) and any([len(x) == 0 for x in [unique_train, unique_val]]):
|
||||
raise ValueError("train_ids({}), val_ids({}) must have at least one value"
|
||||
.format(len(unique_train), len(unique_val)))
|
||||
|
||||
def __str__(self) -> str:
|
||||
unique_train, unique_test, unique_val = self.unique_subjects()
|
||||
return f'Train: {len(unique_train)}, Test: {len(unique_test)}, and Val: {len(unique_val)}. ' \
|
||||
f'Total subjects: {len(unique_train) + len(unique_test) + len(unique_val)}'
|
||||
|
||||
def unique_subjects(self) -> Tuple[Any, Any, Any]:
|
||||
"""
|
||||
Return a tuple of pandas Series of unique subjects across train, test and validation data splits,
|
||||
based on self.subject_column
|
||||
|
||||
:return: a tuple of pandas Series
|
||||
"""
|
||||
return (self.train[self.subject_column].unique(),
|
||||
self.test[self.subject_column].unique(),
|
||||
self.val[self.subject_column].unique())
|
||||
|
||||
def number_of_subjects(self) -> int:
|
||||
"""
|
||||
Returns the sum of unique subjects in the dataset (identified by self.subject_column), summed
|
||||
over train, test and validation data splits
|
||||
|
||||
:return: An integer representing the number of unique subjects
|
||||
"""
|
||||
unique_train, unique_test, unique_val = self.unique_subjects()
|
||||
return len(unique_train) + len(unique_test) + len(unique_val)
|
||||
|
||||
def __getitem__(self, mode: ModelExecutionMode) -> pd.DataFrame:
|
||||
"""
|
||||
Retrieve either the train, validation or test data in the form of a Pandas dataframe, depending
|
||||
on the current execution mode
|
||||
|
||||
:param mode: The current ModelExecutionMode
|
||||
:return: A dataframe of the relevant data split
|
||||
"""
|
||||
if mode == ModelExecutionMode.TRAIN:
|
||||
return self.train
|
||||
elif mode == ModelExecutionMode.TEST:
|
||||
return self.test
|
||||
elif mode == ModelExecutionMode.VAL:
|
||||
return self.val
|
||||
else:
|
||||
raise ValueError(f"Model execution mode not recognized: {mode}")
|
||||
|
||||
@staticmethod
|
||||
def get_subject_ranges_for_splits(population: Sequence[str],
|
||||
proportion_train: float,
|
||||
proportion_test: float,
|
||||
proportion_val: float) \
|
||||
-> Dict[ModelExecutionMode, Set[str]]:
|
||||
"""
|
||||
Get mutually exclusive subject ranges for each dataset split (w.r.t to the proportion provided)
|
||||
ensuring all sets have at least one item in them when possible.
|
||||
|
||||
:param population: all subjects
|
||||
:param proportion_train: proportion for the train set.
|
||||
:param proportion_test: proportion for the test set.
|
||||
:param proportion_val: proportion for the validation set.
|
||||
:return: Train, Test, and Val splits
|
||||
"""
|
||||
sum_proportions = proportion_train + proportion_val + proportion_test
|
||||
if not np.isclose(sum_proportions, 1):
|
||||
raise ValueError("proportion_train({}) + proportion_val({}) + proportion_test({}) must be ~ 1, found: {}"
|
||||
.format(proportion_train, proportion_val, proportion_test, sum_proportions))
|
||||
|
||||
if not 0 <= proportion_test < 1:
|
||||
raise ValueError("proportion_test({}) must be in range [0, 1)"
|
||||
.format(proportion_test))
|
||||
|
||||
if not all([0 < x < 1 for x in [proportion_train, proportion_val]]):
|
||||
raise ValueError("proportion_train({}) and proportion_val({}) must be in range (0, 1)"
|
||||
.format(proportion_train, proportion_val))
|
||||
|
||||
subjects_train, subjects_test, subjects_val = (set(population[0:1]),
|
||||
set(population[1:2]),
|
||||
set(population[2:3]))
|
||||
remaining = list(population[3:])
|
||||
if proportion_test == 0:
|
||||
remaining = list(subjects_test) + remaining
|
||||
subjects_test = set()
|
||||
|
||||
subjects_train |= set(remaining[: ceil(len(remaining) * proportion_train)])
|
||||
if len(subjects_test) > 0:
|
||||
subjects_test |= set(remaining[len(subjects_train):
|
||||
len(subjects_train) + ceil(len(remaining) * proportion_test)])
|
||||
subjects_val |= set(remaining) - (subjects_train | subjects_test)
|
||||
result = {
|
||||
ModelExecutionMode.TRAIN: subjects_train,
|
||||
ModelExecutionMode.TEST: subjects_test,
|
||||
ModelExecutionMode.VAL: subjects_val
|
||||
}
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _from_split_keys(df: pd.DataFrame,
|
||||
train_keys: Sequence[str],
|
||||
test_keys: Sequence[str],
|
||||
val_keys: Sequence[str],
|
||||
*, # make column names keyword-only arguments to avoid mistakes when providing both
|
||||
key_column: str,
|
||||
subject_column: str,
|
||||
group_column: Optional[str]) -> DatasetSplits:
|
||||
"""
|
||||
Takes a slice of values from each data split train/test/val for the provided keys.
|
||||
|
||||
:param df: the input DataFrame
|
||||
:param train_keys: keys for training.
|
||||
:param test_keys: keys for testing.
|
||||
:param val_keys: keys for validation.
|
||||
:param key_column: name of the column the provided keys belong to
|
||||
:param subject_column: subject id column name
|
||||
:param group_column: grouping column name; if given, samples from each group will always be
|
||||
in the same subset (train, val, or test) and cross-validation fold.
|
||||
:return: Data splits with respected dataset split ids.
|
||||
"""
|
||||
train_df = DatasetSplits.get_df_from_ids(df, train_keys, key_column)
|
||||
test_df = DatasetSplits.get_df_from_ids(df, test_keys, key_column)
|
||||
val_df = DatasetSplits.get_df_from_ids(df, val_keys, key_column)
|
||||
|
||||
return DatasetSplits(train=train_df, test=test_df, val=val_df,
|
||||
subject_column=subject_column, group_column=group_column)
|
||||
|
||||
@staticmethod
|
||||
def from_proportions(df: pd.DataFrame,
|
||||
proportion_train: float,
|
||||
proportion_test: float,
|
||||
proportion_val: float,
|
||||
*, # make column names keyword-only arguments to avoid mistakes when providing both
|
||||
subject_column: str = "",
|
||||
group_column: Optional[str] = None,
|
||||
shuffle: bool = True,
|
||||
random_seed: int = 0) -> DatasetSplits:
|
||||
"""
|
||||
Creates a split of a dataset into train, test, and validation set, according to fixed proportions using
|
||||
the "subject" column in the dataframe, or the group column, if given.
|
||||
|
||||
:param df: The dataframe containing all subjects.
|
||||
:param proportion_train: proportion for the train set.
|
||||
:param proportion_test: proportion for the test set.
|
||||
:param subject_column: Subject id column name
|
||||
:param group_column: grouping column name; if given, samples from each group will always be
|
||||
in the same subset (train, val, or test) and cross-validation fold.
|
||||
:param proportion_val: proportion for the validation set.
|
||||
:param shuffle: If True the subjects in the dataframe will be shuffle before performing splits.
|
||||
:param random_seed: Random seed to be used for shuffle 0 is default.
|
||||
:return:
|
||||
"""
|
||||
key_column: str = subject_column if group_column is None else group_column
|
||||
split_keys = df[key_column].unique()
|
||||
if shuffle:
|
||||
# fix the random seed so we can guarantee reproducibility when working with shuffle
|
||||
random.Random(random_seed).shuffle(split_keys)
|
||||
ranges = DatasetSplits.get_subject_ranges_for_splits(
|
||||
split_keys,
|
||||
proportion_train=proportion_train,
|
||||
proportion_val=proportion_val,
|
||||
proportion_test=proportion_test
|
||||
)
|
||||
return DatasetSplits._from_split_keys(df,
|
||||
list(ranges[ModelExecutionMode.TRAIN]),
|
||||
list(ranges[ModelExecutionMode.TEST]),
|
||||
list(ranges[ModelExecutionMode.VAL]),
|
||||
key_column=key_column,
|
||||
subject_column=subject_column,
|
||||
group_column=group_column)
|
||||
|
||||
@staticmethod
|
||||
def from_subject_ids(df: pd.DataFrame,
|
||||
train_ids: Sequence[str],
|
||||
test_ids: Sequence[str],
|
||||
val_ids: Sequence[str],
|
||||
*, # make column names keyword-only arguments to avoid mistakes when providing both
|
||||
subject_column: str = "",
|
||||
group_column: Optional[str] = None) -> DatasetSplits:
|
||||
"""
|
||||
Assuming a DataFrame with columns subject
|
||||
Takes a slice of values from each data split train/test/val for the provided ids.
|
||||
|
||||
:param df: the input DataFrame
|
||||
:param train_ids: ids for training.
|
||||
:param test_ids: ids for testing.
|
||||
:param val_ids: ids for validation.
|
||||
:param subject_column: subject id column name
|
||||
:param group_column: grouping column name; if given, samples from each group will always be
|
||||
in the same subset (train, val, or test) and cross-validation fold.
|
||||
:return: Data splits with respected dataset split ids.
|
||||
"""
|
||||
return DatasetSplits._from_split_keys(df, train_ids, test_ids, val_ids, key_column=subject_column,
|
||||
subject_column=subject_column, group_column=group_column)
|
||||
|
||||
@staticmethod
|
||||
def from_groups(df: pd.DataFrame,
|
||||
train_groups: Sequence[str],
|
||||
test_groups: Sequence[str],
|
||||
val_groups: Sequence[str],
|
||||
*, # make column names keyword-only arguments to avoid mistakes when providing both
|
||||
group_column: str,
|
||||
subject_column: str = "") -> DatasetSplits:
|
||||
"""
|
||||
Assuming a DataFrame with columns subject
|
||||
Takes a slice of values from each data split train/test/val for the provided groups.
|
||||
|
||||
:param df: the input DataFrame
|
||||
:param train_groups: groups for training.
|
||||
:param test_groups: groups for testing.
|
||||
:param val_groups: groups for validation.
|
||||
:param subject_column: subject id column name
|
||||
:param group_column: grouping column name; if given, samples from each group will always be
|
||||
in the same subset (train, val, or test) and cross-validation fold.
|
||||
:return: Data splits with respected dataset split ids.
|
||||
"""
|
||||
return DatasetSplits._from_split_keys(df, train_groups, test_groups, val_groups, key_column=group_column,
|
||||
subject_column=subject_column, group_column=group_column)
|
||||
|
||||
@staticmethod
|
||||
def get_df_from_ids(df: pd.DataFrame, ids: Sequence[str],
|
||||
subject_column: str = "") -> pd.DataFrame:
|
||||
"""
|
||||
Retrieve a subset dataframe where the subject column is restricted to a sequence of provided ids
|
||||
|
||||
:param df: The dataframe to restrict
|
||||
:param ids: The ids to lookup
|
||||
:param subject_column: The column to lookup ids in. Defaults to ""
|
||||
:return: A subset of the dataframe
|
||||
"""
|
||||
return df[df[subject_column].isin(ids)]
|
|
@ -0,0 +1,12 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, TypeVar, Union
|
||||
|
||||
T = TypeVar('T')
|
||||
PathOrString = Union[Path, str]
|
||||
TupleFloat2 = Tuple[float, float]
|
||||
DictStrFloat = Dict[str, float]
|
||||
DictStrFloatOrFloatList = Dict[str, Union[float, List[float]]]
|
|
@ -0,0 +1,44 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from health_ml.utils import common_utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize("os_name, expected_val", [
|
||||
("nt", True),
|
||||
("None", False),
|
||||
("posix", False),
|
||||
("", False)])
|
||||
def test_is_windows(os_name: str, expected_val: bool) -> None:
|
||||
with patch.object(os, "name", new=os_name):
|
||||
assert common_utils.is_windows() == expected_val
|
||||
|
||||
|
||||
@pytest.mark.parametrize("os_name, expected_val", [
|
||||
("nt", False),
|
||||
("None", False),
|
||||
("posix", True),
|
||||
("", False)])
|
||||
def test_is_linux(os_name: str, expected_val: bool) -> None:
|
||||
with patch.object(os, "name", new=os_name):
|
||||
assert common_utils.is_linux() == expected_val
|
||||
|
||||
|
||||
def test_change_working_directory(tmp_path: Path) -> None:
|
||||
"""
|
||||
Test that change_working_directory temporarily changes the current working directory, but that the context manager
|
||||
works to restore the original working directory
|
||||
"""
|
||||
orig_cwd_str = str(Path.cwd())
|
||||
tmp_path_str = str(tmp_path)
|
||||
assert orig_cwd_str != tmp_path_str
|
||||
with common_utils.change_working_directory(tmp_path):
|
||||
assert str(Path.cwd()) == tmp_path_str
|
||||
# outside of the context, the original working directory should be restored
|
||||
assert str(Path.cwd()) == orig_cwd_str != tmp_path_str
|
|
@ -0,0 +1,139 @@
|
|||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from health_azure.utils import is_running_on_azure_agent
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils.config_loader import ModelConfigLoader, path_to_namespace
|
||||
from testhiml.utils.fixed_paths_for_tests import full_ml_test_data_path, tests_root_directory
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config_loader() -> ModelConfigLoader:
|
||||
return ModelConfigLoader(**{"model": "HelloContainer"})
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def hello_config() -> Any:
|
||||
from health_ml.configs import hello_container # type: ignore
|
||||
assert Path(hello_container.__file__).exists(), "Can't find hello_container config"
|
||||
return hello_container
|
||||
|
||||
|
||||
def test_find_module_search_specs(config_loader: ModelConfigLoader) -> None:
|
||||
# By default, property module_search_specs includes the default config path - health_ml.configs
|
||||
len_search_specs_before = len(config_loader.module_search_specs)
|
||||
assert any([m.name == "health_ml.configs" for m in config_loader.module_search_specs])
|
||||
config_loader._find_module_search_specs()
|
||||
# nothing should have been added to module_search_specs
|
||||
assert len(config_loader.module_search_specs) == len_search_specs_before
|
||||
|
||||
|
||||
def test_find_module_search_specs_outside_default_dir() -> None:
|
||||
if is_running_on_azure_agent():
|
||||
return
|
||||
model_name = "NewConfig"
|
||||
|
||||
dummy_config_dir = Path.cwd() / "test_configs"
|
||||
dummy_config_dir.mkdir()
|
||||
dummy_config_path = dummy_config_dir / "new_config.py"
|
||||
dummy_config = f"""class {model_name}:
|
||||
def __init__(self):
|
||||
pass
|
||||
"""
|
||||
dummy_config_path.touch()
|
||||
dummy_config_path.write_text(dummy_config)
|
||||
|
||||
dummy_config_namespace = f"test_configs.new_config.{model_name}"
|
||||
config_loader2 = ModelConfigLoader(**{"model": f"{dummy_config_namespace}"})
|
||||
# The root "testhiml" should now be in the system path and the module "outputs" should be in module_search_specs
|
||||
# this wont be in the previous results, since the default path was used. The default search_spec (health_ml.configs)
|
||||
# should also be in the results for hte new
|
||||
assert any([m.name == "new_config" for m in config_loader2.module_search_specs])
|
||||
assert any([m.name == "health_ml.configs" for m in config_loader2.module_search_specs])
|
||||
|
||||
# If the file doesnt exist but the parent module does, the module will still be appended to module_search_specs
|
||||
# at this stage
|
||||
config_loader3 = ModelConfigLoader(**{"model": "test_configs.new_config.idontexist"})
|
||||
assert any([m.name == "new_config" for m in config_loader3.module_search_specs])
|
||||
|
||||
# If the parent module doesn't exist, an Exception should be raised
|
||||
with pytest.raises(Exception) as e:
|
||||
ModelConfigLoader(**{"model": "testhiml.idontexist.idontexist"})
|
||||
assert "was not found" in str(e)
|
||||
|
||||
shutil.rmtree(dummy_config_dir)
|
||||
|
||||
|
||||
def test_get_default_search_module(config_loader: ModelConfigLoader) -> None:
|
||||
search_module = config_loader.get_default_search_module()
|
||||
assert search_module == "health_ml.configs"
|
||||
|
||||
|
||||
def test_create_model_config_from_name(config_loader: ModelConfigLoader, hello_config: Any
|
||||
) -> None:
|
||||
# if no model name is given, an exception should be raised
|
||||
with pytest.raises(Exception) as e:
|
||||
config_loader.create_model_config_from_name("")
|
||||
assert "the model name is missing" in str(e)
|
||||
|
||||
# if no config is found matching the model name, an exception should be raised
|
||||
with pytest.raises(Exception) as e:
|
||||
config_loader.create_model_config_from_name("idontexist")
|
||||
assert "was not found in search namespaces" in str(e)
|
||||
|
||||
# if > 1 config is found matching the model name, an exception should be raised
|
||||
config_name = "HelloContainer"
|
||||
hello_config_path = Path(hello_config.__file__)
|
||||
duplicate_config_file = hello_config_path.parent / "hello_container_2.py"
|
||||
duplicate_config_file.touch()
|
||||
shutil.copyfile(str(hello_config_path), str(duplicate_config_file))
|
||||
with pytest.raises(Exception) as e:
|
||||
config_loader.create_model_config_from_name(config_name)
|
||||
assert "Multiple instances of model name " in str(e)
|
||||
duplicate_config_file.unlink()
|
||||
|
||||
# if exactly one config is found, expect a LightningContainer to be returned
|
||||
container = config_loader.create_model_config_from_name(config_name)
|
||||
assert isinstance(container, LightningContainer)
|
||||
assert container.model_name == config_name
|
||||
|
||||
|
||||
def test_config_in_dif_location(tmp_path: Path, hello_config: Any) -> None:
|
||||
himl_root = Path(hello_config.__file__).parent.parent
|
||||
model_name = "HelloContainer"
|
||||
new_config_path = himl_root / "hello_container_to_delete.py"
|
||||
new_config_path.touch()
|
||||
hello_config_path = Path(hello_config.__file__)
|
||||
shutil.copyfile(str(hello_config_path), str(new_config_path))
|
||||
config_loader = ModelConfigLoader(model=model_name)
|
||||
|
||||
# Trying to find this config should now cause an exception as it should find it in both "health_ml" and
|
||||
# in "health_ml.configs"
|
||||
with pytest.raises(Exception) as e:
|
||||
config_loader.create_model_config_from_name(model_name)
|
||||
assert "Multiple instances of model name HelloContainer were found in namespaces: " \
|
||||
"dict_keys(['health_ml.configs.hello_container', 'health_ml.hello_container_to_delete']) " in str(e)
|
||||
new_config_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_external", [True, False])
|
||||
def test_path_to_namespace(is_external: bool) -> None:
|
||||
"""
|
||||
A test to check conversion between namespace to path for InnerEye and external namespaces
|
||||
"""
|
||||
tests_root_dir = tests_root_directory()
|
||||
if is_external:
|
||||
folder_name = "logs"
|
||||
full_folder = tests_root_dir / folder_name
|
||||
assert path_to_namespace(
|
||||
path=full_folder,
|
||||
root=tests_root_dir
|
||||
) == folder_name
|
||||
else:
|
||||
assert path_to_namespace(
|
||||
path=full_ml_test_data_path(),
|
||||
root=tests_root_dir
|
||||
) == "ML.test_data"
|
|
@ -0,0 +1,229 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from random import randint
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
|
||||
from param import Number
|
||||
from pathlib import Path
|
||||
|
||||
from health_ml.deep_learning_config import DatasetParams, WorkflowParams, OutputParams, OptimizerParams, \
|
||||
ExperimentFolderHandler, TrainerParams
|
||||
|
||||
|
||||
def test_validate_workflow_params() -> None:
|
||||
|
||||
# DeepLearningConfig cannot be initialized with more than one of these parameters set
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_datasets=Path("foo"),
|
||||
local_weights_path=[Path("foo")],
|
||||
weights_url=["bar"]).validate()
|
||||
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_dataset=Path("foo"),
|
||||
local_weights_path=[Path("foo")],
|
||||
model_id="foo:1").validate()
|
||||
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_dataset=Path("foo"),
|
||||
weights_url=["foo"],
|
||||
model_id="foo:1").validate()
|
||||
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_dataset=Path("foo"),
|
||||
local_weights_path=[Path("foo")],
|
||||
weights_url=["foo"],
|
||||
model_id="foo:1").validate()
|
||||
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_dataset=Path("foo"),
|
||||
model_id="foo").validate()
|
||||
assert "model id should be in the form 'model_name:version'" in ex.value.args[0]
|
||||
|
||||
# The following should be okay
|
||||
WorkflowParams(local_dataset=Path("foo"), local_weights_path=[Path("foo")]).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), weights_url=["foo"]).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), model_id="foo:1").validate()
|
||||
|
||||
|
||||
def test_workflow_params_get_effective_random_seed() -> None:
|
||||
params = WorkflowParams(local_dataset=Path("foo"), weights_url=["foo"])
|
||||
seed = params.get_effective_random_seed()
|
||||
assert seed == params.random_seed
|
||||
|
||||
|
||||
def test_validate_dataset_params() -> None:
|
||||
# DatasetParams cannot be initialized with neither of azure_datasets or local_datasets set
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DatasetParams(local_datasets=[], azure_datasets=[]).validate()
|
||||
assert ex.value.args[0] == "Either local_datasets or azure_datasets must be set."
|
||||
|
||||
# If azure_datasets or local_datasets is not a list an exception should be raised
|
||||
with pytest.raises(Exception) as e:
|
||||
DatasetParams(local_datasets="", azure_datasets=[]).validate()
|
||||
assert "must be a list" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
DatasetParams(local_datasets=[], azure_datasets=None).validate()
|
||||
assert "must be a list" in str(e)
|
||||
|
||||
# local datasets and dataset_mountpoints must be Paths
|
||||
with pytest.raises(Exception) as e:
|
||||
DatasetParams(local_datasets=["foo"])
|
||||
assert "is not an instance of" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
DatasetParams(dataset_mountpoints=["foo"])
|
||||
assert "is not an instance of" in str(e)
|
||||
|
||||
# The following should be okay
|
||||
DatasetParams(local_datasets=[Path("foo")]).validate()
|
||||
DatasetParams(azure_datasets=["bar"]).validate()
|
||||
|
||||
config = DatasetParams(local_datasets=[Path("foo")],
|
||||
azure_datasets=[""])
|
||||
config.validate()
|
||||
assert config.azure_datasets == [""]
|
||||
|
||||
config = DatasetParams(azure_datasets=["foo"])
|
||||
config.validate()
|
||||
assert len(config.azure_datasets) == 1
|
||||
|
||||
config = DatasetParams(local_datasets=[Path("foo")],
|
||||
azure_datasets=[""])
|
||||
config.validate()
|
||||
assert len(config.azure_datasets) == 1
|
||||
|
||||
config = DatasetParams(azure_datasets=["foo", "bar"])
|
||||
config.validate()
|
||||
assert len(config.azure_datasets) == 2
|
||||
|
||||
config = DatasetParams(azure_datasets=["foo"],
|
||||
dataset_mountpoints=[Path()])
|
||||
config.validate()
|
||||
assert config.dataset_mountpoints == [Path()]
|
||||
|
||||
config = DatasetParams(azure_datasets=["foo"],
|
||||
dataset_mountpoints=[Path("foo")])
|
||||
config.validate()
|
||||
assert len(config.dataset_mountpoints) == 1
|
||||
|
||||
# the number of mountpoints must not be larger than the number of datasets
|
||||
with pytest.raises(ValueError) as e:
|
||||
DatasetParams(azure_datasets=["foo"],
|
||||
dataset_mountpoints=[Path("foo"), Path("bar")]).validate()
|
||||
assert "Expected the number of azure datasets to equal the number of mountpoints" in str(e)
|
||||
|
||||
|
||||
def test_output_params_set_output_to() -> None:
|
||||
# output_to must be Path type
|
||||
with pytest.raises(Exception) as e:
|
||||
OutputParams(output_to="foo")
|
||||
assert "must be an instance of Path" in str(e)
|
||||
|
||||
old_path = Path()
|
||||
config = OutputParams(output_to=old_path)
|
||||
assert config.outputs_folder == old_path
|
||||
new_path = Path("dummy")
|
||||
config.set_output_to(new_path)
|
||||
# create_filesystem gets called inside
|
||||
assert config.output_to == new_path
|
||||
|
||||
|
||||
def test_output_params_create_filesystem(tmp_path: Path) -> None:
|
||||
# file_system_config must be of type ExperimentFolderHandler
|
||||
with pytest.raises(Exception) as e:
|
||||
OutputParams(file_system_config="foo")
|
||||
assert "value must be an instance of ExperimentFolderHandler" in str(e)
|
||||
|
||||
config = OutputParams()
|
||||
default_file_system_config = config.file_system_config
|
||||
assert isinstance(default_file_system_config, ExperimentFolderHandler)
|
||||
assert default_file_system_config.project_root == Path(".")
|
||||
# Now call create_filesystem with a different path project_root
|
||||
config.create_filesystem(tmp_path)
|
||||
new_file_system_config = config.file_system_config
|
||||
assert new_file_system_config.project_root == tmp_path
|
||||
|
||||
|
||||
def test_validate_optimizer_params() -> None:
|
||||
# Instantiating OptimizerParams with no non-default values should be ok
|
||||
config = OptimizerParams()
|
||||
config.validate()
|
||||
|
||||
# assert that passing a string to a param expecting a numeric value causes an Exception to be raised
|
||||
numeric_params = [k for k, v in config.params().items() if isinstance(v, Number)]
|
||||
for numeric_param_name in numeric_params:
|
||||
with pytest.raises(Exception) as e:
|
||||
config = OptimizerParams()
|
||||
setattr(config, numeric_param_name, "foo")
|
||||
config.validate()
|
||||
|
||||
# For non-numeric parametes, check that Exceptions are raised when params with invalid types are provided
|
||||
with pytest.raises(Exception) as e:
|
||||
OptimizerParams(l_rate_scheduler="foo").validate()
|
||||
assert "must be an instance of LRSchedulerType" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
OptimizerParams(l_rate_multi_step_milestones="foo")
|
||||
assert "must be a list" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
OptimizerParams(l_rate_warmup="foo").validate()
|
||||
assert "must be an instance of LRWarmUpType" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
OptimizerParams(optimizer_type="foo").validate()
|
||||
assert "must be an instance of OptimizerType" in str(e)
|
||||
|
||||
with pytest.raises(Exception) as e:
|
||||
OptimizerParams(adam_betas="foo").validate()
|
||||
assert "only takes a tuple value" in str(e)
|
||||
|
||||
|
||||
def test_optimizer_params_min_l_rate() -> None:
|
||||
config = OptimizerParams()
|
||||
min_l_rate = config.min_l_rate
|
||||
assert min_l_rate == config._min_l_rate
|
||||
|
||||
|
||||
def test_trainer_params_use_gpu() -> None:
|
||||
config = TrainerParams()
|
||||
for patch_gpu in [False, True]:
|
||||
with patch("health_ml.utils.common_utils.is_gpu_available") as mock_gpu_available:
|
||||
mock_gpu_available.return_value = patch_gpu
|
||||
assert config.use_gpu is patch_gpu
|
||||
|
||||
|
||||
@patch("health_ml.utils.common_utils.is_gpu_available")
|
||||
def test_trainer_params_num_gpus_per_node(mock_gpu_available: MagicMock, caplog: LogCaptureFixture) -> None:
|
||||
mock_gpu_available.return_value = True
|
||||
# if the requested number of gpus is available and less than the total available number of gpus, a warning
|
||||
# should be logged to let the user know that they aren't using the full capacity
|
||||
requested_gpus = 3
|
||||
config = TrainerParams(max_num_gpus=requested_gpus)
|
||||
random_num_available_gpus = randint(requested_gpus, requested_gpus + 5)
|
||||
with patch("torch.cuda.device_count") as mock_gpu_count:
|
||||
mock_gpu_count.return_value = random_num_available_gpus
|
||||
assert config.num_gpus_per_node() == requested_gpus
|
||||
message = caplog.messages[-1]
|
||||
assert f"Restricting the number of GPUs to {requested_gpus}" in message
|
||||
|
||||
# if the max number of gpus is set as less than the number available, expect a warning
|
||||
requested_gpus = 3
|
||||
random_num_available_gpus = randint(1, requested_gpus - 1)
|
||||
with patch("torch.cuda.device_count") as mock_gpu_count:
|
||||
mock_gpu_count.return_value = random_num_available_gpus
|
||||
config = TrainerParams(max_num_gpus=requested_gpus)
|
||||
assert config.num_gpus_per_node() == random_num_available_gpus
|
||||
message = caplog.messages[-1]
|
||||
assert f"You requested max_num_gpus {requested_gpus} but there are only {random_num_available_gpus}" \
|
||||
f" available." in message
|
|
@ -0,0 +1,110 @@
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, patch, Mock
|
||||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler, ModelCheckpoint, ModelSummary, TQDMProgressBar
|
||||
|
||||
from health_ml.configs.hello_container import HelloContainer # type: ignore
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.model_trainer import (create_lightning_trainer, write_experiment_summary_file, model_train)
|
||||
from health_ml.utils.common_utils import EXPERIMENT_SUMMARY_FILE
|
||||
from health_ml.utils.config_loader import ModelConfigLoader
|
||||
from health_ml.utils.lightning_loggers import StoringLogger
|
||||
|
||||
|
||||
def test_write_experiment_summary_file(tmp_path: Path) -> None:
|
||||
config = {
|
||||
"Container": {
|
||||
"_min_l_rate": 0.0,
|
||||
"_model_name": "HelloContainer",
|
||||
"adam_betas": "(0.9, 0.999)",
|
||||
"azure_datasets": "[]"}
|
||||
}
|
||||
expected_args_path = tmp_path / EXPERIMENT_SUMMARY_FILE
|
||||
write_experiment_summary_file(config, tmp_path)
|
||||
|
||||
actual_args = expected_args_path.read_text()
|
||||
assert actual_args == str(config)
|
||||
|
||||
|
||||
def test_create_lightning_trainer() -> None:
|
||||
container = LightningContainer()
|
||||
trainer, storing_logger = create_lightning_trainer(container)
|
||||
|
||||
assert trainer.num_gpus == container.num_gpus_per_node()
|
||||
# by default, trainer's num_nodes is 1
|
||||
assert trainer.num_nodes == 1
|
||||
assert trainer.default_root_dir == str(container.outputs_folder)
|
||||
assert trainer.limit_train_batches == 1.0
|
||||
assert trainer._detect_anomaly == container.detect_anomaly
|
||||
|
||||
assert isinstance(trainer.callbacks[0], TQDMProgressBar)
|
||||
assert isinstance(trainer.callbacks[1], ModelSummary)
|
||||
assert isinstance(trainer.callbacks[2], GradientAccumulationScheduler)
|
||||
assert isinstance(trainer.callbacks[3], ModelCheckpoint)
|
||||
|
||||
assert isinstance(storing_logger, StoringLogger)
|
||||
assert storing_logger.hyperparams is None
|
||||
assert len(storing_logger.results_per_epoch) == 0
|
||||
assert len(storing_logger.train_diagnostics) == 0
|
||||
assert len(storing_logger.val_diagnostics) == 0
|
||||
assert len(storing_logger.results_without_epoch) == 0
|
||||
|
||||
|
||||
class MyCallback(Callback):
|
||||
def on_init_start(self, trainer: Trainer) -> None:
|
||||
print("Starting to init trainer")
|
||||
|
||||
|
||||
def test_create_lightning_trainer_with_callbacks() -> None:
|
||||
"""
|
||||
Test that create_lightning_trainer picks up on additional Container callbacks
|
||||
"""
|
||||
def _get_trainer_arguments() -> Dict[str, Any]:
|
||||
callbacks = [MyCallback()]
|
||||
return {"callbacks": callbacks}
|
||||
|
||||
model_name = "HelloContainer"
|
||||
model_config_loader = ModelConfigLoader(model=model_name)
|
||||
container = model_config_loader.create_model_config_from_name(model_name)
|
||||
container.monitor_gpu = False
|
||||
container.monitor_loading = False
|
||||
# mock get_trainer_arguments method, since default HelloContainer class doesn't specify any additional callbacks
|
||||
container.get_trainer_arguments = _get_trainer_arguments # type: ignore
|
||||
|
||||
kwargs = container.get_trainer_arguments()
|
||||
assert "callbacks" in kwargs
|
||||
# create_lightning_trainer(container, )
|
||||
trainer, storing_logger = create_lightning_trainer(container)
|
||||
# expect trainer to have 3 default callbacks: TQProgressBar, ModelSummary, GradintAccumlationScheduler
|
||||
# and ModelCheckpoint, plus any additional callbacks specified in get_trainer_arguments method
|
||||
kwarg_callbacks = kwargs.get("callbacks") or []
|
||||
expected_num_callbacks = len(kwarg_callbacks) + 4
|
||||
assert len(trainer.callbacks) == expected_num_callbacks, f"Found callbacks: {trainer.callbacks}"
|
||||
assert any([isinstance(c, MyCallback) for c in trainer.callbacks])
|
||||
|
||||
assert isinstance(storing_logger, StoringLogger)
|
||||
|
||||
|
||||
def test_model_train() -> None:
|
||||
container = HelloContainer()
|
||||
container.create_lightning_module_and_store()
|
||||
|
||||
with patch.object(container, "get_data_module"):
|
||||
with patch("health_ml.model_trainer.create_lightning_trainer") as mock_create_trainer:
|
||||
mock_trainer = MagicMock()
|
||||
mock_storing_logger = MagicMock()
|
||||
mock_create_trainer.return_value = mock_trainer, mock_storing_logger
|
||||
|
||||
mock_trainer.fit = Mock()
|
||||
mock_close_logger = Mock()
|
||||
mock_trainer.logger = MagicMock(close=mock_close_logger)
|
||||
|
||||
trainer, storing_logger = model_train(container)
|
||||
|
||||
mock_trainer.fit.assert_called_once()
|
||||
mock_trainer.logger.finalize.assert_called_once()
|
||||
|
||||
assert trainer == mock_trainer
|
||||
assert storing_logger == mock_storing_logger
|
|
@ -0,0 +1,94 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
from pytorch_lightning import Callback
|
||||
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.run_ml import MLRunner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ml_runner() -> MLRunner:
|
||||
experiment_config = ExperimentConfig()
|
||||
container = LightningContainer(num_epochs=1)
|
||||
return MLRunner(experiment_config=experiment_config, container=container)
|
||||
|
||||
|
||||
def test_ml_runner_setup(ml_runner: MLRunner) -> None:
|
||||
"""
|
||||
Check that all the necessary methods get called during setup
|
||||
"""
|
||||
assert not ml_runner._has_setup_run
|
||||
with patch.object(ml_runner, "container", spec=LightningContainer) as mock_container:
|
||||
with patch("health_ml.run_ml.seed_everything") as mock_seed:
|
||||
# mock_container.get_effectie_random_seed = Mock()
|
||||
ml_runner.setup()
|
||||
mock_container.get_effective_random_seed.assert_called_once()
|
||||
mock_container.setup.assert_called_once()
|
||||
mock_container.create_lightning_module_and_store.assert_called_once()
|
||||
assert ml_runner._has_setup_run
|
||||
mock_seed.assert_called_once()
|
||||
|
||||
|
||||
def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None:
|
||||
with pytest.raises(AssertionError) as ae:
|
||||
ml_runner.set_run_tags_from_parent()
|
||||
assert "should only be called in a Hyperdrive run" in str(ae)
|
||||
|
||||
with patch("health_ml.run_ml.PARENT_RUN_CONTEXT") as mock_parent_run_context:
|
||||
with patch("health_ml.run_ml.RUN_CONTEXT") as mock_run_context:
|
||||
mock_parent_run_context.get_tags.return_value = {"tag": "dummy_tag"}
|
||||
ml_runner.set_run_tags_from_parent()
|
||||
mock_run_context.set_tags.assert_called()
|
||||
|
||||
|
||||
def test_run(ml_runner: MLRunner) -> None:
|
||||
|
||||
def _mock_model_train(container: LightningContainer) -> Tuple[str, str]:
|
||||
return "trainer", dummy_storing_logger
|
||||
|
||||
dummy_storing_logger = "storing_logger"
|
||||
|
||||
with patch.object(ml_runner, "setup") as mock_setup:
|
||||
with patch("health_ml.run_ml.model_train", new=_mock_model_train):
|
||||
ml_runner.run()
|
||||
mock_setup.assert_called_once()
|
||||
# expect _mock_model_train to be called and the result of ml_runner.storing_logger
|
||||
# updated accordingly
|
||||
assert ml_runner.storing_logger == dummy_storing_logger
|
||||
|
||||
|
||||
@patch("health_ml.run_ml.create_lightning_trainer")
|
||||
def test_run_inference_for_lightning_models(mock_create_trainer: MagicMock, ml_runner: MLRunner,
|
||||
tmp_path: Path) -> None:
|
||||
"""
|
||||
Check that all expected methods are called during inference3
|
||||
"""
|
||||
mock_trainer = MagicMock()
|
||||
mock_test_result = [{"result": 1.0}]
|
||||
mock_trainer.test.return_value = mock_test_result
|
||||
mock_create_trainer.return_value = mock_trainer, ""
|
||||
|
||||
with patch.object(ml_runner, "container") as mock_container:
|
||||
mock_container.num_gpus_per_node.return_value = 0
|
||||
mock_container.get_trainer_arguments.return_value = {"callbacks": Callback()}
|
||||
mock_container.load_model_checkpoint.return_value = Mock()
|
||||
mock_container.get_data_module.return_value = Mock()
|
||||
mock_container.pl_progress_bar_refresh_rate = None
|
||||
mock_container.detect_anomaly = False
|
||||
mock_container.pl_limit_train_batches = 1.0
|
||||
mock_container.pl_limit_val_batches = 1.0
|
||||
mock_container.outputs_folder = tmp_path
|
||||
|
||||
checkpoint_paths = [Path("dummy")]
|
||||
result = ml_runner.run_inference_for_lightning_models(checkpoint_paths)
|
||||
assert result == mock_test_result
|
||||
|
||||
mock_create_trainer.assert_called_once()
|
||||
mock_container.load_model_checkpoint.assert_called_once()
|
||||
mock_container.get_data_module.assert_called_once()
|
||||
mock_trainer.test.assert_called_once()
|
|
@ -0,0 +1,115 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from health_azure import AzureRunInfo, DatasetConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.runner import Runner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(tmp_path: Path) -> Runner:
|
||||
|
||||
return Runner(project_root=tmp_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name, cluster, num_nodes, should_raise_value_error", [
|
||||
("HelloContainer", "dummyCluster", 1, False),
|
||||
("", "", None, True),
|
||||
("HelloContainer", "", None, False),
|
||||
("a", None, 0, True),
|
||||
(None, "b", 10, True),
|
||||
("HelloContainer", "b", 10, False)
|
||||
])
|
||||
def test_parse_and_load_model(mock_runner: Runner, model_name: Optional[str], cluster: Optional[str],
|
||||
num_nodes: Optional[int], should_raise_value_error: bool) -> None:
|
||||
"""
|
||||
Test that command line args are parsed, a LightningContainer is instantiated with the expected attributes
|
||||
and a ParserResult object is returned, with the expected attributes. If model_name cannot be found in the
|
||||
namespace (i.e. the config does not exist) a ValueError should be raised
|
||||
"""
|
||||
dummy_args = [""]
|
||||
if model_name is not None:
|
||||
dummy_args.append(f"--model={model_name}")
|
||||
if cluster is not None:
|
||||
dummy_args.append(f"--cluster={cluster}")
|
||||
if num_nodes is not None:
|
||||
dummy_args.append(f"--num_nodes={num_nodes}")
|
||||
|
||||
with patch.object(sys, "argv", new=dummy_args):
|
||||
if should_raise_value_error:
|
||||
with pytest.raises(ValueError) as ve:
|
||||
mock_runner.parse_and_load_model()
|
||||
assert "Parameter 'model' needs to be set" in str(ve)
|
||||
else:
|
||||
parser_result = mock_runner.parse_and_load_model()
|
||||
# if model, cluster or num_nodes are provdided in command line args, the corresponding attributes of
|
||||
# the LightningContainer will be set accordingly and they will be dropped from ParserResult during
|
||||
# parse_overrides_and_apply
|
||||
assert parser_result.args.get("model") is None
|
||||
assert parser_result.args.get("cluster") is None
|
||||
assert parser_result.args.get("num_nodes") is None
|
||||
|
||||
assert isinstance(mock_runner.lightning_container, LightningContainer)
|
||||
assert mock_runner.lightning_container.initialized
|
||||
assert mock_runner.lightning_container.model_name == model_name
|
||||
|
||||
|
||||
def test_run(mock_runner: Runner) -> None:
|
||||
model_name = "HelloContainer"
|
||||
arguments = ["", f"--model={model_name}"]
|
||||
with patch("health_ml.runner.Runner.run_in_situ") as mock_run_in_situ:
|
||||
with patch("health_ml.runner.get_workspace"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
model_config, azure_run_info = mock_runner.run()
|
||||
mock_run_in_situ.assert_called_once()
|
||||
|
||||
assert model_config is not None # for pyright
|
||||
assert model_config.model_name == model_name
|
||||
assert azure_run_info.run is None
|
||||
assert len(azure_run_info.input_datasets) == len(azure_run_info.output_datasets) == 0
|
||||
|
||||
|
||||
@patch("health_ml.runner.get_all_environment_files")
|
||||
@patch("health_ml.runner.get_all_pip_requirements_files")
|
||||
@patch("health_ml.runner.get_workspace")
|
||||
def test_submit_to_azureml_if_needed(mock_get_workspace: MagicMock,
|
||||
mock_get_pip_req_files: MagicMock,
|
||||
mock_get_env_files: MagicMock,
|
||||
mock_runner: Runner
|
||||
) -> None:
|
||||
def _mock_dont_submit_to_aml(input_datasets: List[DatasetConfig], submit_to_azureml: bool # type: ignore
|
||||
) -> AzureRunInfo:
|
||||
datasets_input = [d.target_folder for d in input_datasets] if input_datasets else []
|
||||
return AzureRunInfo(input_datasets=datasets_input,
|
||||
output_datasets=[],
|
||||
mount_contexts=[],
|
||||
run=None,
|
||||
is_running_in_azure_ml=False,
|
||||
output_folder=None, # type: ignore
|
||||
logs_folder=None) # type: ignore
|
||||
|
||||
mock_get_env_files.return_value = []
|
||||
mock_get_pip_req_files.return_value = []
|
||||
|
||||
mock_default_datastore = MagicMock()
|
||||
mock_default_datastore.name.return_value = "dummy_datastore"
|
||||
mock_get_workspace.get_default_datastore.return_value = mock_default_datastore
|
||||
|
||||
with patch("health_ml.runner.create_dataset_configs") as mock_create_datasets:
|
||||
mock_create_datasets.return_value = []
|
||||
with patch("health_ml.runner.submit_to_azure_if_needed") as mock_submit_to_aml:
|
||||
mock_submit_to_aml.side_effect = _mock_dont_submit_to_aml
|
||||
mock_runner.lightning_container = LightningContainer()
|
||||
run_info = mock_runner.submit_to_azureml_if_needed()
|
||||
assert isinstance(run_info, AzureRunInfo)
|
||||
assert run_info.input_datasets == []
|
||||
assert run_info.is_running_in_azure_ml is False
|
||||
assert run_info.output_folder is None
|
|
@ -0,0 +1,37 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from health_azure.utils import PathOrString
|
||||
|
||||
|
||||
def tests_root_directory(path: Optional[PathOrString] = None) -> Path:
|
||||
"""
|
||||
Gets the full path to the root directory that holds the tests.
|
||||
If a relative path is provided then concatenate it with the absolute path
|
||||
to the repository root.
|
||||
|
||||
:return: The full path to the repository's root directory, with symlinks resolved if any.
|
||||
"""
|
||||
root = Path(os.path.realpath(__file__)).parent.parent.parent
|
||||
return root / path if path else root
|
||||
|
||||
|
||||
def full_ml_test_data_path(path: str = "") -> Path:
|
||||
"""
|
||||
Takes a relative path inside of the testhiml/ML/test_data folder, and returns its
|
||||
full absolute path.
|
||||
|
||||
:param path: A path relative to the ML/tests/test_data
|
||||
:return: The full absolute path of the argument.
|
||||
"""
|
||||
return _full_test_data_path("ML", path)
|
||||
|
||||
|
||||
def _full_test_data_path(prefix: str, suffix: str) -> Path:
|
||||
root = tests_root_directory()
|
||||
return root / prefix / "test_data" / suffix
|
|
@ -2,6 +2,48 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import pandas as pd
|
||||
|
||||
from health_azure.utils import UnitTestWorkspaceWrapper
|
||||
|
||||
|
||||
DEFAULT_WORKSPACE = UnitTestWorkspaceWrapper()
|
||||
|
||||
|
||||
def create_dataset_df() -> pd.DataFrame:
|
||||
"""
|
||||
Create a test dataframe for DATASET_CSV_FILE_NAME.
|
||||
|
||||
:return: Test dataframe.
|
||||
"""
|
||||
dataset_df = pd.DataFrame()
|
||||
dataset_df['subject'] = list(range(10))
|
||||
dataset_df['seriesId'] = [f"s{i}" for i in range(10)]
|
||||
dataset_df['institutionId'] = ["xyz"] * 10
|
||||
return dataset_df
|
||||
|
||||
|
||||
def create_metrics_df() -> pd.DataFrame:
|
||||
"""
|
||||
Create a test dataframe for SUBJECT_METRICS_FILE_NAME.
|
||||
|
||||
:return: Test dataframe.
|
||||
"""
|
||||
metrics_df = pd.DataFrame()
|
||||
metrics_df['Patient'] = list(range(10))
|
||||
metrics_df['Structure'] = ['appendix'] * 10
|
||||
metrics_df['Dice'] = [0.5 + i * 0.02 for i in range(10)]
|
||||
return metrics_df
|
||||
|
||||
|
||||
def create_comparison_metrics_df() -> pd.DataFrame:
|
||||
"""
|
||||
Create a test dataframe for comparison metrics.
|
||||
|
||||
:return: Test dataframe.
|
||||
"""
|
||||
comparison_metrics_df = pd.DataFrame()
|
||||
comparison_metrics_df['Patient'] = list(range(10))
|
||||
comparison_metrics_df['Structure'] = ['appendix'] * 10
|
||||
comparison_metrics_df['Dice'] = [0.51 + i * 0.02 for i in range(10)]
|
||||
return comparison_metrics_df
|
||||
|
|
|
@ -12,4 +12,18 @@
|
|||
"reportMissingImports": true,
|
||||
"reportMissingTypeStubs": false,
|
||||
"reportPrivateImportUsage": false,
|
||||
"executionEnvironments": [
|
||||
{
|
||||
"root": "hi-ml/src"
|
||||
},
|
||||
{
|
||||
"root": "hi-ml/testhiml"
|
||||
},
|
||||
{
|
||||
"root": "hi-ml-azure/src"
|
||||
},
|
||||
{
|
||||
"root": "hi-ml-azure/testazure"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче