зеркало из https://github.com/microsoft/hi-ml.git
ENH: Ability to evaluate model on a new dataset (#859)
Renamed the `MLRunner` to `TrainingRunner`. Introduced a new `EvalRunner`. Shared functionality moved from `MLRunner` to `RunnerBase`.
This commit is contained in:
Родитель
2cc2511efd
Коммит
e73c92649d
|
@ -7,6 +7,7 @@ use of these features:
|
||||||
- Working with different models in the same codebase, and selecting one by name
|
- Working with different models in the same codebase, and selecting one by name
|
||||||
- Distributed training in AzureML
|
- Distributed training in AzureML
|
||||||
- Logging via AzureML's native capabilities
|
- Logging via AzureML's native capabilities
|
||||||
|
- Evaluation of the trained model on new datasets
|
||||||
|
|
||||||
This can be used by invoking the hi-ml runner and providing the name of the container class, like this:
|
This can be used by invoking the hi-ml runner and providing the name of the container class, like this:
|
||||||
`himl-runner --model=MyContainer`.
|
`himl-runner --model=MyContainer`.
|
||||||
|
@ -215,7 +216,13 @@ and returns a tuple containing the Optimizer and LRScheduler objects
|
||||||
## Run inference with a pretrained model
|
## Run inference with a pretrained model
|
||||||
|
|
||||||
You can use the hi-ml-runner in inference mode only by switching the `--run_inference_only` flag on and specifying
|
You can use the hi-ml-runner in inference mode only by switching the `--run_inference_only` flag on and specifying
|
||||||
the model weights by setting `--src_checkpoint` argument that supports three types of checkpoints:
|
the model weights by setting `--src_checkpoint` argument. With this flag, the model will be evaluated on the test set
|
||||||
|
only. There is also an option for evaluating the model an a full dataset, described further below.
|
||||||
|
|
||||||
|
### Specifying the checkpoint to use
|
||||||
|
|
||||||
|
When running inference on a trained model, you need to provide a model checkpoint that should be used. This is done via
|
||||||
|
the `--src_checkpoint` argument. This supports three types of checkpoints:
|
||||||
|
|
||||||
- A local path where the checkpoint is stored `--src_checkpoint=local/path/to/my_checkpoint/model.ckpt`
|
- A local path where the checkpoint is stored `--src_checkpoint=local/path/to/my_checkpoint/model.ckpt`
|
||||||
- A remote URL from where to download the weights `--src_checkpoint=https://my_checkpoint_url.com/model.ckpt`
|
- A remote URL from where to download the weights `--src_checkpoint=https://my_checkpoint_url.com/model.ckpt`
|
||||||
|
@ -228,13 +235,69 @@ the model weights by setting `--src_checkpoint` argument that supports three typ
|
||||||
|
|
||||||
Refer to [Checkpoints Utils](checkpoints.md) for more details on how checkpoints are parsed.
|
Refer to [Checkpoints Utils](checkpoints.md) for more details on how checkpoints are parsed.
|
||||||
|
|
||||||
Running the following command line will run inference using `MyContainer` model with weights from the checkpoint saved
|
### Running inference on the test set
|
||||||
|
|
||||||
|
When supplying the flag `--run_inference_only` on the commandline, no model training will be run, and only inference on the
|
||||||
|
test set will be done:
|
||||||
|
|
||||||
|
- The model weights will be loaded from the location specified by `--src_checkpoint`
|
||||||
|
- A PyTorch Lightining `Trainer` object will be instantiated.
|
||||||
|
- The test set will be read out from the data module specified by the `get_data_module` method of the
|
||||||
|
`LightningContainer` object.
|
||||||
|
- The model will be evaluated on the test set, by running `trainer.test`. Any special logic to use during the test step
|
||||||
|
will need to be added to the model's `test_step` method.
|
||||||
|
|
||||||
|
Running the following command line will run inference using the `MyContainer` model with weights from the checkpoint saved
|
||||||
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
|
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
himl-runner --model=Mycontainer --run_inference_only --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt
|
himl-runner --model=Mycontainer --run_inference_only --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Running inference on a full dataset
|
||||||
|
|
||||||
|
When supplying the flag `--mode=eval_full` on the commandline, no model training will be run, and the model will be
|
||||||
|
evaluated on a dataset different from the training/validation/test dataset. This dataset is loaded via the
|
||||||
|
`get_eval_data_module` method of the container.
|
||||||
|
|
||||||
|
- The model weights will be loaded from the location specified by `--src_checkpoint`
|
||||||
|
- A PyTorch Lightining `Trainer` object will be instantiated.
|
||||||
|
- The test set will be read out from the data module specified by the `get_eval_data_module` method of the
|
||||||
|
`LightningContainer` object. The data module itself can read data from a mounted Azure dataset, which will be made
|
||||||
|
availabe for the container at the path `self.local_datasets`. In a typical use-case, all the data in that dataset will be
|
||||||
|
put into the `test_dataloader` field of the data module.
|
||||||
|
- The model will be evaluated on the test set, by running `trainer.test`. Any special logic to use during the test step
|
||||||
|
will need to added to the model's `test_step` method.
|
||||||
|
|
||||||
|
Running the following command line will run inference using the `MyContainer` model with weights from the checkpoint saved
|
||||||
|
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
himl-runner --model=Mycontainer --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt --mode=eval_full --azure_datasets=my_new_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
The example code snippet here shows how to add a method that reads the inference dataset. In this example, we assume
|
||||||
|
that the `MyDataModule` class has an argument `splits` that specifies the fraction of data to go into the training,
|
||||||
|
validation, and test data loaders.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyContainer(LightningContainer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.azure_datasets = ["folder_name_in_azure_blob_storage"]
|
||||||
|
self.local_datasets = [Path("/some/local/path")]
|
||||||
|
self.max_epochs = 42
|
||||||
|
|
||||||
|
def create_model(self) -> LightningModule:
|
||||||
|
return MyLightningModel()
|
||||||
|
|
||||||
|
def get_data_module(self) -> LightningDataModule:
|
||||||
|
return MyDataModule(root_path=self.local_dataset, splits=(0.7, 0.2, 0.1))
|
||||||
|
|
||||||
|
def get_eval_data_module(self) -> LightningDataModule:
|
||||||
|
return MyDataModule(root_path=self.local_dataset, splits=(0.0, 0.0, 1.0))
|
||||||
|
```
|
||||||
|
|
||||||
## Resume training from a given checkpoint
|
## Resume training from a given checkpoint
|
||||||
|
|
||||||
Analogously, one can resume training by setting `--src_checkpoint` and `--resume_training` to train a model longer.
|
Analogously, one can resume training by setting `--src_checkpoint` and `--resume_training` to train a model longer.
|
||||||
|
@ -298,5 +361,6 @@ seeds on a local machine, other than by manually starting runs with `--random_se
|
||||||
the limit that AML sets for snapshots. Solution: check for cache files, log files or other files that are not
|
the limit that AML sets for snapshots. Solution: check for cache files, log files or other files that are not
|
||||||
necessary for running your experiment and add them to a `.amlignore` file in the root directory. Alternatively, you
|
necessary for running your experiment and add them to a `.amlignore` file in the root directory. Alternatively, you
|
||||||
can see Azure ML documentation for instructions on increasing this limit, although it will make your jobs slower.
|
can see Azure ML documentation for instructions on increasing this limit, although it will make your jobs slower.
|
||||||
2. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload
|
|
||||||
the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML.
|
1. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload
|
||||||
|
the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML.
|
||||||
|
|
|
@ -61,4 +61,9 @@
|
||||||
"${workspaceFolder}/src",
|
"${workspaceFolder}/src",
|
||||||
"${workspaceFolder}/testazure"
|
"${workspaceFolder}/testazure"
|
||||||
],
|
],
|
||||||
|
"workbench.colorCustomizations": {
|
||||||
|
"activityBar.background": "#5D091F",
|
||||||
|
"titleBar.activeBackground": "#830D2B",
|
||||||
|
"titleBar.activeForeground": "#FFFCFC"
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -524,7 +524,7 @@ def create_dataset_configs(
|
||||||
count = num_azure
|
count = num_azure
|
||||||
elif num_azure == 0 and num_mount == 0:
|
elif num_azure == 0 and num_mount == 0:
|
||||||
# No datasets in Azure at all: This is possible for runs that for example download their own data from the web.
|
# No datasets in Azure at all: This is possible for runs that for example download their own data from the web.
|
||||||
# There can be any number of local datasets, but we are not checking that. In MLRunner.setup, there is a check
|
# There can be any number of local datasets, but we are not checking that. In TrainingRunner.setup, there is a check
|
||||||
# that leaves local datasets intact if there are no Azure datasets.
|
# that leaves local datasets intact if there are no Azure datasets.
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
from health_ml.run_ml import MLRunner
|
from health_ml.training_runner import TrainingRunner
|
||||||
from health_ml.runner import Runner
|
from health_ml.runner import Runner
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["MLRunner", "Runner"]
|
__all__ = ["TrainingRunner", "Runner"]
|
||||||
|
|
|
@ -16,6 +16,9 @@ from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from health_ml.lightning_container import LightningContainer
|
from health_ml.lightning_container import LightningContainer
|
||||||
|
|
||||||
|
TEST_MSE_FILE = "test_mse.txt"
|
||||||
|
TEST_MAE_FILE = "test_mae.txt"
|
||||||
|
|
||||||
|
|
||||||
def _create_1d_regression_dataset(n: int = 100, seed: int = 0) -> torch.Tensor:
|
def _create_1d_regression_dataset(n: int = 100, seed: int = 0) -> torch.Tensor:
|
||||||
"""Creates a simple 1-D dataset of a noisy linear function.
|
"""Creates a simple 1-D dataset of a noisy linear function.
|
||||||
|
@ -79,10 +82,10 @@ class HelloWorldDataModule(LightningDataModule):
|
||||||
A data module that gives the training, validation and test data for a simple 1-dim regression task.
|
A data module that gives the training, validation and test data for a simple 1-dim regression task.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, crossval_count: int, crossval_index: Optional[int]) -> None:
|
def __init__(self, crossval_count: int, crossval_index: Optional[int] = None, seed: int = 0) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
n_total = 200
|
n_total = 200
|
||||||
xy = _create_1d_regression_dataset(n=n_total)
|
xy = _create_1d_regression_dataset(n=n_total, seed=seed)
|
||||||
n_test = 40
|
n_test = 40
|
||||||
n_val = 50
|
n_val = 50
|
||||||
self.test = HelloWorldDataset(xy=xy[:n_test])
|
self.test = HelloWorldDataset(xy=xy[:n_test])
|
||||||
|
@ -229,8 +232,8 @@ class HelloRegression(LightningModule):
|
||||||
for example writing aggregate metrics to disk.
|
for example writing aggregate metrics to disk.
|
||||||
"""
|
"""
|
||||||
average_mse = torch.mean(torch.stack(self.test_mse))
|
average_mse = torch.mean(torch.stack(self.test_mse))
|
||||||
Path("test_mse.txt").write_text(str(average_mse.item()))
|
Path(TEST_MSE_FILE).write_text(str(average_mse.item()))
|
||||||
Path("test_mae.txt").write_text(str(self.test_mae.compute().item()))
|
Path(TEST_MAE_FILE).write_text(str(self.test_mae.compute().item()))
|
||||||
self.log("test_mse", average_mse, on_epoch=True, on_step=False)
|
self.log("test_mse", average_mse, on_epoch=True, on_step=False)
|
||||||
|
|
||||||
def on_run_extra_validation_epoch(self) -> None:
|
def on_run_extra_validation_epoch(self) -> None:
|
||||||
|
@ -266,6 +269,11 @@ class HelloWorld(LightningContainer):
|
||||||
# datamodule must carry out appropriate splitting of the data.
|
# datamodule must carry out appropriate splitting of the data.
|
||||||
return HelloWorldDataModule(crossval_count=self.crossval_count, crossval_index=self.crossval_index)
|
return HelloWorldDataModule(crossval_count=self.crossval_count, crossval_index=self.crossval_index)
|
||||||
|
|
||||||
|
# This method is optional. Override it to supply a data module to use for evaluating the model on a new dataset.
|
||||||
|
# Only the test data loader from the data module will be used.
|
||||||
|
def get_eval_data_module(self) -> LightningDataModule:
|
||||||
|
return HelloWorldDataModule(crossval_count=1, seed=1)
|
||||||
|
|
||||||
def get_callbacks(self) -> List[Callback]:
|
def get_callbacks(self) -> List[Callback]:
|
||||||
if self.save_checkpoint:
|
if self.save_checkpoint:
|
||||||
return [
|
return [
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
from pytorch_lightning import LightningDataModule
|
||||||
|
|
||||||
|
from health_azure.logging import logging_section
|
||||||
|
from health_ml.runner_base import RunnerBase
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRunner(RunnerBase):
|
||||||
|
"""A class to run the evaluation of a model on a new dataset. The initialization logic is taken from the base
|
||||||
|
class `RunnerBase`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Checks if the fields of the class are set up correctly."""
|
||||||
|
if self.container.src_checkpoint is None or self.container.src_checkpoint.checkpoint == "":
|
||||||
|
raise ValueError(
|
||||||
|
"To use model evaluation, you need to provide a checkpoint to use, via the --src_checkpoint argument."
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""Start the core workflow that the class implements: Initialize a PL Trainer object and use that to run
|
||||||
|
inference on the inference dataset."""
|
||||||
|
self.container.outputs_folder.mkdir(exist_ok=True, parents=True)
|
||||||
|
self.init_inference()
|
||||||
|
with logging_section("Model inference"):
|
||||||
|
self.run_inference()
|
||||||
|
|
||||||
|
def get_data_module(self) -> LightningDataModule:
|
||||||
|
"""Reads the evaluation data module from the underlying container."""
|
||||||
|
return self.container.get_eval_data_module()
|
|
@ -10,6 +10,11 @@ class DebugDDPOptions(Enum):
|
||||||
DETAIL = "DETAIL"
|
DETAIL = "DETAIL"
|
||||||
|
|
||||||
|
|
||||||
|
class RunnerMode(Enum):
|
||||||
|
TRAIN = "train"
|
||||||
|
EVAL_FULL = "eval_full"
|
||||||
|
|
||||||
|
|
||||||
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
|
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,3 +87,10 @@ class ExperimentConfig(param.Parameterized):
|
||||||
doc="The maximum runtime that is allowed for this job in AzureML. This is given as a floating"
|
doc="The maximum runtime that is allowed for this job in AzureML. This is given as a floating"
|
||||||
"point number with a string suffix s, m, h, d for seconds, minutes, hours, day. Examples: '3.5h', '2d'",
|
"point number with a string suffix s, m, h, d for seconds, minutes, hours, day. Examples: '3.5h', '2d'",
|
||||||
)
|
)
|
||||||
|
mode: str = param.ClassSelector(
|
||||||
|
class_=RunnerMode,
|
||||||
|
default=RunnerMode.TRAIN,
|
||||||
|
doc=f"The mode to run the experiment in. Can be one of '{RunnerMode.TRAIN}' (training and evaluation on the "
|
||||||
|
f"test set), or '{RunnerMode.EVAL_FULL}' for evaluation on the full dataset specified by the "
|
||||||
|
"'get_eval_data_module' method of the container.",
|
||||||
|
)
|
||||||
|
|
|
@ -65,6 +65,17 @@ class LightningContainer(WorkflowParams, DatasetParams, OutputParams, TrainerPar
|
||||||
"""
|
"""
|
||||||
return None # type: ignore
|
return None # type: ignore
|
||||||
|
|
||||||
|
def get_eval_data_module(self) -> LightningDataModule:
|
||||||
|
"""
|
||||||
|
Gets the data that is used when evaluating the model on a new dataset.
|
||||||
|
This data module should read datasets from the self.local_datasets folder or download from a web location.
|
||||||
|
Only the test dataloader is used, hence the method needs to put all data into the test dataloader, rather
|
||||||
|
than splitting into train/val/test.
|
||||||
|
|
||||||
|
:return: A LightningDataModule
|
||||||
|
"""
|
||||||
|
return None # type: ignore
|
||||||
|
|
||||||
def get_trainer_arguments(self) -> Dict[str, Any]:
|
def get_trainer_arguments(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Gets additional parameters that will be passed on to the PyTorch Lightning trainer.
|
Gets additional parameters that will be passed on to the PyTorch Lightning trainer.
|
||||||
|
|
|
@ -31,9 +31,9 @@ from health_azure.datasets import create_dataset_configs # noqa: E402
|
||||||
from health_azure.himl import DEFAULT_DOCKER_BASE_IMAGE, OUTPUT_FOLDER # noqa: E402
|
from health_azure.himl import DEFAULT_DOCKER_BASE_IMAGE, OUTPUT_FOLDER # noqa: E402
|
||||||
from health_azure.logging import logging_to_stdout # noqa: E402
|
from health_azure.logging import logging_to_stdout # noqa: E402
|
||||||
from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
|
from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
|
||||||
from health_azure.utils import (
|
from health_azure.utils import ( # noqa: E402
|
||||||
ENV_LOCAL_RANK,
|
ENV_LOCAL_RANK,
|
||||||
ENV_NODE_RANK, # noqa: E402
|
ENV_NODE_RANK,
|
||||||
get_workspace,
|
get_workspace,
|
||||||
get_ml_client,
|
get_ml_client,
|
||||||
is_local_rank_zero,
|
is_local_rank_zero,
|
||||||
|
@ -43,9 +43,11 @@ from health_azure.utils import (
|
||||||
is_global_rank_zero,
|
is_global_rank_zero,
|
||||||
)
|
)
|
||||||
|
|
||||||
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402
|
from health_ml.eval_runner import EvalRunner # noqa: E402
|
||||||
|
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig, RunnerMode # noqa: E402
|
||||||
from health_ml.lightning_container import LightningContainer # noqa: E402
|
from health_ml.lightning_container import LightningContainer # noqa: E402
|
||||||
from health_ml.run_ml import MLRunner # noqa: E402
|
from health_ml.runner_base import RunnerBase # noqa: E402
|
||||||
|
from health_ml.training_runner import TrainingRunner # noqa: E402
|
||||||
from health_ml.utils import fixed_paths # noqa: E402
|
from health_ml.utils import fixed_paths # noqa: E402
|
||||||
from health_ml.utils.logging import ConsoleAndFileOutput # noqa: E402
|
from health_ml.utils.logging import ConsoleAndFileOutput # noqa: E402
|
||||||
from health_ml.utils.common_utils import check_conda_environment, choose_conda_env_file, is_linux # noqa: E402
|
from health_ml.utils.common_utils import check_conda_environment, choose_conda_env_file, is_linux # noqa: E402
|
||||||
|
@ -101,8 +103,8 @@ class Runner:
|
||||||
self.project_root = project_root
|
self.project_root = project_root
|
||||||
self.experiment_config: ExperimentConfig = ExperimentConfig()
|
self.experiment_config: ExperimentConfig = ExperimentConfig()
|
||||||
self.lightning_container: LightningContainer = None # type: ignore
|
self.lightning_container: LightningContainer = None # type: ignore
|
||||||
# This field stores the MLRunner object that has been created in the most recent call to the run() method.
|
# This field stores the TrainingRunner object that has been created in the most recent call to the run() method.
|
||||||
self.ml_runner: Optional[MLRunner] = None
|
self.ml_runner: Optional[RunnerBase] = None
|
||||||
|
|
||||||
def parse_and_load_model(self) -> ParserResult:
|
def parse_and_load_model(self) -> ParserResult:
|
||||||
"""
|
"""
|
||||||
|
@ -322,15 +324,26 @@ class Runner:
|
||||||
assert azure_run_info.run is not None
|
assert azure_run_info.run is not None
|
||||||
azure_run_info.run.set_tags(self.additional_run_tags(sys.argv[1:]))
|
azure_run_info.run.set_tags(self.additional_run_tags(sys.argv[1:]))
|
||||||
|
|
||||||
# Set environment variables for multi-node training if needed. This function will terminate early
|
if self.experiment_config.mode == RunnerMode.TRAIN:
|
||||||
# if it detects that it is not in a multi-node environment.
|
# Set environment variables for multi-node training if needed. This function will terminate early
|
||||||
if self.experiment_config.num_nodes > 1:
|
# if it detects that it is not in a multi-node environment.
|
||||||
set_environment_variables_for_multi_node()
|
if self.experiment_config.num_nodes > 1:
|
||||||
self.ml_runner = MLRunner(
|
set_environment_variables_for_multi_node()
|
||||||
experiment_config=self.experiment_config, container=self.lightning_container, project_root=self.project_root
|
self.ml_runner = TrainingRunner(
|
||||||
)
|
experiment_config=self.experiment_config,
|
||||||
self.ml_runner.setup(azure_run_info)
|
container=self.lightning_container,
|
||||||
self.ml_runner.run()
|
project_root=self.project_root,
|
||||||
|
)
|
||||||
|
elif self.experiment_config.mode == RunnerMode.EVAL_FULL:
|
||||||
|
self.ml_runner = EvalRunner(
|
||||||
|
experiment_config=self.experiment_config,
|
||||||
|
container=self.lightning_container,
|
||||||
|
project_root=self.project_root,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown mode {self.experiment_config.mode}")
|
||||||
|
self.ml_runner.validate()
|
||||||
|
self.ml_runner.run_and_cleanup(azure_run_info)
|
||||||
|
|
||||||
|
|
||||||
def run(project_root: Path) -> Tuple[LightningContainer, AzureRunInfo]:
|
def run(project_root: Path) -> Tuple[LightningContainer, AzureRunInfo]:
|
||||||
|
|
|
@ -0,0 +1,247 @@
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from azureml.core import Run
|
||||||
|
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
|
||||||
|
|
||||||
|
from health_azure import AzureRunInfo
|
||||||
|
from health_azure.utils import (
|
||||||
|
PARENT_RUN_CONTEXT,
|
||||||
|
RUN_CONTEXT,
|
||||||
|
create_aml_run_object,
|
||||||
|
create_run_recovery_id,
|
||||||
|
is_running_in_azure_ml,
|
||||||
|
)
|
||||||
|
from health_ml.experiment_config import ExperimentConfig
|
||||||
|
from health_ml.lightning_container import LightningContainer
|
||||||
|
from health_ml.model_trainer import create_lightning_trainer
|
||||||
|
from health_ml.utils import fixed_paths
|
||||||
|
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||||
|
from health_ml.utils.common_utils import (
|
||||||
|
EFFECTIVE_RANDOM_SEED_KEY_NAME,
|
||||||
|
RUN_RECOVERY_FROM_ID_KEY_NAME,
|
||||||
|
RUN_RECOVERY_ID_KEY,
|
||||||
|
change_working_directory,
|
||||||
|
seed_monai_if_available,
|
||||||
|
)
|
||||||
|
from health_ml.utils.lightning_loggers import StoringLogger, get_mlflow_run_id_from_trainer
|
||||||
|
from health_ml.utils.type_annotations import PathOrString
|
||||||
|
|
||||||
|
|
||||||
|
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
|
||||||
|
"""
|
||||||
|
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
|
||||||
|
to a Path instance. If it does not exist, raise a FileNotFoundError.
|
||||||
|
|
||||||
|
:param local_dataset: The dataset folder to check.
|
||||||
|
:return: The local_dataset argument, converted to a Path.
|
||||||
|
"""
|
||||||
|
expected_dir = Path(local_dataset)
|
||||||
|
if not expected_dir.is_dir():
|
||||||
|
raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.")
|
||||||
|
logging.info(f"Model will use the local dataset provided in {expected_dir}")
|
||||||
|
return expected_dir
|
||||||
|
|
||||||
|
|
||||||
|
class RunnerBase:
|
||||||
|
"""
|
||||||
|
A base class with operations that are shared between the training/test runner and the evaluation-only runner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, experiment_config: ExperimentConfig, container: LightningContainer, project_root: Optional[Path] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Driver class to run an ML experiment. Note that the project root argument MUST be supplied when using hi-ml
|
||||||
|
as a package!
|
||||||
|
|
||||||
|
:param experiment_config: The ExperimentConfig object to use for training.
|
||||||
|
:param container: The LightningContainer object to use for training.
|
||||||
|
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
|
||||||
|
it is crucial when using hi-ml as a package or submodule!
|
||||||
|
"""
|
||||||
|
self.container = container
|
||||||
|
self.experiment_config = experiment_config
|
||||||
|
self.container.num_nodes = self.experiment_config.num_nodes
|
||||||
|
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
|
||||||
|
self.storing_logger: Optional[StoringLogger] = None
|
||||||
|
self._has_setup_run = False
|
||||||
|
self.checkpoint_handler = CheckpointHandler(
|
||||||
|
container=self.container, project_root=self.project_root, run_context=RUN_CONTEXT
|
||||||
|
)
|
||||||
|
self.trainer: Optional[Trainer] = None
|
||||||
|
self.azureml_run_for_logging: Optional[Run] = None
|
||||||
|
self.mlflow_run_for_logging: Optional[str] = None
|
||||||
|
# This is passed to trainer.validate and trainer.test in inference mode
|
||||||
|
self.inference_checkpoint: Optional[str] = None
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""
|
||||||
|
Checks if all arguments and settings of the object are correct.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_azureml(self) -> None:
|
||||||
|
"""
|
||||||
|
Execute setup steps that are specific to AzureML.
|
||||||
|
"""
|
||||||
|
if PARENT_RUN_CONTEXT is not None:
|
||||||
|
# Set metadata for the run in AzureML if running in a Hyperdrive job.
|
||||||
|
run_tags_parent = PARENT_RUN_CONTEXT.get_tags()
|
||||||
|
tags_to_copy = [
|
||||||
|
"tag",
|
||||||
|
"model_name",
|
||||||
|
"execution_mode",
|
||||||
|
"recovered_from",
|
||||||
|
"friendly_name",
|
||||||
|
"build_number",
|
||||||
|
"build_user",
|
||||||
|
RUN_RECOVERY_FROM_ID_KEY_NAME,
|
||||||
|
]
|
||||||
|
new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy}
|
||||||
|
new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT)
|
||||||
|
new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed())
|
||||||
|
RUN_CONTEXT.set_tags(new_tags)
|
||||||
|
|
||||||
|
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
|
||||||
|
"""
|
||||||
|
Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual
|
||||||
|
Lightning modules.
|
||||||
|
|
||||||
|
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
|
||||||
|
This can be missing when running in unit tests, where the local dataset paths are already populated.
|
||||||
|
"""
|
||||||
|
if self._has_setup_run:
|
||||||
|
return
|
||||||
|
if azure_run_info:
|
||||||
|
# Set up the paths to the datasets. azure_run_info already has all necessary information, using either
|
||||||
|
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
|
||||||
|
# This must happen before container setup because that could already read datasets.
|
||||||
|
logging.info("Setting tags from parent run.")
|
||||||
|
input_datasets = azure_run_info.input_datasets
|
||||||
|
logging.info(f"Setting the following datasets as local datasets: {input_datasets}")
|
||||||
|
if len(input_datasets) > 0:
|
||||||
|
local_datasets: List[Path] = []
|
||||||
|
for i, dataset in enumerate(input_datasets):
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Invalid setup: The dataset at index {i} is None")
|
||||||
|
local_datasets.append(check_dataset_folder_exists(dataset))
|
||||||
|
self.container.local_datasets = local_datasets # type: ignore
|
||||||
|
# Ensure that we use fixed seeds before initializing the PyTorch models.
|
||||||
|
# MONAI needs a separate method to make all transforms deterministic by default
|
||||||
|
seed = self.container.get_effective_random_seed()
|
||||||
|
seed_monai_if_available(seed)
|
||||||
|
seed_everything(seed)
|
||||||
|
|
||||||
|
# Creating the folder structure must happen before the LightningModule is created, because the output
|
||||||
|
# parameters of the container will be copied into the module.
|
||||||
|
self.container.create_filesystem(self.project_root)
|
||||||
|
|
||||||
|
# configure recovery container if provided
|
||||||
|
self.checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||||
|
|
||||||
|
# Create an AzureML run for logging if running outside AzureML.
|
||||||
|
self.create_logger()
|
||||||
|
|
||||||
|
self.container.setup()
|
||||||
|
self.container.create_lightning_module_and_store()
|
||||||
|
self._has_setup_run = True
|
||||||
|
|
||||||
|
if is_running_in_azure_ml():
|
||||||
|
self.setup_azureml()
|
||||||
|
|
||||||
|
def create_logger(self) -> None:
|
||||||
|
"""
|
||||||
|
Create an AzureML run for logging if running outside AzureML. This run will be used for metrics logging
|
||||||
|
during both training and inference. We can't rely on the automatically generated run inside the AzureMLLogger
|
||||||
|
class because two of those logger objects will be created, so training and inference metrics would be logged
|
||||||
|
in different runs.
|
||||||
|
"""
|
||||||
|
if self.container.log_from_vm:
|
||||||
|
run = create_aml_run_object(experiment_name=self.container.effective_experiment_name)
|
||||||
|
# Display name should already be set when creating the Run object, but in some scenarios this
|
||||||
|
# does not happen. Hence, set it again.
|
||||||
|
run.display_name = self.container.tag if self.container.tag else None
|
||||||
|
self.azureml_run_for_logging = run
|
||||||
|
|
||||||
|
def get_data_module(self) -> LightningDataModule:
|
||||||
|
"""
|
||||||
|
Reads the datamodule that should be used for training or valuation from the container. This must be
|
||||||
|
overridden in subclasses.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def set_trainer_for_inference(self) -> None:
|
||||||
|
"""Set the runner's PL Trainer object that should be used when running inference on the validation or test set.
|
||||||
|
We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||||
|
internally, which replicates some samples to make sure all devices have the same batch size in case of
|
||||||
|
uneven inputs which biases the results."""
|
||||||
|
mlflow_run_id = get_mlflow_run_id_from_trainer(self.trainer)
|
||||||
|
self.container.max_num_gpus = self.container.max_num_gpus_inference
|
||||||
|
self.trainer, _ = create_lightning_trainer(
|
||||||
|
container=self.container,
|
||||||
|
num_nodes=1,
|
||||||
|
azureml_run_for_logging=self.azureml_run_for_logging,
|
||||||
|
mlflow_run_for_logging=mlflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_inference(self) -> None:
|
||||||
|
"""Prepare the runner for inference on validation set, test set, or a full dataset.
|
||||||
|
The following steps are performed:
|
||||||
|
|
||||||
|
1. Get the checkpoint to use for inference. This is either the checkpoint from the last training epoch or the
|
||||||
|
one specified in src_checkpoint argument.
|
||||||
|
|
||||||
|
2. Create a new trainer instance for inference. This is necessary because the trainer is created with a single
|
||||||
|
device in contrast to training that uses DDP if multiple GPUs are available.
|
||||||
|
|
||||||
|
3. Create a new data module instance for inference to account for any requested changes in the dataloading
|
||||||
|
parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch.
|
||||||
|
"""
|
||||||
|
self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test())
|
||||||
|
self.set_trainer_for_inference()
|
||||||
|
self.data_module = self.get_data_module()
|
||||||
|
|
||||||
|
def run_inference(self) -> None:
|
||||||
|
"""Run inference on the test set for all models. This is done by calling the LightningModule.test_step method.
|
||||||
|
If the LightningModule.test_step method is not overridden, then this method does nothing. The cwd is changed to
|
||||||
|
the outputs folder so that the model can write to current working directory, and still everything is put into
|
||||||
|
the right place in AzureML (there, only the contents of the "outputs" folder is treated as a result file).
|
||||||
|
"""
|
||||||
|
if self.container.has_custom_test_step():
|
||||||
|
logging.info("Running inference via the LightningModule.test_step method")
|
||||||
|
with change_working_directory(self.container.outputs_folder):
|
||||||
|
assert self.trainer, "Trainer should be initialized before inference. Call self.init_inference()."
|
||||||
|
_ = self.trainer.test(
|
||||||
|
self.container.model, datamodule=self.data_module, ckpt_path=self.inference_checkpoint
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""
|
||||||
|
Run the training or evaluation. This method must be overridden in subclasses.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run_and_cleanup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
|
||||||
|
"""
|
||||||
|
Run the training or evaluation via `self.run` and cleanup afterwards.
|
||||||
|
|
||||||
|
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
|
||||||
|
This can be missing when running in unit tests, where the local dataset paths are already populated.
|
||||||
|
"""
|
||||||
|
self.setup(azure_run_info)
|
||||||
|
try:
|
||||||
|
self.run()
|
||||||
|
finally:
|
||||||
|
if self.azureml_run_for_logging is not None:
|
||||||
|
try:
|
||||||
|
self.azureml_run_for_logging.complete()
|
||||||
|
except Exception as ex:
|
||||||
|
logging.error("Failed to complete AzureML run: %s", ex)
|
|
@ -6,14 +6,11 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from typing import Dict
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from azureml.core import Run
|
from pytorch_lightning import LightningDataModule, seed_everything
|
||||||
from pytorch_lightning import Trainer, seed_everything
|
|
||||||
|
|
||||||
from health_azure import AzureRunInfo
|
|
||||||
from health_azure.logging import logging_section, print_message_with_rank_pid
|
from health_azure.logging import logging_section, print_message_with_rank_pid
|
||||||
from health_azure.utils import (
|
from health_azure.utils import (
|
||||||
ENV_GLOBAL_RANK,
|
ENV_GLOBAL_RANK,
|
||||||
|
@ -22,140 +19,24 @@ from health_azure.utils import (
|
||||||
ENV_OMPI_COMM_WORLD_RANK,
|
ENV_OMPI_COMM_WORLD_RANK,
|
||||||
PARENT_RUN_CONTEXT,
|
PARENT_RUN_CONTEXT,
|
||||||
RUN_CONTEXT,
|
RUN_CONTEXT,
|
||||||
create_aml_run_object,
|
|
||||||
create_run_recovery_id,
|
|
||||||
get_metrics_for_hyperdrive_run,
|
get_metrics_for_hyperdrive_run,
|
||||||
get_metrics_for_run,
|
get_metrics_for_run,
|
||||||
is_global_rank_zero,
|
is_global_rank_zero,
|
||||||
is_local_rank_zero,
|
is_local_rank_zero,
|
||||||
is_running_in_azure_ml,
|
is_running_in_azure_ml,
|
||||||
)
|
)
|
||||||
from health_ml.experiment_config import ExperimentConfig
|
|
||||||
from health_ml.lightning_container import LightningContainer
|
|
||||||
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
|
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
|
||||||
from health_ml.utils import fixed_paths
|
from health_ml.runner_base import RunnerBase
|
||||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
|
||||||
from health_ml.utils.checkpoint_utils import cleanup_checkpoints
|
from health_ml.utils.checkpoint_utils import cleanup_checkpoints
|
||||||
from health_ml.utils.common_utils import (
|
from health_ml.utils.common_utils import (
|
||||||
EFFECTIVE_RANDOM_SEED_KEY_NAME,
|
|
||||||
RUN_RECOVERY_FROM_ID_KEY_NAME,
|
|
||||||
RUN_RECOVERY_ID_KEY,
|
|
||||||
change_working_directory,
|
change_working_directory,
|
||||||
seed_monai_if_available,
|
|
||||||
)
|
)
|
||||||
from health_ml.utils.lightning_loggers import StoringLogger, get_mlflow_run_id_from_trainer
|
|
||||||
from health_ml.utils.regression_test_utils import REGRESSION_TEST_METRICS_FILENAME, compare_folders_and_run_outputs
|
from health_ml.utils.regression_test_utils import REGRESSION_TEST_METRICS_FILENAME, compare_folders_and_run_outputs
|
||||||
from health_ml.utils.type_annotations import PathOrString
|
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
|
class TrainingRunner(RunnerBase):
|
||||||
"""
|
def get_data_module(self) -> LightningDataModule:
|
||||||
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
|
return self.container.get_data_module()
|
||||||
to a Path instance. If it does not exist, raise a FileNotFoundError.
|
|
||||||
|
|
||||||
:param local_dataset: The dataset folder to check.
|
|
||||||
:return: The local_dataset argument, converted to a Path.
|
|
||||||
"""
|
|
||||||
expected_dir = Path(local_dataset)
|
|
||||||
if not expected_dir.is_dir():
|
|
||||||
raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.")
|
|
||||||
logging.info(f"Model training will use the local dataset provided in {expected_dir}")
|
|
||||||
return expected_dir
|
|
||||||
|
|
||||||
|
|
||||||
class MLRunner:
|
|
||||||
def __init__(
|
|
||||||
self, experiment_config: ExperimentConfig, container: LightningContainer, project_root: Optional[Path] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Driver class to run a ML experiment. Note that the project root argument MUST be supplied when using hi-ml
|
|
||||||
as a package!
|
|
||||||
|
|
||||||
:param experiment_config: The ExperimentConfig object to use for training.
|
|
||||||
:param container: The LightningContainer object to use for training.
|
|
||||||
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
|
|
||||||
it is crucial when using hi-ml as a package or submodule!
|
|
||||||
"""
|
|
||||||
self.container = container
|
|
||||||
self.experiment_config = experiment_config
|
|
||||||
self.container.num_nodes = self.experiment_config.num_nodes
|
|
||||||
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
|
|
||||||
self.storing_logger: Optional[StoringLogger] = None
|
|
||||||
self._has_setup_run = False
|
|
||||||
self.checkpoint_handler = CheckpointHandler(
|
|
||||||
container=self.container, project_root=self.project_root, run_context=RUN_CONTEXT
|
|
||||||
)
|
|
||||||
self.trainer: Optional[Trainer] = None
|
|
||||||
self.azureml_run_for_logging: Optional[Run] = None
|
|
||||||
self.mlflow_run_for_logging: Optional[str] = None
|
|
||||||
self.inference_checkpoint: Optional[str] = None # Passed to trainer.validate and trainer.test in inference mode
|
|
||||||
|
|
||||||
def set_run_tags_from_parent(self) -> None:
|
|
||||||
"""
|
|
||||||
Set metadata for the run
|
|
||||||
"""
|
|
||||||
assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run."
|
|
||||||
run_tags_parent = PARENT_RUN_CONTEXT.get_tags()
|
|
||||||
tags_to_copy = [
|
|
||||||
"tag",
|
|
||||||
"model_name",
|
|
||||||
"execution_mode",
|
|
||||||
"recovered_from",
|
|
||||||
"friendly_name",
|
|
||||||
"build_number",
|
|
||||||
"build_user",
|
|
||||||
RUN_RECOVERY_FROM_ID_KEY_NAME,
|
|
||||||
]
|
|
||||||
new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy}
|
|
||||||
new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT)
|
|
||||||
new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed())
|
|
||||||
RUN_CONTEXT.set_tags(new_tags)
|
|
||||||
|
|
||||||
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
|
|
||||||
"""
|
|
||||||
Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual
|
|
||||||
Lightning modules.
|
|
||||||
|
|
||||||
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
|
|
||||||
This can be missing when running in unit tests, where the local dataset paths are already populated.
|
|
||||||
"""
|
|
||||||
if self._has_setup_run:
|
|
||||||
return
|
|
||||||
if azure_run_info:
|
|
||||||
# Set up the paths to the datasets. azure_run_info already has all necessary information, using either
|
|
||||||
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
|
|
||||||
# This must happen before container setup because that could already read datasets.
|
|
||||||
input_datasets = azure_run_info.input_datasets
|
|
||||||
logging.info(f"Setting the following datasets as local datasets: {input_datasets}")
|
|
||||||
if len(input_datasets) > 0:
|
|
||||||
local_datasets: List[Path] = []
|
|
||||||
for i, dataset in enumerate(input_datasets):
|
|
||||||
if dataset is None:
|
|
||||||
raise ValueError(f"Invalid setup: The dataset at index {i} is None")
|
|
||||||
local_datasets.append(check_dataset_folder_exists(dataset))
|
|
||||||
self.container.local_datasets = local_datasets # type: ignore
|
|
||||||
# Ensure that we use fixed seeds before initializing the PyTorch models.
|
|
||||||
# MONAI needs a separate method to make all transforms deterministic by default
|
|
||||||
seed = self.container.get_effective_random_seed()
|
|
||||||
seed_monai_if_available(seed)
|
|
||||||
seed_everything(seed)
|
|
||||||
|
|
||||||
# Creating the folder structure must happen before the LightningModule is created, because the output
|
|
||||||
# parameters of the container will be copied into the module.
|
|
||||||
self.container.create_filesystem(self.project_root)
|
|
||||||
|
|
||||||
# configure recovery container if provided
|
|
||||||
self.checkpoint_handler.download_recovery_checkpoints_or_weights() # type: ignore
|
|
||||||
|
|
||||||
self.container.setup()
|
|
||||||
self.container.create_lightning_module_and_store()
|
|
||||||
self._has_setup_run = True
|
|
||||||
|
|
||||||
is_offline_run = not is_running_in_azure_ml(RUN_CONTEXT)
|
|
||||||
# Get the AzureML context in which the script is running
|
|
||||||
if not is_offline_run and PARENT_RUN_CONTEXT is not None:
|
|
||||||
logging.info("Setting tags from parent run.")
|
|
||||||
self.set_run_tags_from_parent()
|
|
||||||
|
|
||||||
def get_multiple_trainloader_mode(self) -> str:
|
def get_multiple_trainloader_mode(self) -> str:
|
||||||
# Workaround for a bug in PL 1.5.5: We need to pass the cycle mode for the training data as a trainer argument
|
# Workaround for a bug in PL 1.5.5: We need to pass the cycle mode for the training data as a trainer argument
|
||||||
|
@ -190,18 +71,7 @@ class MLRunner:
|
||||||
seed_everything(self.container.get_effective_random_seed(), workers=True)
|
seed_everything(self.container.get_effective_random_seed(), workers=True)
|
||||||
|
|
||||||
# Get the container's datamodule
|
# Get the container's datamodule
|
||||||
self.data_module = self.container.get_data_module()
|
self.data_module = self.get_data_module()
|
||||||
|
|
||||||
# Create an AzureML run for logging if running outside AzureML. This run will be used for metrics logging
|
|
||||||
# during both training and inference. We can't rely on the automatically generated run inside the AzureMLLogger
|
|
||||||
# class because two of those logger objects will be created, so training and inference metrics would be logged
|
|
||||||
# in different runs.
|
|
||||||
if self.container.log_from_vm:
|
|
||||||
run = create_aml_run_object(experiment_name=self.container.effective_experiment_name)
|
|
||||||
# Display name should already be set when creating the Run object, but in some scenarios this
|
|
||||||
# does not happen. Hence, set it again.
|
|
||||||
run.display_name = self.container.tag if self.container.tag else None
|
|
||||||
self.azureml_run_for_logging = run
|
|
||||||
|
|
||||||
if not self.container.run_inference_only:
|
if not self.container.run_inference_only:
|
||||||
checkpoint_path_for_recovery = self.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
checkpoint_path_for_recovery = self.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||||
|
@ -275,36 +145,6 @@ class MLRunner:
|
||||||
return self.container.crossval_index == 0
|
return self.container.crossval_index == 0
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def set_trainer_for_inference(self) -> None:
|
|
||||||
"""Set the runner's PL Trainer object that should be used when running inference on the validation or test set.
|
|
||||||
We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
|
||||||
internally, which replicates some samples to make sure all devices have the same batch size in case of
|
|
||||||
uneven inputs which biases the results."""
|
|
||||||
mlflow_run_id = get_mlflow_run_id_from_trainer(self.trainer)
|
|
||||||
self.container.max_num_gpus = self.container.max_num_gpus_inference
|
|
||||||
self.trainer, _ = create_lightning_trainer(
|
|
||||||
container=self.container,
|
|
||||||
num_nodes=1,
|
|
||||||
azureml_run_for_logging=self.azureml_run_for_logging,
|
|
||||||
mlflow_run_for_logging=mlflow_run_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_inference(self) -> None:
|
|
||||||
"""Prepare the runner for inference: validation or test. The following steps are performed:
|
|
||||||
1. Get the checkpoint to use for inference. This is either the checkpoint from the last training epoch or the
|
|
||||||
one specified in src_checkpoint argument.
|
|
||||||
2. If the container has a run_extra_val_epoch method, call it to run an extra validation epoch.
|
|
||||||
3. Create a new trainer instance for inference. This is necessary because the trainer is created with a single
|
|
||||||
device in contrast to training that uses DDP if multiple GPUs are available.
|
|
||||||
4. Create a new data module instance for inference to account for any requested changes in the dataloading
|
|
||||||
parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch.
|
|
||||||
"""
|
|
||||||
self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test())
|
|
||||||
if self.container.run_extra_val_epoch:
|
|
||||||
self.container.on_run_extra_validation_epoch()
|
|
||||||
self.set_trainer_for_inference()
|
|
||||||
self.data_module = self.container.get_data_module()
|
|
||||||
|
|
||||||
def run_training(self) -> None:
|
def run_training(self) -> None:
|
||||||
"""
|
"""
|
||||||
The main training loop. It creates the Pytorch model based on the configuration options passed in,
|
The main training loop. It creates the Pytorch model based on the configuration options passed in,
|
||||||
|
@ -321,6 +161,16 @@ class MLRunner:
|
||||||
assert logger is not None
|
assert logger is not None
|
||||||
logger.finalize('success')
|
logger.finalize('success')
|
||||||
|
|
||||||
|
def init_inference(self) -> None:
|
||||||
|
"""
|
||||||
|
Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint,
|
||||||
|
initializes the PL Trainer object, and chooses the right data module. Afterwards, the hook for running
|
||||||
|
inference on the validation set is run (`LightningContainer.on_run_extra_validation_epoch`)
|
||||||
|
"""
|
||||||
|
super().init_inference()
|
||||||
|
if self.container.run_extra_val_epoch:
|
||||||
|
self.container.on_run_extra_validation_epoch()
|
||||||
|
|
||||||
def run_validation(self) -> None:
|
def run_validation(self) -> None:
|
||||||
"""Run validation on the validation set for all models to save time/memory consuming outputs. This is done in
|
"""Run validation on the validation set for all models to save time/memory consuming outputs. This is done in
|
||||||
inference only mode or when the user has requested an extra validation epoch. The cwd is changed to the outputs
|
inference only mode or when the user has requested an extra validation epoch. The cwd is changed to the outputs
|
||||||
|
@ -334,22 +184,6 @@ class MLRunner:
|
||||||
else:
|
else:
|
||||||
logging.info("Skipping extra validation because the user has not requested it.")
|
logging.info("Skipping extra validation because the user has not requested it.")
|
||||||
|
|
||||||
def run_inference(self) -> None:
|
|
||||||
"""Run inference on the test set for all models. This is done by calling the LightningModule.test_step method.
|
|
||||||
If the LightningModule.test_step method is not overridden, then this method does nothing. The cwd is changed to
|
|
||||||
the outputs folder so that the model can write to current working directory, and still everything is put into
|
|
||||||
the right place in AzureML (there, only the contents of the "outputs" folder is treated as a result file).
|
|
||||||
"""
|
|
||||||
if self.container.has_custom_test_step():
|
|
||||||
logging.info("Running inference via the LightningModule.test_step method")
|
|
||||||
with change_working_directory(self.container.outputs_folder):
|
|
||||||
assert self.trainer, "Trainer should be initialized before inference. Call self.init_inference()."
|
|
||||||
_ = self.trainer.test(
|
|
||||||
self.container.model, datamodule=self.data_module, ckpt_path=self.inference_checkpoint
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")
|
|
||||||
|
|
||||||
def run_regression_test(self) -> None:
|
def run_regression_test(self) -> None:
|
||||||
if self.container.regression_test_folder:
|
if self.container.regression_test_folder:
|
||||||
with logging_section("Regression Test"):
|
with logging_section("Regression Test"):
|
||||||
|
@ -392,32 +226,23 @@ class MLRunner:
|
||||||
"""
|
"""
|
||||||
Driver function to run a ML experiment
|
Driver function to run a ML experiment
|
||||||
"""
|
"""
|
||||||
self.setup()
|
self.init_training()
|
||||||
try:
|
|
||||||
self.init_training()
|
|
||||||
|
|
||||||
if not self.container.run_inference_only:
|
if not self.container.run_inference_only:
|
||||||
# Backup the environment variables in case we need to run a second training in the unit tests.
|
# Backup the environment variables in case we need to run a second training in the unit tests.
|
||||||
environ_before_training = dict(os.environ)
|
environ_before_training = dict(os.environ)
|
||||||
|
|
||||||
with logging_section("Model training"):
|
with logging_section("Model training"):
|
||||||
self.run_training()
|
self.run_training()
|
||||||
|
|
||||||
self.end_training(environ_before_training)
|
self.end_training(environ_before_training)
|
||||||
|
|
||||||
self.init_inference()
|
self.init_inference()
|
||||||
|
|
||||||
with logging_section("Model validation"):
|
with logging_section("Model validation"):
|
||||||
self.run_validation()
|
self.run_validation()
|
||||||
|
|
||||||
with logging_section("Model inference"):
|
with logging_section("Model inference"):
|
||||||
self.run_inference()
|
self.run_inference()
|
||||||
|
|
||||||
self.run_regression_test()
|
self.run_regression_test()
|
||||||
|
|
||||||
finally:
|
|
||||||
if self.azureml_run_for_logging is not None:
|
|
||||||
try:
|
|
||||||
self.azureml_run_for_logging.complete()
|
|
||||||
except Exception as ex:
|
|
||||||
logging.error("Failed to complete AzureML run: %s", ex)
|
|
|
@ -1,4 +1,15 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from health_ml.runner import Runner
|
||||||
from health_ml.utils import health_ml_package_setup
|
from health_ml.utils import health_ml_package_setup
|
||||||
|
|
||||||
# Reduce logging noise in DEBUG mode
|
# Reduce logging noise in DEBUG mode
|
||||||
health_ml_package_setup()
|
health_ml_package_setup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_runner(tmp_path: Path) -> Runner:
|
||||||
|
"""A test fixture that creates a Runner object in a temporary folder."""
|
||||||
|
|
||||||
|
return Runner(project_root=tmp_path)
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from health_ml import TrainingRunner
|
||||||
|
from health_ml.configs.hello_world import (
|
||||||
|
TEST_MAE_FILE,
|
||||||
|
TEST_MSE_FILE,
|
||||||
|
HelloWorld,
|
||||||
|
HelloWorldDataModule,
|
||||||
|
)
|
||||||
|
from health_ml.eval_runner import EvalRunner
|
||||||
|
from health_ml.experiment_config import ExperimentConfig, RunnerMode
|
||||||
|
from health_ml.runner import Runner
|
||||||
|
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||||
|
from testhiml.test_training_runner import training_runner_hello_world
|
||||||
|
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
|
||||||
|
|
||||||
|
|
||||||
|
hello_world_checkpoint = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_runner_no_checkpoint(mock_runner: Runner) -> None:
|
||||||
|
"""Test of the evaluation mode fails if no checkpoint source is provided"""
|
||||||
|
arguments = ["", f"--model=HelloWorld", f"--mode={RunnerMode.EVAL_FULL.value}"]
|
||||||
|
with pytest.raises(ValueError, match="To use model evaluation, you need to provide a checkpoint to use"):
|
||||||
|
with patch.object(sys, "argv", arguments):
|
||||||
|
mock_runner.run()
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_runner_end_to_end(mock_runner: Runner) -> None:
|
||||||
|
"""Test the end-to-end integration of the EvalRunner class into the overall Runner"""
|
||||||
|
arguments = [
|
||||||
|
"",
|
||||||
|
f"--model=HelloWorld",
|
||||||
|
f"--mode={RunnerMode.EVAL_FULL.value}",
|
||||||
|
f"--src_checkpoint={hello_world_checkpoint}",
|
||||||
|
]
|
||||||
|
with patch("health_ml.training_runner.TrainingRunner.run_and_cleanup") as mock_training_run:
|
||||||
|
with patch.object(sys, "argv", arguments):
|
||||||
|
mock_runner.run()
|
||||||
|
# The training runner should not be invoked
|
||||||
|
mock_training_run.assert_not_called()
|
||||||
|
# The eval runner should have been invoked. The test step writes two files with metrics, check that
|
||||||
|
# they exist
|
||||||
|
output_folder = mock_runner.lightning_container.outputs_folder
|
||||||
|
for file_name in [TEST_MSE_FILE, TEST_MAE_FILE]:
|
||||||
|
assert (output_folder / file_name).exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_runner_methods_called(tmp_path: Path) -> None:
|
||||||
|
"""Test if the eval runner uses the right data module from the HelloWorld model"""
|
||||||
|
container = HelloWorld()
|
||||||
|
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
|
||||||
|
eval_runner = EvalRunner(
|
||||||
|
container=container, experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL), project_root=tmp_path
|
||||||
|
)
|
||||||
|
with patch("health_ml.configs.hello_world.HelloWorld.get_eval_data_module") as mock_get_data_module:
|
||||||
|
mock_get_data_module.return_value = HelloWorldDataModule(crossval_count=1, seed=1)
|
||||||
|
eval_runner.run_and_cleanup()
|
||||||
|
mock_get_data_module.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_runner_no_extra_validation_epoch_called(tmp_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Ensure that the eval runner does not invoke the hook the extra validation epoch that is used by the training runner.
|
||||||
|
"""
|
||||||
|
container = HelloWorld()
|
||||||
|
container.run_extra_val_epoch = True
|
||||||
|
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
|
||||||
|
eval_runner = EvalRunner(
|
||||||
|
container=container, experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL), project_root=tmp_path
|
||||||
|
)
|
||||||
|
with patch("health_ml.configs.hello_world.HelloRegression.on_run_extra_validation_epoch") as mock_hook:
|
||||||
|
eval_runner.run_and_cleanup()
|
||||||
|
mock_hook.assert_not_called()
|
|
@ -13,7 +13,7 @@ import pytest
|
||||||
|
|
||||||
from health_azure.utils import create_aml_run_object
|
from health_azure.utils import create_aml_run_object
|
||||||
from health_ml.experiment_config import ExperimentConfig
|
from health_ml.experiment_config import ExperimentConfig
|
||||||
from health_ml.run_ml import MLRunner
|
from health_ml.training_runner import TrainingRunner
|
||||||
from health_ml.configs.hello_world import HelloWorld
|
from health_ml.configs.hello_world import HelloWorld
|
||||||
from health_ml.utils.regression_test_utils import (
|
from health_ml.utils.regression_test_utils import (
|
||||||
CONTENTS_MISMATCH,
|
CONTENTS_MISMATCH,
|
||||||
|
@ -54,7 +54,7 @@ def test_regression_test() -> None:
|
||||||
"""
|
"""
|
||||||
container = HelloWorld()
|
container = HelloWorld()
|
||||||
container.regression_test_folder = Path(str(uuid.uuid4().hex))
|
container.regression_test_folder = Path(str(uuid.uuid4().hex))
|
||||||
runner = MLRunner(container=container, experiment_config=ExperimentConfig())
|
runner = TrainingRunner(container=container, experiment_config=ExperimentConfig())
|
||||||
runner.setup()
|
runner.setup()
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError) as ex:
|
||||||
runner.run()
|
runner.run()
|
||||||
|
|
|
@ -19,7 +19,7 @@ from health_azure import AzureRunInfo, DatasetConfig
|
||||||
from health_azure.himl import OUTPUT_FOLDER
|
from health_azure.himl import OUTPUT_FOLDER
|
||||||
from health_azure.utils import ENV_LOCAL_RANK, ENV_NODE_RANK
|
from health_azure.utils import ENV_LOCAL_RANK, ENV_NODE_RANK
|
||||||
from health_azure.paths import ENVIRONMENT_YAML_FILE_NAME
|
from health_azure.paths import ENVIRONMENT_YAML_FILE_NAME
|
||||||
from health_ml.configs.hello_world import HelloWorld # type: ignore
|
from health_ml.configs.hello_world import TEST_MAE_FILE, TEST_MSE_FILE, HelloWorld # type: ignore
|
||||||
from health_ml.deep_learning_config import WorkflowParams
|
from health_ml.deep_learning_config import WorkflowParams
|
||||||
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, DebugDDPOptions
|
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, DebugDDPOptions
|
||||||
from health_ml.lightning_container import LightningContainer
|
from health_ml.lightning_container import LightningContainer
|
||||||
|
@ -28,13 +28,6 @@ from health_ml.utils.common_utils import change_working_directory
|
||||||
from health_ml.utils.fixed_paths import repository_root_directory
|
from health_ml.utils.fixed_paths import repository_root_directory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_runner(tmp_path: Path) -> Runner:
|
|
||||||
"""A test fixture that creates a Runner object in a temporary folder."""
|
|
||||||
|
|
||||||
return Runner(project_root=tmp_path)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def change_working_folder_and_add_environment(tmp_path: Path) -> Generator:
|
def change_working_folder_and_add_environment(tmp_path: Path) -> Generator:
|
||||||
# Use a special simplified environment file only for the tests here. Copy that to a temp folder, then let the runner
|
# Use a special simplified environment file only for the tests here. Copy that to a temp folder, then let the runner
|
||||||
|
@ -407,7 +400,7 @@ def test_run_hello_world(mock_runner: Runner) -> None:
|
||||||
# time-consuming auth
|
# time-consuming auth
|
||||||
mock_get_workspace.assert_not_called()
|
mock_get_workspace.assert_not_called()
|
||||||
# Summary.txt is written at start, the other files during inference
|
# Summary.txt is written at start, the other files during inference
|
||||||
expected_files = ["experiment_summary.txt", "test_mae.txt", "test_mse.txt"]
|
expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE]
|
||||||
for file in expected_files:
|
for file in expected_files:
|
||||||
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
|
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
|
||||||
|
|
||||||
|
@ -428,9 +421,8 @@ def test_invalid_profiler(mock_runner: Runner) -> None:
|
||||||
invalid_profile = "--pl_profiler=foo"
|
invalid_profile = "--pl_profiler=foo"
|
||||||
arguments = ["", "--model=HelloWorld", invalid_profile]
|
arguments = ["", "--model=HelloWorld", invalid_profile]
|
||||||
with patch.object(sys, "argv", arguments):
|
with patch.object(sys, "argv", arguments):
|
||||||
with pytest.raises(ValueError) as ex:
|
with pytest.raises(ValueError, match="Unsupported profiler."):
|
||||||
mock_runner.run()
|
mock_runner.run()
|
||||||
assert "Unsupported profiler." in str(ex)
|
|
||||||
|
|
||||||
|
|
||||||
def test_custom_datastore_outside_aml(mock_runner: Runner) -> None:
|
def test_custom_datastore_outside_aml(mock_runner: Runner) -> None:
|
||||||
|
|
|
@ -18,34 +18,38 @@ from pytorch_lightning import LightningModule
|
||||||
import mlflow
|
import mlflow
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
|
|
||||||
from health_ml.configs.hello_world import HelloWorld # type: ignore
|
from health_ml.configs.hello_world import TEST_MAE_FILE, TEST_MSE_FILE, HelloWorld # type: ignore
|
||||||
from health_ml.experiment_config import ExperimentConfig
|
from health_ml.experiment_config import ExperimentConfig
|
||||||
from health_ml.lightning_container import LightningContainer
|
from health_ml.lightning_container import LightningContainer
|
||||||
from health_ml.run_ml import MLRunner
|
from health_ml.runner_base import RunnerBase
|
||||||
|
from health_ml.training_runner import TrainingRunner
|
||||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||||
from health_ml.utils.common_utils import is_gpu_available
|
from health_ml.utils.common_utils import EFFECTIVE_RANDOM_SEED_KEY_NAME, is_gpu_available
|
||||||
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger, get_mlflow_run_id_from_trainer
|
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger, get_mlflow_run_id_from_trainer
|
||||||
from health_azure.utils import ENV_EXPERIMENT_NAME, is_global_rank_zero
|
from health_azure.utils import ENV_EXPERIMENT_NAME, is_global_rank_zero
|
||||||
from testazure.utils_testazure import DEFAULT_WORKSPACE, experiment_for_unittests
|
from testazure.utils_testazure import DEFAULT_WORKSPACE, experiment_for_unittests
|
||||||
from testhiml.utils.fixed_paths_for_tests import mock_run_id
|
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
|
||||||
|
|
||||||
no_gpu = not is_gpu_available()
|
no_gpu = not is_gpu_available()
|
||||||
|
hello_world_checkpoint = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture()
|
||||||
def ml_runner_no_setup() -> MLRunner:
|
def training_runner_no_setup(tmp_path: Path) -> TrainingRunner:
|
||||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||||
container = LightningContainer(num_epochs=1)
|
container = LightningContainer(num_epochs=1)
|
||||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
container.set_output_to(tmp_path)
|
||||||
|
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture()
|
||||||
def ml_runner() -> Generator:
|
def training_runner(tmp_path: Path) -> Generator:
|
||||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||||
container = LightningContainer(num_epochs=1)
|
container = LightningContainer(num_epochs=1)
|
||||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
container.set_output_to(tmp_path)
|
||||||
|
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||||
runner.setup()
|
runner.setup()
|
||||||
yield runner
|
yield runner
|
||||||
output_dir = runner.container.file_system_config.outputs_folder
|
output_dir = runner.container.file_system_config.outputs_folder
|
||||||
|
@ -54,10 +58,11 @@ def ml_runner() -> Generator:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def ml_runner_with_container() -> Generator:
|
def training_runner_hello_world(tmp_path: Path) -> Generator:
|
||||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||||
container = HelloWorld()
|
container = HelloWorld()
|
||||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
container.set_output_to(tmp_path)
|
||||||
|
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||||
runner.setup()
|
runner.setup()
|
||||||
yield runner
|
yield runner
|
||||||
output_dir = runner.container.file_system_config.outputs_folder
|
output_dir = runner.container.file_system_config.outputs_folder
|
||||||
|
@ -66,12 +71,15 @@ def ml_runner_with_container() -> Generator:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def ml_runner_with_run_id() -> Generator:
|
def training_runner_hello_world_with_checkpoint() -> Generator:
|
||||||
|
"""
|
||||||
|
A fixture with a training runner for the HelloWorld model that has a src_checkpoint set.
|
||||||
|
"""
|
||||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||||
container = HelloWorld()
|
container = HelloWorld()
|
||||||
container.save_checkpoint = True
|
container.save_checkpoint = True
|
||||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
|
||||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||||
runner.setup()
|
runner.setup()
|
||||||
yield runner
|
yield runner
|
||||||
output_dir = runner.container.file_system_config.outputs_folder
|
output_dir = runner.container.file_system_config.outputs_folder
|
||||||
|
@ -92,14 +100,27 @@ def regression_datadir(tmp_path: Path) -> Generator:
|
||||||
shutil.rmtree(tmp_path)
|
shutil.rmtree(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None:
|
def create_mlflow_trash_folder(runner: RunnerBase) -> None:
|
||||||
|
"""Create a trash folder where MLFlow expects its deleted runs.
|
||||||
|
This is a workaround for sporadic test failures: When reading out the run_id, MLFlow checks its own
|
||||||
|
deleted runs folder, but that (or one of its parents) does not exist
|
||||||
|
"""
|
||||||
|
trash_folder = runner.container.outputs_folder / "mlruns" / ".trash"
|
||||||
|
trash_folder.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ml_runner_setup(training_runner_no_setup: TrainingRunner) -> None:
|
||||||
"""Check that all the necessary methods get called during setup"""
|
"""Check that all the necessary methods get called during setup"""
|
||||||
assert not ml_runner_no_setup._has_setup_run
|
assert not training_runner_no_setup._has_setup_run
|
||||||
with patch.object(ml_runner_no_setup, "container", spec=LightningContainer) as mock_container:
|
with patch.object(training_runner_no_setup, "container", spec=LightningContainer) as mock_container:
|
||||||
with patch.object(ml_runner_no_setup, "checkpoint_handler", spec=CheckpointHandler) as mock_checkpoint_handler:
|
# Without that, it would try to create a local run object for logging and fail there.
|
||||||
with patch("health_ml.run_ml.seed_everything") as mock_seed:
|
mock_container.log_from_vm = False
|
||||||
with patch("health_ml.run_ml.seed_monai_if_available") as mock_seed_monai:
|
with patch.object(
|
||||||
ml_runner_no_setup.setup()
|
training_runner_no_setup, "checkpoint_handler", spec=CheckpointHandler
|
||||||
|
) as mock_checkpoint_handler:
|
||||||
|
with patch("health_ml.runner_base.seed_everything") as mock_seed:
|
||||||
|
with patch("health_ml.runner_base.seed_monai_if_available") as mock_seed_monai:
|
||||||
|
training_runner_no_setup.setup()
|
||||||
mock_container.get_effective_random_seed.assert_called()
|
mock_container.get_effective_random_seed.assert_called()
|
||||||
mock_seed.assert_called_once()
|
mock_seed.assert_called_once()
|
||||||
mock_seed_monai.assert_called_once()
|
mock_seed_monai.assert_called_once()
|
||||||
|
@ -107,47 +128,58 @@ def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None:
|
||||||
mock_checkpoint_handler.download_recovery_checkpoints_or_weights.assert_called_once()
|
mock_checkpoint_handler.download_recovery_checkpoints_or_weights.assert_called_once()
|
||||||
mock_container.setup.assert_called_once()
|
mock_container.setup.assert_called_once()
|
||||||
mock_container.create_lightning_module_and_store.assert_called_once()
|
mock_container.create_lightning_module_and_store.assert_called_once()
|
||||||
assert ml_runner_no_setup._has_setup_run
|
assert training_runner_no_setup._has_setup_run
|
||||||
|
|
||||||
|
|
||||||
def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None:
|
def test_setup_azureml(training_runner: TrainingRunner) -> None:
|
||||||
"""Test that set_run_tags_from_parents causes set_tags to get called"""
|
"""Test that setup_azureml causes set_tags to get called when running in Hyperdrive"""
|
||||||
with pytest.raises(AssertionError) as ae:
|
with patch("health_ml.runner_base.RUN_CONTEXT") as mock_run_context:
|
||||||
ml_runner.set_run_tags_from_parent()
|
training_runner.setup_azureml()
|
||||||
assert "should only be called in a Hyperdrive run" in str(ae)
|
# Tests always run outside of a Hyperdrive run. In those cases, no tags should be set on
|
||||||
|
# the current run.
|
||||||
|
mock_run_context.set_tags.assert_not_called()
|
||||||
|
|
||||||
with patch("health_ml.run_ml.PARENT_RUN_CONTEXT") as mock_parent_run_context:
|
with patch("health_ml.runner_base.PARENT_RUN_CONTEXT") as mock_parent_run_context:
|
||||||
with patch("health_ml.run_ml.RUN_CONTEXT") as mock_run_context:
|
# Mock the presence of a parent run, and tags that are present there
|
||||||
mock_parent_run_context.get_tags.return_value = {"tag": "dummy_tag"}
|
tag_name = "tag"
|
||||||
ml_runner.set_run_tags_from_parent()
|
tag_value = "dummy_tag"
|
||||||
mock_run_context.set_tags.assert_called()
|
mock_parent_run_context.get_tags.return_value = {tag_name: tag_value}
|
||||||
|
training_runner.setup_azureml()
|
||||||
|
# The function should read out tags from the parent run, and set them on the current run
|
||||||
|
mock_parent_run_context.get_tags.assert_called_once_with()
|
||||||
|
mock_run_context.set_tags.assert_called_once()
|
||||||
|
call_args = mock_run_context.set_tags.call_args[0][0]
|
||||||
|
assert tag_name in call_args
|
||||||
|
assert call_args[tag_name] == tag_value
|
||||||
|
assert EFFECTIVE_RANDOM_SEED_KEY_NAME in call_args
|
||||||
|
|
||||||
|
|
||||||
def test_get_multiple_trainloader_mode(ml_runner: MLRunner) -> None:
|
def test_get_multiple_trainloader_mode(training_runner: TrainingRunner) -> None:
|
||||||
multiple_trainloader_mode = ml_runner.get_multiple_trainloader_mode()
|
training_runner.init_training()
|
||||||
|
multiple_trainloader_mode = training_runner.get_multiple_trainloader_mode()
|
||||||
assert multiple_trainloader_mode == "max_size_cycle", "train_loader_cycle_mode is available now, "
|
assert multiple_trainloader_mode == "max_size_cycle", "train_loader_cycle_mode is available now, "
|
||||||
"`get_multiple_trainloader_mode` workaround can be safely removed."
|
"`get_multiple_trainloader_mode` workaround can be safely removed."
|
||||||
|
|
||||||
|
|
||||||
def _test_init_training(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None:
|
def _test_init_training(run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture) -> None:
|
||||||
"""Test that training is initialized correctly"""
|
"""Test that training is initialized correctly"""
|
||||||
ml_runner.container.run_inference_only = run_inference_only
|
training_runner.container.run_inference_only = run_inference_only
|
||||||
ml_runner.setup()
|
training_runner.setup()
|
||||||
assert not ml_runner.checkpoint_handler.has_continued_training
|
assert not training_runner.checkpoint_handler.has_continued_training
|
||||||
assert ml_runner.trainer is None
|
assert training_runner.trainer is None
|
||||||
assert ml_runner.storing_logger is None
|
assert training_runner.storing_logger is None
|
||||||
|
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer:
|
||||||
with patch.object(ml_runner.container, "get_data_module") as mock_get_data_module:
|
with patch.object(training_runner.container, "get_data_module") as mock_get_data_module:
|
||||||
with patch("health_ml.run_ml.write_experiment_summary_file") as mock_write_experiment_summary_file:
|
with patch("health_ml.training_runner.write_experiment_summary_file") as mock_write_experiment_summary_file:
|
||||||
with patch.object(
|
with patch.object(
|
||||||
ml_runner.checkpoint_handler, "get_recovery_or_checkpoint_path_train"
|
training_runner.checkpoint_handler, "get_recovery_or_checkpoint_path_train"
|
||||||
) as mock_get_recovery_or_checkpoint_path_train:
|
) as mock_get_recovery_or_checkpoint_path_train:
|
||||||
with patch("health_ml.run_ml.seed_everything") as mock_seed:
|
with patch("health_ml.training_runner.seed_everything") as mock_seed:
|
||||||
mock_create_trainer.return_value = MagicMock(), MagicMock()
|
mock_create_trainer.return_value = MagicMock(), MagicMock()
|
||||||
mock_get_recovery_or_checkpoint_path_train.return_value = "dummy_path"
|
mock_get_recovery_or_checkpoint_path_train.return_value = "dummy_path"
|
||||||
|
|
||||||
ml_runner.init_training()
|
training_runner.init_training()
|
||||||
|
|
||||||
# Make sure write_experiment_summary_file is only called on rank 0
|
# Make sure write_experiment_summary_file is only called on rank 0
|
||||||
if is_global_rank_zero():
|
if is_global_rank_zero():
|
||||||
|
@ -157,47 +189,51 @@ def _test_init_training(run_inference_only: bool, ml_runner: MLRunner, caplog: L
|
||||||
|
|
||||||
# Make sure seed is set correctly with workers=True
|
# Make sure seed is set correctly with workers=True
|
||||||
mock_seed.assert_called_once()
|
mock_seed.assert_called_once()
|
||||||
assert mock_seed.call_args[0][0] == ml_runner.container.get_effective_random_seed()
|
assert mock_seed.call_args[0][0] == training_runner.container.get_effective_random_seed()
|
||||||
assert mock_seed.call_args[1]["workers"]
|
assert mock_seed.call_args[1]["workers"]
|
||||||
|
|
||||||
mock_get_data_module.assert_called_once()
|
mock_get_data_module.assert_called_once()
|
||||||
assert ml_runner.data_module is not None
|
assert training_runner.data_module is not None
|
||||||
|
|
||||||
if not run_inference_only:
|
if not run_inference_only:
|
||||||
mock_get_recovery_or_checkpoint_path_train.assert_called_once()
|
mock_get_recovery_or_checkpoint_path_train.assert_called_once()
|
||||||
# Validate that the trainer is created correctly
|
# Validate that the trainer is created correctly
|
||||||
assert mock_create_trainer.call_args[1]["resume_from_checkpoint"] == "dummy_path"
|
assert mock_create_trainer.call_args[1]["resume_from_checkpoint"] == "dummy_path"
|
||||||
assert ml_runner.storing_logger is not None
|
assert training_runner.storing_logger is not None
|
||||||
assert ml_runner.trainer is not None
|
assert training_runner.trainer is not None
|
||||||
assert "Environment variables:" in caplog.messages[-1]
|
assert "Environment variables:" in caplog.messages[-1]
|
||||||
else:
|
else:
|
||||||
assert ml_runner.trainer is None
|
assert training_runner.trainer is None
|
||||||
assert ml_runner.storing_logger is None
|
assert training_runner.storing_logger is None
|
||||||
mock_get_recovery_or_checkpoint_path_train.assert_not_called()
|
mock_get_recovery_or_checkpoint_path_train.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||||
def test_init_training_cpu(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None:
|
def test_init_training_cpu(
|
||||||
|
run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
"""Test that training is initialized correctly"""
|
"""Test that training is initialized correctly"""
|
||||||
ml_runner.container.max_num_gpus = 0
|
training_runner.container.max_num_gpus = 0
|
||||||
_test_init_training(run_inference_only, ml_runner, caplog)
|
_test_init_training(run_inference_only, training_runner, caplog)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||||
@pytest.mark.gpu
|
@pytest.mark.gpu
|
||||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||||
def test_init_training_gpu(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None:
|
def test_init_training_gpu(
|
||||||
|
run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture
|
||||||
|
) -> None:
|
||||||
"""Test that training is initialized correctly in DDP mode"""
|
"""Test that training is initialized correctly in DDP mode"""
|
||||||
_test_init_training(run_inference_only, ml_runner, caplog)
|
_test_init_training(run_inference_only, training_runner, caplog)
|
||||||
|
|
||||||
|
|
||||||
def test_run_training() -> None:
|
def test_run_training() -> None:
|
||||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||||
container = HelloWorld()
|
container = HelloWorld()
|
||||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||||
|
|
||||||
with patch.object(container, "get_data_module") as mock_get_data_module:
|
with patch.object(container, "get_data_module") as mock_get_data_module:
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer:
|
||||||
runner.setup()
|
runner.setup()
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
mock_storing_logger = MagicMock()
|
mock_storing_logger = MagicMock()
|
||||||
|
@ -225,17 +261,17 @@ def test_end_training(max_num_gpus_inf: int) -> None:
|
||||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||||
container = HelloWorld()
|
container = HelloWorld()
|
||||||
container.max_num_gpus_inference = max_num_gpus_inf
|
container.max_num_gpus_inference = max_num_gpus_inf
|
||||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||||
|
|
||||||
with patch.object(container, "get_data_module"):
|
with patch.object(container, "get_data_module"):
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
with patch("health_ml.training_runner.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
||||||
runner.setup()
|
runner.setup()
|
||||||
runner.init_training()
|
runner.init_training()
|
||||||
runner.run_training()
|
runner.run_training()
|
||||||
|
|
||||||
with patch.object(runner.checkpoint_handler, "additional_training_done") as mock_additional_training_done:
|
with patch.object(runner.checkpoint_handler, "additional_training_done") as mock_additional_training_done:
|
||||||
with patch.object(runner, "after_ddp_cleanup") as mock_after_ddp_cleanup:
|
with patch.object(runner, "after_ddp_cleanup") as mock_after_ddp_cleanup:
|
||||||
with patch("health_ml.run_ml.cleanup_checkpoints") as mock_cleanup_ckpt:
|
with patch("health_ml.training_runner.cleanup_checkpoints") as mock_cleanup_ckpt:
|
||||||
environ_before_training = {"old": "environ"}
|
environ_before_training = {"old": "environ"}
|
||||||
runner.end_training(environ_before_training=environ_before_training)
|
runner.end_training(environ_before_training=environ_before_training)
|
||||||
mock_additional_training_done.assert_called_once()
|
mock_additional_training_done.assert_called_once()
|
||||||
|
@ -251,71 +287,98 @@ def test_end_training(max_num_gpus_inf: int) -> None:
|
||||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||||
def test_init_inference(
|
def test_init_inference(
|
||||||
run_inference_only: bool, run_extra_val_epoch: bool, max_num_gpus_inf: int, ml_runner_with_run_id: MLRunner
|
run_inference_only: bool,
|
||||||
|
run_extra_val_epoch: bool,
|
||||||
|
max_num_gpus_inf: int,
|
||||||
|
training_runner_hello_world_with_checkpoint: TrainingRunner,
|
||||||
) -> None:
|
) -> None:
|
||||||
ml_runner_with_run_id.container.run_inference_only = run_inference_only
|
training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only
|
||||||
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
|
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
|
||||||
ml_runner_with_run_id.container.max_num_gpus_inference = max_num_gpus_inf
|
training_runner_hello_world_with_checkpoint.container.max_num_gpus_inference = max_num_gpus_inf
|
||||||
assert ml_runner_with_run_id.container.max_num_gpus == -1 # This is the default value of max_num_gpus
|
assert (
|
||||||
ml_runner_with_run_id.init_training()
|
training_runner_hello_world_with_checkpoint.container.max_num_gpus == -1
|
||||||
|
) # This is the default value of max_num_gpus
|
||||||
|
training_runner_hello_world_with_checkpoint.init_training()
|
||||||
if run_inference_only:
|
if run_inference_only:
|
||||||
expected_mlflow_run_id = None
|
expected_mlflow_run_id = None
|
||||||
else:
|
else:
|
||||||
assert ml_runner_with_run_id.trainer is not None
|
assert training_runner_hello_world_with_checkpoint.trainer is not None
|
||||||
expected_mlflow_run_id = ml_runner_with_run_id.trainer.loggers[1].run_id # type: ignore
|
create_mlflow_trash_folder(training_runner_hello_world_with_checkpoint)
|
||||||
|
expected_mlflow_run_id = training_runner_hello_world_with_checkpoint.trainer.loggers[1].run_id # type: ignore
|
||||||
if not run_inference_only:
|
if not run_inference_only:
|
||||||
ml_runner_with_run_id.checkpoint_handler.additional_training_done()
|
training_runner_hello_world_with_checkpoint.checkpoint_handler.additional_training_done()
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer:
|
||||||
with patch.object(ml_runner_with_run_id.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
|
with patch.object(
|
||||||
with patch.object(ml_runner_with_run_id.container, "get_data_module") as mock_get_data_module:
|
training_runner_hello_world_with_checkpoint.container, "get_checkpoint_to_test"
|
||||||
|
) as mock_get_checkpoint_to_test:
|
||||||
|
with patch.object(
|
||||||
|
training_runner_hello_world_with_checkpoint.container, "get_data_module"
|
||||||
|
) as mock_get_data_module:
|
||||||
mock_checkpoint = MagicMock(is_file=MagicMock(return_value=True))
|
mock_checkpoint = MagicMock(is_file=MagicMock(return_value=True))
|
||||||
mock_get_checkpoint_to_test.return_value = mock_checkpoint
|
mock_get_checkpoint_to_test.return_value = mock_checkpoint
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
||||||
mock_get_data_module.return_value = "dummy_data_module"
|
mock_get_data_module.return_value = "dummy_data_module"
|
||||||
|
|
||||||
assert ml_runner_with_run_id.inference_checkpoint is None
|
assert training_runner_hello_world_with_checkpoint.inference_checkpoint is None
|
||||||
assert not ml_runner_with_run_id.container.model._on_extra_val_epoch
|
assert not training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch
|
||||||
|
|
||||||
ml_runner_with_run_id.init_inference()
|
training_runner_hello_world_with_checkpoint.init_inference()
|
||||||
|
|
||||||
expected_ckpt = str(ml_runner_with_run_id.checkpoint_handler.trained_weights_path)
|
expected_ckpt = str(training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path)
|
||||||
expected_ckpt = expected_ckpt if run_inference_only else str(mock_checkpoint)
|
expected_ckpt = expected_ckpt if run_inference_only else str(mock_checkpoint)
|
||||||
assert ml_runner_with_run_id.inference_checkpoint == expected_ckpt
|
assert training_runner_hello_world_with_checkpoint.inference_checkpoint == expected_ckpt
|
||||||
|
|
||||||
assert hasattr(ml_runner_with_run_id.container.model, "on_run_extra_validation_epoch")
|
assert hasattr(
|
||||||
assert ml_runner_with_run_id.container.model._on_extra_val_epoch == run_extra_val_epoch
|
training_runner_hello_world_with_checkpoint.container.model, "on_run_extra_validation_epoch"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch
|
||||||
|
== run_extra_val_epoch
|
||||||
|
)
|
||||||
|
|
||||||
mock_create_trainer.assert_called_once()
|
mock_create_trainer.assert_called_once()
|
||||||
assert ml_runner_with_run_id.trainer == mock_trainer
|
assert training_runner_hello_world_with_checkpoint.trainer == mock_trainer
|
||||||
assert ml_runner_with_run_id.container.max_num_gpus == max_num_gpus_inf
|
assert training_runner_hello_world_with_checkpoint.container.max_num_gpus == max_num_gpus_inf
|
||||||
assert mock_create_trainer.call_args[1]["container"] == ml_runner_with_run_id.container
|
assert (
|
||||||
|
mock_create_trainer.call_args[1]["container"]
|
||||||
|
== training_runner_hello_world_with_checkpoint.container
|
||||||
|
)
|
||||||
assert mock_create_trainer.call_args[1]["num_nodes"] == 1
|
assert mock_create_trainer.call_args[1]["num_nodes"] == 1
|
||||||
assert mock_create_trainer.call_args[1]["mlflow_run_for_logging"] == expected_mlflow_run_id
|
assert mock_create_trainer.call_args[1]["mlflow_run_for_logging"] == expected_mlflow_run_id
|
||||||
mock_get_data_module.assert_called_once()
|
mock_get_data_module.assert_called_once()
|
||||||
assert ml_runner_with_run_id.data_module == "dummy_data_module"
|
assert training_runner_hello_world_with_checkpoint.data_module == "dummy_data_module"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||||
def test_run_validation(
|
def test_run_validation(
|
||||||
run_extra_val_epoch: bool, run_inference_only: bool, ml_runner_with_run_id: MLRunner, caplog: LogCaptureFixture
|
run_extra_val_epoch: bool,
|
||||||
|
run_inference_only: bool,
|
||||||
|
training_runner_hello_world_with_checkpoint: TrainingRunner,
|
||||||
|
caplog: LogCaptureFixture,
|
||||||
) -> None:
|
) -> None:
|
||||||
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
|
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
|
||||||
ml_runner_with_run_id.container.run_inference_only = run_inference_only
|
training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only
|
||||||
ml_runner_with_run_id.init_training()
|
training_runner_hello_world_with_checkpoint.init_training()
|
||||||
mock_datamodule = MagicMock()
|
mock_datamodule = MagicMock()
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
create_mlflow_trash_folder(training_runner_hello_world_with_checkpoint)
|
||||||
with patch.object(ml_runner_with_run_id.container, "get_data_module", return_value=mock_datamodule):
|
with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer:
|
||||||
|
with patch.object(
|
||||||
|
training_runner_hello_world_with_checkpoint.container, "get_data_module", return_value=mock_datamodule
|
||||||
|
):
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
||||||
ml_runner_with_run_id.init_inference()
|
training_runner_hello_world_with_checkpoint.init_inference()
|
||||||
assert ml_runner_with_run_id.trainer == mock_trainer
|
assert training_runner_hello_world_with_checkpoint.trainer == mock_trainer
|
||||||
mock_trainer.validate = Mock()
|
mock_trainer.validate = Mock()
|
||||||
ml_runner_with_run_id.run_validation()
|
training_runner_hello_world_with_checkpoint.run_validation()
|
||||||
if run_extra_val_epoch or run_inference_only:
|
if run_extra_val_epoch or run_inference_only:
|
||||||
mock_trainer.validate.assert_called_once()
|
mock_trainer.validate.assert_called_once()
|
||||||
assert mock_trainer.validate.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint
|
assert (
|
||||||
|
mock_trainer.validate.call_args[1]["ckpt_path"]
|
||||||
|
== training_runner_hello_world_with_checkpoint.inference_checkpoint
|
||||||
|
)
|
||||||
assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule
|
assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule
|
||||||
else:
|
else:
|
||||||
assert "Skipping extra validation" in caplog.messages[-1]
|
assert "Skipping extra validation" in caplog.messages[-1]
|
||||||
|
@ -332,12 +395,12 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
|
||||||
container = HelloWorld()
|
container = HelloWorld()
|
||||||
container.create_lightning_module_and_store()
|
container.create_lightning_module_and_store()
|
||||||
container.run_extra_val_epoch = True
|
container.run_extra_val_epoch = True
|
||||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||||
runner.setup()
|
runner.setup()
|
||||||
runner.checkpoint_handler.additional_training_done()
|
runner.checkpoint_handler.additional_training_done()
|
||||||
runner.container.outputs_folder.mkdir(parents=True, exist_ok=True)
|
runner.container.outputs_folder.mkdir(parents=True, exist_ok=True)
|
||||||
with patch.object(container, "get_data_module"):
|
with patch.object(container, "get_data_module"):
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
with patch("health_ml.runner_base.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
||||||
with patch.object(runner.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
|
with patch.object(runner.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
|
||||||
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
|
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
|
||||||
runner.init_inference()
|
runner.init_inference()
|
||||||
|
@ -346,34 +409,34 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
|
||||||
assert "Hook `on_run_extra_validation_epoch` is not implemented" in latest_message
|
assert "Hook `on_run_extra_validation_epoch` is not implemented" in latest_message
|
||||||
|
|
||||||
|
|
||||||
def test_run_inference(ml_runner_with_container: MLRunner, regression_datadir: Path) -> None:
|
def test_run_inference(training_runner_hello_world: TrainingRunner, regression_datadir: Path) -> None:
|
||||||
"""
|
"""
|
||||||
Test that run_inference gets called as expected.
|
Test that run_inference gets called as expected.
|
||||||
"""
|
"""
|
||||||
ml_runner_with_container.container.max_num_gpus = 0
|
training_runner_hello_world.container.max_num_gpus = 0
|
||||||
|
|
||||||
def _expected_files_exist() -> bool:
|
def _expected_files_exist() -> bool:
|
||||||
output_dir = ml_runner_with_container.container.outputs_folder
|
output_dir = training_runner_hello_world.container.outputs_folder
|
||||||
if not output_dir.is_dir():
|
if not output_dir.is_dir():
|
||||||
return False
|
return False
|
||||||
expected_files = ["test_mse.txt", "test_mae.txt"]
|
expected_files = [TEST_MSE_FILE, TEST_MAE_FILE]
|
||||||
return all([(output_dir / p).exists() for p in expected_files])
|
return all([(output_dir / p).exists() for p in expected_files])
|
||||||
|
|
||||||
expected_ckpt_path = ml_runner_with_container.container.outputs_folder / "checkpoints" / "last.ckpt"
|
expected_ckpt_path = training_runner_hello_world.container.outputs_folder / "checkpoints" / "last.ckpt"
|
||||||
assert not expected_ckpt_path.exists()
|
assert not expected_ckpt_path.exists()
|
||||||
# update the container to look for test data at this location
|
# update the container to look for test data at this location
|
||||||
ml_runner_with_container.container.local_dataset_dir = regression_datadir
|
training_runner_hello_world.container.local_dataset_dir = regression_datadir
|
||||||
assert not _expected_files_exist()
|
assert not _expected_files_exist()
|
||||||
|
|
||||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
actual_train_ckpt_path = training_runner_hello_world.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||||
assert actual_train_ckpt_path is None
|
assert actual_train_ckpt_path is None
|
||||||
|
|
||||||
ml_runner_with_container.run()
|
training_runner_hello_world.run()
|
||||||
|
|
||||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
actual_train_ckpt_path = training_runner_hello_world.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||||
assert actual_train_ckpt_path == expected_ckpt_path
|
assert actual_train_ckpt_path == expected_ckpt_path
|
||||||
|
|
||||||
actual_test_ckpt_path = ml_runner_with_container.checkpoint_handler.get_checkpoint_to_test()
|
actual_test_ckpt_path = training_runner_hello_world.checkpoint_handler.get_checkpoint_to_test()
|
||||||
assert actual_test_ckpt_path == expected_ckpt_path
|
assert actual_test_ckpt_path == expected_ckpt_path
|
||||||
assert actual_test_ckpt_path.is_file()
|
assert actual_test_ckpt_path.is_file()
|
||||||
# After training, the outputs directory should now exist and contain the 2 error files
|
# After training, the outputs directory should now exist and contain the 2 error files
|
||||||
|
@ -382,25 +445,25 @@ def test_run_inference(ml_runner_with_container: MLRunner, regression_datadir: P
|
||||||
|
|
||||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||||
def test_run(run_inference_only: bool, run_extra_val_epoch: bool, ml_runner_with_container: MLRunner) -> None:
|
def test_run(run_inference_only: bool, run_extra_val_epoch: bool, training_runner_hello_world: TrainingRunner) -> None:
|
||||||
"""Test that model runner gets called"""
|
"""Test that model runner gets called"""
|
||||||
ml_runner_with_container.container.run_inference_only = run_inference_only
|
training_runner_hello_world.container.run_inference_only = run_inference_only
|
||||||
ml_runner_with_container.container.run_extra_val_epoch = run_extra_val_epoch
|
training_runner_hello_world.container.run_extra_val_epoch = run_extra_val_epoch
|
||||||
ml_runner_with_container.setup()
|
training_runner_hello_world.setup()
|
||||||
assert not ml_runner_with_container.checkpoint_handler.has_continued_training
|
assert not training_runner_hello_world.checkpoint_handler.has_continued_training
|
||||||
|
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
with patch("health_ml.runner_base.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
||||||
with patch.multiple(
|
with patch.multiple(
|
||||||
ml_runner_with_container,
|
training_runner_hello_world,
|
||||||
checkpoint_handler=mock.DEFAULT,
|
checkpoint_handler=mock.DEFAULT,
|
||||||
run_training=mock.DEFAULT,
|
run_training=mock.DEFAULT,
|
||||||
run_validation=mock.DEFAULT,
|
run_validation=mock.DEFAULT,
|
||||||
run_inference=mock.DEFAULT,
|
run_inference=mock.DEFAULT,
|
||||||
end_training=mock.DEFAULT,
|
end_training=mock.DEFAULT,
|
||||||
) as mocks:
|
) as mocks:
|
||||||
ml_runner_with_container.run()
|
training_runner_hello_world.run()
|
||||||
assert ml_runner_with_container.container.has_custom_test_step()
|
assert training_runner_hello_world.container.has_custom_test_step()
|
||||||
assert ml_runner_with_container._has_setup_run
|
assert training_runner_hello_world._has_setup_run
|
||||||
assert mocks["end_training"] != run_inference_only
|
assert mocks["end_training"] != run_inference_only
|
||||||
assert mocks["run_training"].called != run_inference_only
|
assert mocks["run_training"].called != run_inference_only
|
||||||
mocks["run_validation"].assert_called_once()
|
mocks["run_validation"].assert_called_once()
|
||||||
|
@ -408,45 +471,59 @@ def test_run(run_inference_only: bool, run_extra_val_epoch: bool, ml_runner_with
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||||
def test_run_inference_only(run_extra_val_epoch: bool, ml_runner_with_run_id: MLRunner) -> None:
|
def test_run_inference_only(
|
||||||
|
run_extra_val_epoch: bool, training_runner_hello_world_with_checkpoint: TrainingRunner
|
||||||
|
) -> None:
|
||||||
"""Test inference only mode. Validation should be run regardless of run_extra_val_epoch status."""
|
"""Test inference only mode. Validation should be run regardless of run_extra_val_epoch status."""
|
||||||
ml_runner_with_run_id.container.run_inference_only = True
|
training_runner_hello_world_with_checkpoint.container.run_inference_only = True
|
||||||
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
|
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
|
||||||
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
|
assert training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path
|
||||||
mock_datamodule = MagicMock()
|
mock_datamodule = MagicMock()
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer:
|
||||||
with patch.object(ml_runner_with_run_id.container, "get_data_module", return_value=mock_datamodule):
|
with patch.object(
|
||||||
|
training_runner_hello_world_with_checkpoint.container, "get_data_module", return_value=mock_datamodule
|
||||||
|
):
|
||||||
with patch.multiple(
|
with patch.multiple(
|
||||||
ml_runner_with_run_id,
|
training_runner_hello_world_with_checkpoint,
|
||||||
run_training=mock.DEFAULT,
|
run_training=mock.DEFAULT,
|
||||||
) as mocks:
|
) as mocks:
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
||||||
ml_runner_with_run_id.run()
|
training_runner_hello_world_with_checkpoint.run()
|
||||||
mock_create_trainer.assert_called_once()
|
mock_create_trainer.assert_called_once()
|
||||||
mocks["run_training"].assert_not_called()
|
mocks["run_training"].assert_not_called()
|
||||||
|
|
||||||
mock_trainer.validate.assert_called_once()
|
mock_trainer.validate.assert_called_once()
|
||||||
assert mock_trainer.validate.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint
|
assert (
|
||||||
|
mock_trainer.validate.call_args[1]["ckpt_path"]
|
||||||
|
== training_runner_hello_world_with_checkpoint.inference_checkpoint
|
||||||
|
)
|
||||||
assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule
|
assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule
|
||||||
mock_trainer.test.assert_called_once()
|
mock_trainer.test.assert_called_once()
|
||||||
assert mock_trainer.test.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint
|
assert (
|
||||||
|
mock_trainer.test.call_args[1]["ckpt_path"]
|
||||||
|
== training_runner_hello_world_with_checkpoint.inference_checkpoint
|
||||||
|
)
|
||||||
assert mock_trainer.test.call_args[1]["datamodule"] == mock_datamodule
|
assert mock_trainer.test.call_args[1]["datamodule"] == mock_datamodule
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||||
def test_resume_training_from_run_id(run_extra_val_epoch: bool, ml_runner_with_run_id: MLRunner) -> None:
|
def test_resume_training_from_run_id(
|
||||||
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
|
run_extra_val_epoch: bool, training_runner_hello_world_with_checkpoint: TrainingRunner
|
||||||
ml_runner_with_run_id.container.max_num_gpus = 0
|
) -> None:
|
||||||
ml_runner_with_run_id.container.max_epochs += 10
|
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
|
||||||
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
|
training_runner_hello_world_with_checkpoint.container.max_num_gpus = 0
|
||||||
|
training_runner_hello_world_with_checkpoint.container.max_epochs += 10
|
||||||
|
assert training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(mock_trainer, MagicMock())):
|
with patch("health_ml.runner_base.create_lightning_trainer", return_value=(mock_trainer, MagicMock())):
|
||||||
with patch.object(ml_runner_with_run_id.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
|
with patch.object(
|
||||||
with patch.object(ml_runner_with_run_id, "run_inference") as mock_run_inference:
|
training_runner_hello_world_with_checkpoint.container, "get_checkpoint_to_test"
|
||||||
with patch("health_ml.run_ml.cleanup_checkpoints") as mock_cleanup_ckpt:
|
) as mock_get_checkpoint_to_test:
|
||||||
|
with patch.object(training_runner_hello_world_with_checkpoint, "run_inference") as mock_run_inference:
|
||||||
|
with patch("health_ml.training_runner.cleanup_checkpoints") as mock_cleanup_ckpt:
|
||||||
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
|
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
|
||||||
ml_runner_with_run_id.run()
|
training_runner_hello_world_with_checkpoint.run()
|
||||||
mock_get_checkpoint_to_test.assert_called_once()
|
mock_get_checkpoint_to_test.assert_called_once()
|
||||||
mock_cleanup_ckpt.assert_called_once()
|
mock_cleanup_ckpt.assert_called_once()
|
||||||
assert mock_trainer.validate.called == run_extra_val_epoch
|
assert mock_trainer.validate.called == run_extra_val_epoch
|
||||||
|
@ -457,12 +534,12 @@ def test_model_weights_when_resume_training() -> None:
|
||||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||||
container = HelloWorld()
|
container = HelloWorld()
|
||||||
container.max_num_gpus = 0
|
container.max_num_gpus = 0
|
||||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
|
||||||
container.resume_training = True
|
container.resume_training = True
|
||||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||||
runner.setup()
|
runner.setup()
|
||||||
assert runner.checkpoint_handler.trained_weights_path.is_file() # type: ignore
|
assert runner.checkpoint_handler.trained_weights_path.is_file() # type: ignore
|
||||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer:
|
||||||
mock_create_trainer.return_value = MagicMock(), MagicMock()
|
mock_create_trainer.return_value = MagicMock(), MagicMock()
|
||||||
runner.init_training()
|
runner.init_training()
|
||||||
mock_create_trainer.assert_called_once()
|
mock_create_trainer.assert_called_once()
|
||||||
|
@ -482,14 +559,14 @@ def test_log_on_vm(log_from_vm: bool) -> None:
|
||||||
tag = f"test_log_on_vm [{log_from_vm}]"
|
tag = f"test_log_on_vm [{log_from_vm}]"
|
||||||
container.tag = tag
|
container.tag = tag
|
||||||
container.log_from_vm = log_from_vm
|
container.log_from_vm = log_from_vm
|
||||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||||
# When logging to AzureML, need to provide the unit test AML workspace.
|
# When logging to AzureML, need to provide the unit test AML workspace.
|
||||||
# When not logging to AzureML, no workspace (and no authentication) should be needed.
|
# When not logging to AzureML, no workspace (and no authentication) should be needed.
|
||||||
if log_from_vm:
|
if log_from_vm:
|
||||||
with patch("health_azure.utils.get_workspace", return_value=DEFAULT_WORKSPACE.workspace):
|
with patch("health_azure.utils.get_workspace", return_value=DEFAULT_WORKSPACE.workspace):
|
||||||
runner.run()
|
runner.run_and_cleanup()
|
||||||
else:
|
else:
|
||||||
runner.run()
|
runner.run_and_cleanup()
|
||||||
# The PL trainer object is created in the init_training method.
|
# The PL trainer object is created in the init_training method.
|
||||||
# Check that the AzureML logger is set up correctly.
|
# Check that the AzureML logger is set up correctly.
|
||||||
assert runner.trainer is not None
|
assert runner.trainer is not None
|
||||||
|
@ -536,13 +613,15 @@ def test_get_mlflow_run_id_from_trainer() -> None:
|
||||||
assert run_id == mock_run_id
|
assert run_id == mock_run_id
|
||||||
|
|
||||||
|
|
||||||
def test_inference_only_metrics_correctness(ml_runner_with_run_id: MLRunner, regression_datadir: Path) -> None:
|
def test_inference_only_metrics_correctness(
|
||||||
ml_runner_with_run_id.container.run_inference_only = True
|
training_runner_hello_world_with_checkpoint: TrainingRunner, regression_datadir: Path
|
||||||
ml_runner_with_run_id.container.local_dataset_dir = regression_datadir
|
) -> None:
|
||||||
ml_runner_with_run_id.run()
|
training_runner_hello_world_with_checkpoint.container.run_inference_only = True
|
||||||
with open(ml_runner_with_run_id.container.outputs_folder / "test_mse.txt") as f:
|
training_runner_hello_world_with_checkpoint.container.local_dataset_dir = regression_datadir
|
||||||
|
training_runner_hello_world_with_checkpoint.run()
|
||||||
|
with open(training_runner_hello_world_with_checkpoint.container.outputs_folder / TEST_MSE_FILE) as f:
|
||||||
mse = float(f.readlines()[0])
|
mse = float(f.readlines()[0])
|
||||||
assert isclose(mse, 0.010806690901517868, abs_tol=1e-3)
|
assert isclose(mse, 0.010806690901517868, abs_tol=1e-3)
|
||||||
with open(ml_runner_with_run_id.container.outputs_folder / "test_mae.txt") as f:
|
with open(training_runner_hello_world_with_checkpoint.container.outputs_folder / TEST_MAE_FILE) as f:
|
||||||
mae = float(f.readlines()[0])
|
mae = float(f.readlines()[0])
|
||||||
assert isclose(mae, 0.08260975033044815, abs_tol=1e-3)
|
assert isclose(mae, 0.08260975033044815, abs_tol=1e-3)
|
Загрузка…
Ссылка в новой задаче