зеркало из https://github.com/microsoft/hi-ml.git
ENH: Ability to evaluate model on a new dataset (#859)
Renamed the `MLRunner` to `TrainingRunner`. Introduced a new `EvalRunner`. Shared functionality moved from `MLRunner` to `RunnerBase`.
This commit is contained in:
Родитель
2cc2511efd
Коммит
e73c92649d
|
@ -7,6 +7,7 @@ use of these features:
|
|||
- Working with different models in the same codebase, and selecting one by name
|
||||
- Distributed training in AzureML
|
||||
- Logging via AzureML's native capabilities
|
||||
- Evaluation of the trained model on new datasets
|
||||
|
||||
This can be used by invoking the hi-ml runner and providing the name of the container class, like this:
|
||||
`himl-runner --model=MyContainer`.
|
||||
|
@ -215,7 +216,13 @@ and returns a tuple containing the Optimizer and LRScheduler objects
|
|||
## Run inference with a pretrained model
|
||||
|
||||
You can use the hi-ml-runner in inference mode only by switching the `--run_inference_only` flag on and specifying
|
||||
the model weights by setting `--src_checkpoint` argument that supports three types of checkpoints:
|
||||
the model weights by setting `--src_checkpoint` argument. With this flag, the model will be evaluated on the test set
|
||||
only. There is also an option for evaluating the model an a full dataset, described further below.
|
||||
|
||||
### Specifying the checkpoint to use
|
||||
|
||||
When running inference on a trained model, you need to provide a model checkpoint that should be used. This is done via
|
||||
the `--src_checkpoint` argument. This supports three types of checkpoints:
|
||||
|
||||
- A local path where the checkpoint is stored `--src_checkpoint=local/path/to/my_checkpoint/model.ckpt`
|
||||
- A remote URL from where to download the weights `--src_checkpoint=https://my_checkpoint_url.com/model.ckpt`
|
||||
|
@ -228,13 +235,69 @@ the model weights by setting `--src_checkpoint` argument that supports three typ
|
|||
|
||||
Refer to [Checkpoints Utils](checkpoints.md) for more details on how checkpoints are parsed.
|
||||
|
||||
Running the following command line will run inference using `MyContainer` model with weights from the checkpoint saved
|
||||
### Running inference on the test set
|
||||
|
||||
When supplying the flag `--run_inference_only` on the commandline, no model training will be run, and only inference on the
|
||||
test set will be done:
|
||||
|
||||
- The model weights will be loaded from the location specified by `--src_checkpoint`
|
||||
- A PyTorch Lightining `Trainer` object will be instantiated.
|
||||
- The test set will be read out from the data module specified by the `get_data_module` method of the
|
||||
`LightningContainer` object.
|
||||
- The model will be evaluated on the test set, by running `trainer.test`. Any special logic to use during the test step
|
||||
will need to be added to the model's `test_step` method.
|
||||
|
||||
Running the following command line will run inference using the `MyContainer` model with weights from the checkpoint saved
|
||||
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
|
||||
|
||||
```bash
|
||||
himl-runner --model=Mycontainer --run_inference_only --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt
|
||||
```
|
||||
|
||||
### Running inference on a full dataset
|
||||
|
||||
When supplying the flag `--mode=eval_full` on the commandline, no model training will be run, and the model will be
|
||||
evaluated on a dataset different from the training/validation/test dataset. This dataset is loaded via the
|
||||
`get_eval_data_module` method of the container.
|
||||
|
||||
- The model weights will be loaded from the location specified by `--src_checkpoint`
|
||||
- A PyTorch Lightining `Trainer` object will be instantiated.
|
||||
- The test set will be read out from the data module specified by the `get_eval_data_module` method of the
|
||||
`LightningContainer` object. The data module itself can read data from a mounted Azure dataset, which will be made
|
||||
availabe for the container at the path `self.local_datasets`. In a typical use-case, all the data in that dataset will be
|
||||
put into the `test_dataloader` field of the data module.
|
||||
- The model will be evaluated on the test set, by running `trainer.test`. Any special logic to use during the test step
|
||||
will need to added to the model's `test_step` method.
|
||||
|
||||
Running the following command line will run inference using the `MyContainer` model with weights from the checkpoint saved
|
||||
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
|
||||
|
||||
```bash
|
||||
himl-runner --model=Mycontainer --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt --mode=eval_full --azure_datasets=my_new_dataset
|
||||
```
|
||||
|
||||
The example code snippet here shows how to add a method that reads the inference dataset. In this example, we assume
|
||||
that the `MyDataModule` class has an argument `splits` that specifies the fraction of data to go into the training,
|
||||
validation, and test data loaders.
|
||||
|
||||
```python
|
||||
class MyContainer(LightningContainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.azure_datasets = ["folder_name_in_azure_blob_storage"]
|
||||
self.local_datasets = [Path("/some/local/path")]
|
||||
self.max_epochs = 42
|
||||
|
||||
def create_model(self) -> LightningModule:
|
||||
return MyLightningModel()
|
||||
|
||||
def get_data_module(self) -> LightningDataModule:
|
||||
return MyDataModule(root_path=self.local_dataset, splits=(0.7, 0.2, 0.1))
|
||||
|
||||
def get_eval_data_module(self) -> LightningDataModule:
|
||||
return MyDataModule(root_path=self.local_dataset, splits=(0.0, 0.0, 1.0))
|
||||
```
|
||||
|
||||
## Resume training from a given checkpoint
|
||||
|
||||
Analogously, one can resume training by setting `--src_checkpoint` and `--resume_training` to train a model longer.
|
||||
|
@ -298,5 +361,6 @@ seeds on a local machine, other than by manually starting runs with `--random_se
|
|||
the limit that AML sets for snapshots. Solution: check for cache files, log files or other files that are not
|
||||
necessary for running your experiment and add them to a `.amlignore` file in the root directory. Alternatively, you
|
||||
can see Azure ML documentation for instructions on increasing this limit, although it will make your jobs slower.
|
||||
2. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload
|
||||
the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML.
|
||||
|
||||
1. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload
|
||||
the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML.
|
||||
|
|
|
@ -61,4 +61,9 @@
|
|||
"${workspaceFolder}/src",
|
||||
"${workspaceFolder}/testazure"
|
||||
],
|
||||
"workbench.colorCustomizations": {
|
||||
"activityBar.background": "#5D091F",
|
||||
"titleBar.activeBackground": "#830D2B",
|
||||
"titleBar.activeForeground": "#FFFCFC"
|
||||
},
|
||||
}
|
||||
|
|
|
@ -524,7 +524,7 @@ def create_dataset_configs(
|
|||
count = num_azure
|
||||
elif num_azure == 0 and num_mount == 0:
|
||||
# No datasets in Azure at all: This is possible for runs that for example download their own data from the web.
|
||||
# There can be any number of local datasets, but we are not checking that. In MLRunner.setup, there is a check
|
||||
# There can be any number of local datasets, but we are not checking that. In TrainingRunner.setup, there is a check
|
||||
# that leaves local datasets intact if there are no Azure datasets.
|
||||
return []
|
||||
else:
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from health_ml.run_ml import MLRunner
|
||||
from health_ml.training_runner import TrainingRunner
|
||||
from health_ml.runner import Runner
|
||||
|
||||
|
||||
__all__ = ["MLRunner", "Runner"]
|
||||
__all__ = ["TrainingRunner", "Runner"]
|
||||
|
|
|
@ -16,6 +16,9 @@ from torch.utils.data import DataLoader, Dataset
|
|||
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
|
||||
TEST_MSE_FILE = "test_mse.txt"
|
||||
TEST_MAE_FILE = "test_mae.txt"
|
||||
|
||||
|
||||
def _create_1d_regression_dataset(n: int = 100, seed: int = 0) -> torch.Tensor:
|
||||
"""Creates a simple 1-D dataset of a noisy linear function.
|
||||
|
@ -79,10 +82,10 @@ class HelloWorldDataModule(LightningDataModule):
|
|||
A data module that gives the training, validation and test data for a simple 1-dim regression task.
|
||||
"""
|
||||
|
||||
def __init__(self, crossval_count: int, crossval_index: Optional[int]) -> None:
|
||||
def __init__(self, crossval_count: int, crossval_index: Optional[int] = None, seed: int = 0) -> None:
|
||||
super().__init__()
|
||||
n_total = 200
|
||||
xy = _create_1d_regression_dataset(n=n_total)
|
||||
xy = _create_1d_regression_dataset(n=n_total, seed=seed)
|
||||
n_test = 40
|
||||
n_val = 50
|
||||
self.test = HelloWorldDataset(xy=xy[:n_test])
|
||||
|
@ -229,8 +232,8 @@ class HelloRegression(LightningModule):
|
|||
for example writing aggregate metrics to disk.
|
||||
"""
|
||||
average_mse = torch.mean(torch.stack(self.test_mse))
|
||||
Path("test_mse.txt").write_text(str(average_mse.item()))
|
||||
Path("test_mae.txt").write_text(str(self.test_mae.compute().item()))
|
||||
Path(TEST_MSE_FILE).write_text(str(average_mse.item()))
|
||||
Path(TEST_MAE_FILE).write_text(str(self.test_mae.compute().item()))
|
||||
self.log("test_mse", average_mse, on_epoch=True, on_step=False)
|
||||
|
||||
def on_run_extra_validation_epoch(self) -> None:
|
||||
|
@ -266,6 +269,11 @@ class HelloWorld(LightningContainer):
|
|||
# datamodule must carry out appropriate splitting of the data.
|
||||
return HelloWorldDataModule(crossval_count=self.crossval_count, crossval_index=self.crossval_index)
|
||||
|
||||
# This method is optional. Override it to supply a data module to use for evaluating the model on a new dataset.
|
||||
# Only the test data loader from the data module will be used.
|
||||
def get_eval_data_module(self) -> LightningDataModule:
|
||||
return HelloWorldDataModule(crossval_count=1, seed=1)
|
||||
|
||||
def get_callbacks(self) -> List[Callback]:
|
||||
if self.save_checkpoint:
|
||||
return [
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
from pytorch_lightning import LightningDataModule
|
||||
|
||||
from health_azure.logging import logging_section
|
||||
from health_ml.runner_base import RunnerBase
|
||||
|
||||
|
||||
class EvalRunner(RunnerBase):
|
||||
"""A class to run the evaluation of a model on a new dataset. The initialization logic is taken from the base
|
||||
class `RunnerBase`.
|
||||
"""
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Checks if the fields of the class are set up correctly."""
|
||||
if self.container.src_checkpoint is None or self.container.src_checkpoint.checkpoint == "":
|
||||
raise ValueError(
|
||||
"To use model evaluation, you need to provide a checkpoint to use, via the --src_checkpoint argument."
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Start the core workflow that the class implements: Initialize a PL Trainer object and use that to run
|
||||
inference on the inference dataset."""
|
||||
self.container.outputs_folder.mkdir(exist_ok=True, parents=True)
|
||||
self.init_inference()
|
||||
with logging_section("Model inference"):
|
||||
self.run_inference()
|
||||
|
||||
def get_data_module(self) -> LightningDataModule:
|
||||
"""Reads the evaluation data module from the underlying container."""
|
||||
return self.container.get_eval_data_module()
|
|
@ -10,6 +10,11 @@ class DebugDDPOptions(Enum):
|
|||
DETAIL = "DETAIL"
|
||||
|
||||
|
||||
class RunnerMode(Enum):
|
||||
TRAIN = "train"
|
||||
EVAL_FULL = "eval_full"
|
||||
|
||||
|
||||
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
|
||||
|
||||
|
||||
|
@ -82,3 +87,10 @@ class ExperimentConfig(param.Parameterized):
|
|||
doc="The maximum runtime that is allowed for this job in AzureML. This is given as a floating"
|
||||
"point number with a string suffix s, m, h, d for seconds, minutes, hours, day. Examples: '3.5h', '2d'",
|
||||
)
|
||||
mode: str = param.ClassSelector(
|
||||
class_=RunnerMode,
|
||||
default=RunnerMode.TRAIN,
|
||||
doc=f"The mode to run the experiment in. Can be one of '{RunnerMode.TRAIN}' (training and evaluation on the "
|
||||
f"test set), or '{RunnerMode.EVAL_FULL}' for evaluation on the full dataset specified by the "
|
||||
"'get_eval_data_module' method of the container.",
|
||||
)
|
||||
|
|
|
@ -65,6 +65,17 @@ class LightningContainer(WorkflowParams, DatasetParams, OutputParams, TrainerPar
|
|||
"""
|
||||
return None # type: ignore
|
||||
|
||||
def get_eval_data_module(self) -> LightningDataModule:
|
||||
"""
|
||||
Gets the data that is used when evaluating the model on a new dataset.
|
||||
This data module should read datasets from the self.local_datasets folder or download from a web location.
|
||||
Only the test dataloader is used, hence the method needs to put all data into the test dataloader, rather
|
||||
than splitting into train/val/test.
|
||||
|
||||
:return: A LightningDataModule
|
||||
"""
|
||||
return None # type: ignore
|
||||
|
||||
def get_trainer_arguments(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Gets additional parameters that will be passed on to the PyTorch Lightning trainer.
|
||||
|
|
|
@ -31,9 +31,9 @@ from health_azure.datasets import create_dataset_configs # noqa: E402
|
|||
from health_azure.himl import DEFAULT_DOCKER_BASE_IMAGE, OUTPUT_FOLDER # noqa: E402
|
||||
from health_azure.logging import logging_to_stdout # noqa: E402
|
||||
from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
|
||||
from health_azure.utils import (
|
||||
from health_azure.utils import ( # noqa: E402
|
||||
ENV_LOCAL_RANK,
|
||||
ENV_NODE_RANK, # noqa: E402
|
||||
ENV_NODE_RANK,
|
||||
get_workspace,
|
||||
get_ml_client,
|
||||
is_local_rank_zero,
|
||||
|
@ -43,9 +43,11 @@ from health_azure.utils import (
|
|||
is_global_rank_zero,
|
||||
)
|
||||
|
||||
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402
|
||||
from health_ml.eval_runner import EvalRunner # noqa: E402
|
||||
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig, RunnerMode # noqa: E402
|
||||
from health_ml.lightning_container import LightningContainer # noqa: E402
|
||||
from health_ml.run_ml import MLRunner # noqa: E402
|
||||
from health_ml.runner_base import RunnerBase # noqa: E402
|
||||
from health_ml.training_runner import TrainingRunner # noqa: E402
|
||||
from health_ml.utils import fixed_paths # noqa: E402
|
||||
from health_ml.utils.logging import ConsoleAndFileOutput # noqa: E402
|
||||
from health_ml.utils.common_utils import check_conda_environment, choose_conda_env_file, is_linux # noqa: E402
|
||||
|
@ -101,8 +103,8 @@ class Runner:
|
|||
self.project_root = project_root
|
||||
self.experiment_config: ExperimentConfig = ExperimentConfig()
|
||||
self.lightning_container: LightningContainer = None # type: ignore
|
||||
# This field stores the MLRunner object that has been created in the most recent call to the run() method.
|
||||
self.ml_runner: Optional[MLRunner] = None
|
||||
# This field stores the TrainingRunner object that has been created in the most recent call to the run() method.
|
||||
self.ml_runner: Optional[RunnerBase] = None
|
||||
|
||||
def parse_and_load_model(self) -> ParserResult:
|
||||
"""
|
||||
|
@ -322,15 +324,26 @@ class Runner:
|
|||
assert azure_run_info.run is not None
|
||||
azure_run_info.run.set_tags(self.additional_run_tags(sys.argv[1:]))
|
||||
|
||||
# Set environment variables for multi-node training if needed. This function will terminate early
|
||||
# if it detects that it is not in a multi-node environment.
|
||||
if self.experiment_config.num_nodes > 1:
|
||||
set_environment_variables_for_multi_node()
|
||||
self.ml_runner = MLRunner(
|
||||
experiment_config=self.experiment_config, container=self.lightning_container, project_root=self.project_root
|
||||
)
|
||||
self.ml_runner.setup(azure_run_info)
|
||||
self.ml_runner.run()
|
||||
if self.experiment_config.mode == RunnerMode.TRAIN:
|
||||
# Set environment variables for multi-node training if needed. This function will terminate early
|
||||
# if it detects that it is not in a multi-node environment.
|
||||
if self.experiment_config.num_nodes > 1:
|
||||
set_environment_variables_for_multi_node()
|
||||
self.ml_runner = TrainingRunner(
|
||||
experiment_config=self.experiment_config,
|
||||
container=self.lightning_container,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
elif self.experiment_config.mode == RunnerMode.EVAL_FULL:
|
||||
self.ml_runner = EvalRunner(
|
||||
experiment_config=self.experiment_config,
|
||||
container=self.lightning_container,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown mode {self.experiment_config.mode}")
|
||||
self.ml_runner.validate()
|
||||
self.ml_runner.run_and_cleanup(azure_run_info)
|
||||
|
||||
|
||||
def run(project_root: Path) -> Tuple[LightningContainer, AzureRunInfo]:
|
||||
|
|
|
@ -0,0 +1,247 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from azureml.core import Run
|
||||
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
|
||||
|
||||
from health_azure import AzureRunInfo
|
||||
from health_azure.utils import (
|
||||
PARENT_RUN_CONTEXT,
|
||||
RUN_CONTEXT,
|
||||
create_aml_run_object,
|
||||
create_run_recovery_id,
|
||||
is_running_in_azure_ml,
|
||||
)
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.model_trainer import create_lightning_trainer
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.common_utils import (
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME,
|
||||
RUN_RECOVERY_FROM_ID_KEY_NAME,
|
||||
RUN_RECOVERY_ID_KEY,
|
||||
change_working_directory,
|
||||
seed_monai_if_available,
|
||||
)
|
||||
from health_ml.utils.lightning_loggers import StoringLogger, get_mlflow_run_id_from_trainer
|
||||
from health_ml.utils.type_annotations import PathOrString
|
||||
|
||||
|
||||
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
|
||||
"""
|
||||
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
|
||||
to a Path instance. If it does not exist, raise a FileNotFoundError.
|
||||
|
||||
:param local_dataset: The dataset folder to check.
|
||||
:return: The local_dataset argument, converted to a Path.
|
||||
"""
|
||||
expected_dir = Path(local_dataset)
|
||||
if not expected_dir.is_dir():
|
||||
raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.")
|
||||
logging.info(f"Model will use the local dataset provided in {expected_dir}")
|
||||
return expected_dir
|
||||
|
||||
|
||||
class RunnerBase:
|
||||
"""
|
||||
A base class with operations that are shared between the training/test runner and the evaluation-only runner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, experiment_config: ExperimentConfig, container: LightningContainer, project_root: Optional[Path] = None
|
||||
) -> None:
|
||||
"""
|
||||
Driver class to run an ML experiment. Note that the project root argument MUST be supplied when using hi-ml
|
||||
as a package!
|
||||
|
||||
:param experiment_config: The ExperimentConfig object to use for training.
|
||||
:param container: The LightningContainer object to use for training.
|
||||
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
|
||||
it is crucial when using hi-ml as a package or submodule!
|
||||
"""
|
||||
self.container = container
|
||||
self.experiment_config = experiment_config
|
||||
self.container.num_nodes = self.experiment_config.num_nodes
|
||||
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
|
||||
self.storing_logger: Optional[StoringLogger] = None
|
||||
self._has_setup_run = False
|
||||
self.checkpoint_handler = CheckpointHandler(
|
||||
container=self.container, project_root=self.project_root, run_context=RUN_CONTEXT
|
||||
)
|
||||
self.trainer: Optional[Trainer] = None
|
||||
self.azureml_run_for_logging: Optional[Run] = None
|
||||
self.mlflow_run_for_logging: Optional[str] = None
|
||||
# This is passed to trainer.validate and trainer.test in inference mode
|
||||
self.inference_checkpoint: Optional[str] = None
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
Checks if all arguments and settings of the object are correct.
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup_azureml(self) -> None:
|
||||
"""
|
||||
Execute setup steps that are specific to AzureML.
|
||||
"""
|
||||
if PARENT_RUN_CONTEXT is not None:
|
||||
# Set metadata for the run in AzureML if running in a Hyperdrive job.
|
||||
run_tags_parent = PARENT_RUN_CONTEXT.get_tags()
|
||||
tags_to_copy = [
|
||||
"tag",
|
||||
"model_name",
|
||||
"execution_mode",
|
||||
"recovered_from",
|
||||
"friendly_name",
|
||||
"build_number",
|
||||
"build_user",
|
||||
RUN_RECOVERY_FROM_ID_KEY_NAME,
|
||||
]
|
||||
new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy}
|
||||
new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT)
|
||||
new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed())
|
||||
RUN_CONTEXT.set_tags(new_tags)
|
||||
|
||||
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
|
||||
"""
|
||||
Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual
|
||||
Lightning modules.
|
||||
|
||||
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
|
||||
This can be missing when running in unit tests, where the local dataset paths are already populated.
|
||||
"""
|
||||
if self._has_setup_run:
|
||||
return
|
||||
if azure_run_info:
|
||||
# Set up the paths to the datasets. azure_run_info already has all necessary information, using either
|
||||
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
|
||||
# This must happen before container setup because that could already read datasets.
|
||||
logging.info("Setting tags from parent run.")
|
||||
input_datasets = azure_run_info.input_datasets
|
||||
logging.info(f"Setting the following datasets as local datasets: {input_datasets}")
|
||||
if len(input_datasets) > 0:
|
||||
local_datasets: List[Path] = []
|
||||
for i, dataset in enumerate(input_datasets):
|
||||
if dataset is None:
|
||||
raise ValueError(f"Invalid setup: The dataset at index {i} is None")
|
||||
local_datasets.append(check_dataset_folder_exists(dataset))
|
||||
self.container.local_datasets = local_datasets # type: ignore
|
||||
# Ensure that we use fixed seeds before initializing the PyTorch models.
|
||||
# MONAI needs a separate method to make all transforms deterministic by default
|
||||
seed = self.container.get_effective_random_seed()
|
||||
seed_monai_if_available(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
# Creating the folder structure must happen before the LightningModule is created, because the output
|
||||
# parameters of the container will be copied into the module.
|
||||
self.container.create_filesystem(self.project_root)
|
||||
|
||||
# configure recovery container if provided
|
||||
self.checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
|
||||
# Create an AzureML run for logging if running outside AzureML.
|
||||
self.create_logger()
|
||||
|
||||
self.container.setup()
|
||||
self.container.create_lightning_module_and_store()
|
||||
self._has_setup_run = True
|
||||
|
||||
if is_running_in_azure_ml():
|
||||
self.setup_azureml()
|
||||
|
||||
def create_logger(self) -> None:
|
||||
"""
|
||||
Create an AzureML run for logging if running outside AzureML. This run will be used for metrics logging
|
||||
during both training and inference. We can't rely on the automatically generated run inside the AzureMLLogger
|
||||
class because two of those logger objects will be created, so training and inference metrics would be logged
|
||||
in different runs.
|
||||
"""
|
||||
if self.container.log_from_vm:
|
||||
run = create_aml_run_object(experiment_name=self.container.effective_experiment_name)
|
||||
# Display name should already be set when creating the Run object, but in some scenarios this
|
||||
# does not happen. Hence, set it again.
|
||||
run.display_name = self.container.tag if self.container.tag else None
|
||||
self.azureml_run_for_logging = run
|
||||
|
||||
def get_data_module(self) -> LightningDataModule:
|
||||
"""
|
||||
Reads the datamodule that should be used for training or valuation from the container. This must be
|
||||
overridden in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_trainer_for_inference(self) -> None:
|
||||
"""Set the runner's PL Trainer object that should be used when running inference on the validation or test set.
|
||||
We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||
internally, which replicates some samples to make sure all devices have the same batch size in case of
|
||||
uneven inputs which biases the results."""
|
||||
mlflow_run_id = get_mlflow_run_id_from_trainer(self.trainer)
|
||||
self.container.max_num_gpus = self.container.max_num_gpus_inference
|
||||
self.trainer, _ = create_lightning_trainer(
|
||||
container=self.container,
|
||||
num_nodes=1,
|
||||
azureml_run_for_logging=self.azureml_run_for_logging,
|
||||
mlflow_run_for_logging=mlflow_run_id,
|
||||
)
|
||||
|
||||
def init_inference(self) -> None:
|
||||
"""Prepare the runner for inference on validation set, test set, or a full dataset.
|
||||
The following steps are performed:
|
||||
|
||||
1. Get the checkpoint to use for inference. This is either the checkpoint from the last training epoch or the
|
||||
one specified in src_checkpoint argument.
|
||||
|
||||
2. Create a new trainer instance for inference. This is necessary because the trainer is created with a single
|
||||
device in contrast to training that uses DDP if multiple GPUs are available.
|
||||
|
||||
3. Create a new data module instance for inference to account for any requested changes in the dataloading
|
||||
parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch.
|
||||
"""
|
||||
self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test())
|
||||
self.set_trainer_for_inference()
|
||||
self.data_module = self.get_data_module()
|
||||
|
||||
def run_inference(self) -> None:
|
||||
"""Run inference on the test set for all models. This is done by calling the LightningModule.test_step method.
|
||||
If the LightningModule.test_step method is not overridden, then this method does nothing. The cwd is changed to
|
||||
the outputs folder so that the model can write to current working directory, and still everything is put into
|
||||
the right place in AzureML (there, only the contents of the "outputs" folder is treated as a result file).
|
||||
"""
|
||||
if self.container.has_custom_test_step():
|
||||
logging.info("Running inference via the LightningModule.test_step method")
|
||||
with change_working_directory(self.container.outputs_folder):
|
||||
assert self.trainer, "Trainer should be initialized before inference. Call self.init_inference()."
|
||||
_ = self.trainer.test(
|
||||
self.container.model, datamodule=self.data_module, ckpt_path=self.inference_checkpoint
|
||||
)
|
||||
else:
|
||||
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run the training or evaluation. This method must be overridden in subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run_and_cleanup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
|
||||
"""
|
||||
Run the training or evaluation via `self.run` and cleanup afterwards.
|
||||
|
||||
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
|
||||
This can be missing when running in unit tests, where the local dataset paths are already populated.
|
||||
"""
|
||||
self.setup(azure_run_info)
|
||||
try:
|
||||
self.run()
|
||||
finally:
|
||||
if self.azureml_run_for_logging is not None:
|
||||
try:
|
||||
self.azureml_run_for_logging.complete()
|
||||
except Exception as ex:
|
||||
logging.error("Failed to complete AzureML run: %s", ex)
|
|
@ -6,14 +6,11 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from azureml.core import Run
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning import LightningDataModule, seed_everything
|
||||
|
||||
from health_azure import AzureRunInfo
|
||||
from health_azure.logging import logging_section, print_message_with_rank_pid
|
||||
from health_azure.utils import (
|
||||
ENV_GLOBAL_RANK,
|
||||
|
@ -22,140 +19,24 @@ from health_azure.utils import (
|
|||
ENV_OMPI_COMM_WORLD_RANK,
|
||||
PARENT_RUN_CONTEXT,
|
||||
RUN_CONTEXT,
|
||||
create_aml_run_object,
|
||||
create_run_recovery_id,
|
||||
get_metrics_for_hyperdrive_run,
|
||||
get_metrics_for_run,
|
||||
is_global_rank_zero,
|
||||
is_local_rank_zero,
|
||||
is_running_in_azure_ml,
|
||||
)
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
|
||||
from health_ml.utils import fixed_paths
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.runner_base import RunnerBase
|
||||
from health_ml.utils.checkpoint_utils import cleanup_checkpoints
|
||||
from health_ml.utils.common_utils import (
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME,
|
||||
RUN_RECOVERY_FROM_ID_KEY_NAME,
|
||||
RUN_RECOVERY_ID_KEY,
|
||||
change_working_directory,
|
||||
seed_monai_if_available,
|
||||
)
|
||||
from health_ml.utils.lightning_loggers import StoringLogger, get_mlflow_run_id_from_trainer
|
||||
from health_ml.utils.regression_test_utils import REGRESSION_TEST_METRICS_FILENAME, compare_folders_and_run_outputs
|
||||
from health_ml.utils.type_annotations import PathOrString
|
||||
|
||||
|
||||
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
|
||||
"""
|
||||
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
|
||||
to a Path instance. If it does not exist, raise a FileNotFoundError.
|
||||
|
||||
:param local_dataset: The dataset folder to check.
|
||||
:return: The local_dataset argument, converted to a Path.
|
||||
"""
|
||||
expected_dir = Path(local_dataset)
|
||||
if not expected_dir.is_dir():
|
||||
raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.")
|
||||
logging.info(f"Model training will use the local dataset provided in {expected_dir}")
|
||||
return expected_dir
|
||||
|
||||
|
||||
class MLRunner:
|
||||
def __init__(
|
||||
self, experiment_config: ExperimentConfig, container: LightningContainer, project_root: Optional[Path] = None
|
||||
) -> None:
|
||||
"""
|
||||
Driver class to run a ML experiment. Note that the project root argument MUST be supplied when using hi-ml
|
||||
as a package!
|
||||
|
||||
:param experiment_config: The ExperimentConfig object to use for training.
|
||||
:param container: The LightningContainer object to use for training.
|
||||
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
|
||||
it is crucial when using hi-ml as a package or submodule!
|
||||
"""
|
||||
self.container = container
|
||||
self.experiment_config = experiment_config
|
||||
self.container.num_nodes = self.experiment_config.num_nodes
|
||||
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
|
||||
self.storing_logger: Optional[StoringLogger] = None
|
||||
self._has_setup_run = False
|
||||
self.checkpoint_handler = CheckpointHandler(
|
||||
container=self.container, project_root=self.project_root, run_context=RUN_CONTEXT
|
||||
)
|
||||
self.trainer: Optional[Trainer] = None
|
||||
self.azureml_run_for_logging: Optional[Run] = None
|
||||
self.mlflow_run_for_logging: Optional[str] = None
|
||||
self.inference_checkpoint: Optional[str] = None # Passed to trainer.validate and trainer.test in inference mode
|
||||
|
||||
def set_run_tags_from_parent(self) -> None:
|
||||
"""
|
||||
Set metadata for the run
|
||||
"""
|
||||
assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run."
|
||||
run_tags_parent = PARENT_RUN_CONTEXT.get_tags()
|
||||
tags_to_copy = [
|
||||
"tag",
|
||||
"model_name",
|
||||
"execution_mode",
|
||||
"recovered_from",
|
||||
"friendly_name",
|
||||
"build_number",
|
||||
"build_user",
|
||||
RUN_RECOVERY_FROM_ID_KEY_NAME,
|
||||
]
|
||||
new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy}
|
||||
new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT)
|
||||
new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed())
|
||||
RUN_CONTEXT.set_tags(new_tags)
|
||||
|
||||
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
|
||||
"""
|
||||
Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual
|
||||
Lightning modules.
|
||||
|
||||
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
|
||||
This can be missing when running in unit tests, where the local dataset paths are already populated.
|
||||
"""
|
||||
if self._has_setup_run:
|
||||
return
|
||||
if azure_run_info:
|
||||
# Set up the paths to the datasets. azure_run_info already has all necessary information, using either
|
||||
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
|
||||
# This must happen before container setup because that could already read datasets.
|
||||
input_datasets = azure_run_info.input_datasets
|
||||
logging.info(f"Setting the following datasets as local datasets: {input_datasets}")
|
||||
if len(input_datasets) > 0:
|
||||
local_datasets: List[Path] = []
|
||||
for i, dataset in enumerate(input_datasets):
|
||||
if dataset is None:
|
||||
raise ValueError(f"Invalid setup: The dataset at index {i} is None")
|
||||
local_datasets.append(check_dataset_folder_exists(dataset))
|
||||
self.container.local_datasets = local_datasets # type: ignore
|
||||
# Ensure that we use fixed seeds before initializing the PyTorch models.
|
||||
# MONAI needs a separate method to make all transforms deterministic by default
|
||||
seed = self.container.get_effective_random_seed()
|
||||
seed_monai_if_available(seed)
|
||||
seed_everything(seed)
|
||||
|
||||
# Creating the folder structure must happen before the LightningModule is created, because the output
|
||||
# parameters of the container will be copied into the module.
|
||||
self.container.create_filesystem(self.project_root)
|
||||
|
||||
# configure recovery container if provided
|
||||
self.checkpoint_handler.download_recovery_checkpoints_or_weights() # type: ignore
|
||||
|
||||
self.container.setup()
|
||||
self.container.create_lightning_module_and_store()
|
||||
self._has_setup_run = True
|
||||
|
||||
is_offline_run = not is_running_in_azure_ml(RUN_CONTEXT)
|
||||
# Get the AzureML context in which the script is running
|
||||
if not is_offline_run and PARENT_RUN_CONTEXT is not None:
|
||||
logging.info("Setting tags from parent run.")
|
||||
self.set_run_tags_from_parent()
|
||||
class TrainingRunner(RunnerBase):
|
||||
def get_data_module(self) -> LightningDataModule:
|
||||
return self.container.get_data_module()
|
||||
|
||||
def get_multiple_trainloader_mode(self) -> str:
|
||||
# Workaround for a bug in PL 1.5.5: We need to pass the cycle mode for the training data as a trainer argument
|
||||
|
@ -190,18 +71,7 @@ class MLRunner:
|
|||
seed_everything(self.container.get_effective_random_seed(), workers=True)
|
||||
|
||||
# Get the container's datamodule
|
||||
self.data_module = self.container.get_data_module()
|
||||
|
||||
# Create an AzureML run for logging if running outside AzureML. This run will be used for metrics logging
|
||||
# during both training and inference. We can't rely on the automatically generated run inside the AzureMLLogger
|
||||
# class because two of those logger objects will be created, so training and inference metrics would be logged
|
||||
# in different runs.
|
||||
if self.container.log_from_vm:
|
||||
run = create_aml_run_object(experiment_name=self.container.effective_experiment_name)
|
||||
# Display name should already be set when creating the Run object, but in some scenarios this
|
||||
# does not happen. Hence, set it again.
|
||||
run.display_name = self.container.tag if self.container.tag else None
|
||||
self.azureml_run_for_logging = run
|
||||
self.data_module = self.get_data_module()
|
||||
|
||||
if not self.container.run_inference_only:
|
||||
checkpoint_path_for_recovery = self.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
|
@ -275,36 +145,6 @@ class MLRunner:
|
|||
return self.container.crossval_index == 0
|
||||
return True
|
||||
|
||||
def set_trainer_for_inference(self) -> None:
|
||||
"""Set the runner's PL Trainer object that should be used when running inference on the validation or test set.
|
||||
We run inference on a single device because distributed strategies such as DDP use DistributedSampler
|
||||
internally, which replicates some samples to make sure all devices have the same batch size in case of
|
||||
uneven inputs which biases the results."""
|
||||
mlflow_run_id = get_mlflow_run_id_from_trainer(self.trainer)
|
||||
self.container.max_num_gpus = self.container.max_num_gpus_inference
|
||||
self.trainer, _ = create_lightning_trainer(
|
||||
container=self.container,
|
||||
num_nodes=1,
|
||||
azureml_run_for_logging=self.azureml_run_for_logging,
|
||||
mlflow_run_for_logging=mlflow_run_id,
|
||||
)
|
||||
|
||||
def init_inference(self) -> None:
|
||||
"""Prepare the runner for inference: validation or test. The following steps are performed:
|
||||
1. Get the checkpoint to use for inference. This is either the checkpoint from the last training epoch or the
|
||||
one specified in src_checkpoint argument.
|
||||
2. If the container has a run_extra_val_epoch method, call it to run an extra validation epoch.
|
||||
3. Create a new trainer instance for inference. This is necessary because the trainer is created with a single
|
||||
device in contrast to training that uses DDP if multiple GPUs are available.
|
||||
4. Create a new data module instance for inference to account for any requested changes in the dataloading
|
||||
parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch.
|
||||
"""
|
||||
self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test())
|
||||
if self.container.run_extra_val_epoch:
|
||||
self.container.on_run_extra_validation_epoch()
|
||||
self.set_trainer_for_inference()
|
||||
self.data_module = self.container.get_data_module()
|
||||
|
||||
def run_training(self) -> None:
|
||||
"""
|
||||
The main training loop. It creates the Pytorch model based on the configuration options passed in,
|
||||
|
@ -321,6 +161,16 @@ class MLRunner:
|
|||
assert logger is not None
|
||||
logger.finalize('success')
|
||||
|
||||
def init_inference(self) -> None:
|
||||
"""
|
||||
Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint,
|
||||
initializes the PL Trainer object, and chooses the right data module. Afterwards, the hook for running
|
||||
inference on the validation set is run (`LightningContainer.on_run_extra_validation_epoch`)
|
||||
"""
|
||||
super().init_inference()
|
||||
if self.container.run_extra_val_epoch:
|
||||
self.container.on_run_extra_validation_epoch()
|
||||
|
||||
def run_validation(self) -> None:
|
||||
"""Run validation on the validation set for all models to save time/memory consuming outputs. This is done in
|
||||
inference only mode or when the user has requested an extra validation epoch. The cwd is changed to the outputs
|
||||
|
@ -334,22 +184,6 @@ class MLRunner:
|
|||
else:
|
||||
logging.info("Skipping extra validation because the user has not requested it.")
|
||||
|
||||
def run_inference(self) -> None:
|
||||
"""Run inference on the test set for all models. This is done by calling the LightningModule.test_step method.
|
||||
If the LightningModule.test_step method is not overridden, then this method does nothing. The cwd is changed to
|
||||
the outputs folder so that the model can write to current working directory, and still everything is put into
|
||||
the right place in AzureML (there, only the contents of the "outputs" folder is treated as a result file).
|
||||
"""
|
||||
if self.container.has_custom_test_step():
|
||||
logging.info("Running inference via the LightningModule.test_step method")
|
||||
with change_working_directory(self.container.outputs_folder):
|
||||
assert self.trainer, "Trainer should be initialized before inference. Call self.init_inference()."
|
||||
_ = self.trainer.test(
|
||||
self.container.model, datamodule=self.data_module, ckpt_path=self.inference_checkpoint
|
||||
)
|
||||
else:
|
||||
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")
|
||||
|
||||
def run_regression_test(self) -> None:
|
||||
if self.container.regression_test_folder:
|
||||
with logging_section("Regression Test"):
|
||||
|
@ -392,32 +226,23 @@ class MLRunner:
|
|||
"""
|
||||
Driver function to run a ML experiment
|
||||
"""
|
||||
self.setup()
|
||||
try:
|
||||
self.init_training()
|
||||
self.init_training()
|
||||
|
||||
if not self.container.run_inference_only:
|
||||
# Backup the environment variables in case we need to run a second training in the unit tests.
|
||||
environ_before_training = dict(os.environ)
|
||||
if not self.container.run_inference_only:
|
||||
# Backup the environment variables in case we need to run a second training in the unit tests.
|
||||
environ_before_training = dict(os.environ)
|
||||
|
||||
with logging_section("Model training"):
|
||||
self.run_training()
|
||||
with logging_section("Model training"):
|
||||
self.run_training()
|
||||
|
||||
self.end_training(environ_before_training)
|
||||
self.end_training(environ_before_training)
|
||||
|
||||
self.init_inference()
|
||||
self.init_inference()
|
||||
|
||||
with logging_section("Model validation"):
|
||||
self.run_validation()
|
||||
with logging_section("Model validation"):
|
||||
self.run_validation()
|
||||
|
||||
with logging_section("Model inference"):
|
||||
self.run_inference()
|
||||
with logging_section("Model inference"):
|
||||
self.run_inference()
|
||||
|
||||
self.run_regression_test()
|
||||
|
||||
finally:
|
||||
if self.azureml_run_for_logging is not None:
|
||||
try:
|
||||
self.azureml_run_for_logging.complete()
|
||||
except Exception as ex:
|
||||
logging.error("Failed to complete AzureML run: %s", ex)
|
||||
self.run_regression_test()
|
|
@ -1,4 +1,15 @@
|
|||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from health_ml.runner import Runner
|
||||
from health_ml.utils import health_ml_package_setup
|
||||
|
||||
# Reduce logging noise in DEBUG mode
|
||||
health_ml_package_setup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(tmp_path: Path) -> Runner:
|
||||
"""A test fixture that creates a Runner object in a temporary folder."""
|
||||
|
||||
return Runner(project_root=tmp_path)
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
from pathlib import Path
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from health_ml import TrainingRunner
|
||||
from health_ml.configs.hello_world import (
|
||||
TEST_MAE_FILE,
|
||||
TEST_MSE_FILE,
|
||||
HelloWorld,
|
||||
HelloWorldDataModule,
|
||||
)
|
||||
from health_ml.eval_runner import EvalRunner
|
||||
from health_ml.experiment_config import ExperimentConfig, RunnerMode
|
||||
from health_ml.runner import Runner
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from testhiml.test_training_runner import training_runner_hello_world
|
||||
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
|
||||
|
||||
|
||||
hello_world_checkpoint = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
|
||||
|
||||
def test_eval_runner_no_checkpoint(mock_runner: Runner) -> None:
|
||||
"""Test of the evaluation mode fails if no checkpoint source is provided"""
|
||||
arguments = ["", f"--model=HelloWorld", f"--mode={RunnerMode.EVAL_FULL.value}"]
|
||||
with pytest.raises(ValueError, match="To use model evaluation, you need to provide a checkpoint to use"):
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
|
||||
|
||||
def test_eval_runner_end_to_end(mock_runner: Runner) -> None:
|
||||
"""Test the end-to-end integration of the EvalRunner class into the overall Runner"""
|
||||
arguments = [
|
||||
"",
|
||||
f"--model=HelloWorld",
|
||||
f"--mode={RunnerMode.EVAL_FULL.value}",
|
||||
f"--src_checkpoint={hello_world_checkpoint}",
|
||||
]
|
||||
with patch("health_ml.training_runner.TrainingRunner.run_and_cleanup") as mock_training_run:
|
||||
with patch.object(sys, "argv", arguments):
|
||||
mock_runner.run()
|
||||
# The training runner should not be invoked
|
||||
mock_training_run.assert_not_called()
|
||||
# The eval runner should have been invoked. The test step writes two files with metrics, check that
|
||||
# they exist
|
||||
output_folder = mock_runner.lightning_container.outputs_folder
|
||||
for file_name in [TEST_MSE_FILE, TEST_MAE_FILE]:
|
||||
assert (output_folder / file_name).exists()
|
||||
|
||||
|
||||
def test_eval_runner_methods_called(tmp_path: Path) -> None:
|
||||
"""Test if the eval runner uses the right data module from the HelloWorld model"""
|
||||
container = HelloWorld()
|
||||
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
|
||||
eval_runner = EvalRunner(
|
||||
container=container, experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL), project_root=tmp_path
|
||||
)
|
||||
with patch("health_ml.configs.hello_world.HelloWorld.get_eval_data_module") as mock_get_data_module:
|
||||
mock_get_data_module.return_value = HelloWorldDataModule(crossval_count=1, seed=1)
|
||||
eval_runner.run_and_cleanup()
|
||||
mock_get_data_module.assert_called_once_with()
|
||||
|
||||
|
||||
def test_eval_runner_no_extra_validation_epoch_called(tmp_path: Path) -> None:
|
||||
"""
|
||||
Ensure that the eval runner does not invoke the hook the extra validation epoch that is used by the training runner.
|
||||
"""
|
||||
container = HelloWorld()
|
||||
container.run_extra_val_epoch = True
|
||||
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
|
||||
eval_runner = EvalRunner(
|
||||
container=container, experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL), project_root=tmp_path
|
||||
)
|
||||
with patch("health_ml.configs.hello_world.HelloRegression.on_run_extra_validation_epoch") as mock_hook:
|
||||
eval_runner.run_and_cleanup()
|
||||
mock_hook.assert_not_called()
|
|
@ -13,7 +13,7 @@ import pytest
|
|||
|
||||
from health_azure.utils import create_aml_run_object
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.run_ml import MLRunner
|
||||
from health_ml.training_runner import TrainingRunner
|
||||
from health_ml.configs.hello_world import HelloWorld
|
||||
from health_ml.utils.regression_test_utils import (
|
||||
CONTENTS_MISMATCH,
|
||||
|
@ -54,7 +54,7 @@ def test_regression_test() -> None:
|
|||
"""
|
||||
container = HelloWorld()
|
||||
container.regression_test_folder = Path(str(uuid.uuid4().hex))
|
||||
runner = MLRunner(container=container, experiment_config=ExperimentConfig())
|
||||
runner = TrainingRunner(container=container, experiment_config=ExperimentConfig())
|
||||
runner.setup()
|
||||
with pytest.raises(ValueError) as ex:
|
||||
runner.run()
|
||||
|
|
|
@ -19,7 +19,7 @@ from health_azure import AzureRunInfo, DatasetConfig
|
|||
from health_azure.himl import OUTPUT_FOLDER
|
||||
from health_azure.utils import ENV_LOCAL_RANK, ENV_NODE_RANK
|
||||
from health_azure.paths import ENVIRONMENT_YAML_FILE_NAME
|
||||
from health_ml.configs.hello_world import HelloWorld # type: ignore
|
||||
from health_ml.configs.hello_world import TEST_MAE_FILE, TEST_MSE_FILE, HelloWorld # type: ignore
|
||||
from health_ml.deep_learning_config import WorkflowParams
|
||||
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, DebugDDPOptions
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
|
@ -28,13 +28,6 @@ from health_ml.utils.common_utils import change_working_directory
|
|||
from health_ml.utils.fixed_paths import repository_root_directory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(tmp_path: Path) -> Runner:
|
||||
"""A test fixture that creates a Runner object in a temporary folder."""
|
||||
|
||||
return Runner(project_root=tmp_path)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def change_working_folder_and_add_environment(tmp_path: Path) -> Generator:
|
||||
# Use a special simplified environment file only for the tests here. Copy that to a temp folder, then let the runner
|
||||
|
@ -407,7 +400,7 @@ def test_run_hello_world(mock_runner: Runner) -> None:
|
|||
# time-consuming auth
|
||||
mock_get_workspace.assert_not_called()
|
||||
# Summary.txt is written at start, the other files during inference
|
||||
expected_files = ["experiment_summary.txt", "test_mae.txt", "test_mse.txt"]
|
||||
expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE]
|
||||
for file in expected_files:
|
||||
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
|
||||
|
||||
|
@ -428,9 +421,8 @@ def test_invalid_profiler(mock_runner: Runner) -> None:
|
|||
invalid_profile = "--pl_profiler=foo"
|
||||
arguments = ["", "--model=HelloWorld", invalid_profile]
|
||||
with patch.object(sys, "argv", arguments):
|
||||
with pytest.raises(ValueError) as ex:
|
||||
with pytest.raises(ValueError, match="Unsupported profiler."):
|
||||
mock_runner.run()
|
||||
assert "Unsupported profiler." in str(ex)
|
||||
|
||||
|
||||
def test_custom_datastore_outside_aml(mock_runner: Runner) -> None:
|
||||
|
|
|
@ -18,34 +18,38 @@ from pytorch_lightning import LightningModule
|
|||
import mlflow
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from health_ml.configs.hello_world import HelloWorld # type: ignore
|
||||
from health_ml.configs.hello_world import TEST_MAE_FILE, TEST_MSE_FILE, HelloWorld # type: ignore
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.run_ml import MLRunner
|
||||
from health_ml.runner_base import RunnerBase
|
||||
from health_ml.training_runner import TrainingRunner
|
||||
from health_ml.utils.checkpoint_handler import CheckpointHandler
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_ml.utils.common_utils import is_gpu_available
|
||||
from health_ml.utils.common_utils import EFFECTIVE_RANDOM_SEED_KEY_NAME, is_gpu_available
|
||||
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger, get_mlflow_run_id_from_trainer
|
||||
from health_azure.utils import ENV_EXPERIMENT_NAME, is_global_rank_zero
|
||||
from testazure.utils_testazure import DEFAULT_WORKSPACE, experiment_for_unittests
|
||||
from testhiml.utils.fixed_paths_for_tests import mock_run_id
|
||||
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
|
||||
|
||||
no_gpu = not is_gpu_available()
|
||||
hello_world_checkpoint = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ml_runner_no_setup() -> MLRunner:
|
||||
@pytest.fixture()
|
||||
def training_runner_no_setup(tmp_path: Path) -> TrainingRunner:
|
||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = LightningContainer(num_epochs=1)
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
container.set_output_to(tmp_path)
|
||||
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ml_runner() -> Generator:
|
||||
@pytest.fixture()
|
||||
def training_runner(tmp_path: Path) -> Generator:
|
||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = LightningContainer(num_epochs=1)
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
container.set_output_to(tmp_path)
|
||||
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||
runner.setup()
|
||||
yield runner
|
||||
output_dir = runner.container.file_system_config.outputs_folder
|
||||
|
@ -54,10 +58,11 @@ def ml_runner() -> Generator:
|
|||
|
||||
|
||||
@pytest.fixture()
|
||||
def ml_runner_with_container() -> Generator:
|
||||
def training_runner_hello_world(tmp_path: Path) -> Generator:
|
||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
container.set_output_to(tmp_path)
|
||||
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||
runner.setup()
|
||||
yield runner
|
||||
output_dir = runner.container.file_system_config.outputs_folder
|
||||
|
@ -66,12 +71,15 @@ def ml_runner_with_container() -> Generator:
|
|||
|
||||
|
||||
@pytest.fixture()
|
||||
def ml_runner_with_run_id() -> Generator:
|
||||
def training_runner_hello_world_with_checkpoint() -> Generator:
|
||||
"""
|
||||
A fixture with a training runner for the HelloWorld model that has a src_checkpoint set.
|
||||
"""
|
||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.save_checkpoint = True
|
||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
|
||||
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||
runner.setup()
|
||||
yield runner
|
||||
output_dir = runner.container.file_system_config.outputs_folder
|
||||
|
@ -92,14 +100,27 @@ def regression_datadir(tmp_path: Path) -> Generator:
|
|||
shutil.rmtree(tmp_path)
|
||||
|
||||
|
||||
def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None:
|
||||
def create_mlflow_trash_folder(runner: RunnerBase) -> None:
|
||||
"""Create a trash folder where MLFlow expects its deleted runs.
|
||||
This is a workaround for sporadic test failures: When reading out the run_id, MLFlow checks its own
|
||||
deleted runs folder, but that (or one of its parents) does not exist
|
||||
"""
|
||||
trash_folder = runner.container.outputs_folder / "mlruns" / ".trash"
|
||||
trash_folder.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
def test_ml_runner_setup(training_runner_no_setup: TrainingRunner) -> None:
|
||||
"""Check that all the necessary methods get called during setup"""
|
||||
assert not ml_runner_no_setup._has_setup_run
|
||||
with patch.object(ml_runner_no_setup, "container", spec=LightningContainer) as mock_container:
|
||||
with patch.object(ml_runner_no_setup, "checkpoint_handler", spec=CheckpointHandler) as mock_checkpoint_handler:
|
||||
with patch("health_ml.run_ml.seed_everything") as mock_seed:
|
||||
with patch("health_ml.run_ml.seed_monai_if_available") as mock_seed_monai:
|
||||
ml_runner_no_setup.setup()
|
||||
assert not training_runner_no_setup._has_setup_run
|
||||
with patch.object(training_runner_no_setup, "container", spec=LightningContainer) as mock_container:
|
||||
# Without that, it would try to create a local run object for logging and fail there.
|
||||
mock_container.log_from_vm = False
|
||||
with patch.object(
|
||||
training_runner_no_setup, "checkpoint_handler", spec=CheckpointHandler
|
||||
) as mock_checkpoint_handler:
|
||||
with patch("health_ml.runner_base.seed_everything") as mock_seed:
|
||||
with patch("health_ml.runner_base.seed_monai_if_available") as mock_seed_monai:
|
||||
training_runner_no_setup.setup()
|
||||
mock_container.get_effective_random_seed.assert_called()
|
||||
mock_seed.assert_called_once()
|
||||
mock_seed_monai.assert_called_once()
|
||||
|
@ -107,47 +128,58 @@ def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None:
|
|||
mock_checkpoint_handler.download_recovery_checkpoints_or_weights.assert_called_once()
|
||||
mock_container.setup.assert_called_once()
|
||||
mock_container.create_lightning_module_and_store.assert_called_once()
|
||||
assert ml_runner_no_setup._has_setup_run
|
||||
assert training_runner_no_setup._has_setup_run
|
||||
|
||||
|
||||
def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None:
|
||||
"""Test that set_run_tags_from_parents causes set_tags to get called"""
|
||||
with pytest.raises(AssertionError) as ae:
|
||||
ml_runner.set_run_tags_from_parent()
|
||||
assert "should only be called in a Hyperdrive run" in str(ae)
|
||||
def test_setup_azureml(training_runner: TrainingRunner) -> None:
|
||||
"""Test that setup_azureml causes set_tags to get called when running in Hyperdrive"""
|
||||
with patch("health_ml.runner_base.RUN_CONTEXT") as mock_run_context:
|
||||
training_runner.setup_azureml()
|
||||
# Tests always run outside of a Hyperdrive run. In those cases, no tags should be set on
|
||||
# the current run.
|
||||
mock_run_context.set_tags.assert_not_called()
|
||||
|
||||
with patch("health_ml.run_ml.PARENT_RUN_CONTEXT") as mock_parent_run_context:
|
||||
with patch("health_ml.run_ml.RUN_CONTEXT") as mock_run_context:
|
||||
mock_parent_run_context.get_tags.return_value = {"tag": "dummy_tag"}
|
||||
ml_runner.set_run_tags_from_parent()
|
||||
mock_run_context.set_tags.assert_called()
|
||||
with patch("health_ml.runner_base.PARENT_RUN_CONTEXT") as mock_parent_run_context:
|
||||
# Mock the presence of a parent run, and tags that are present there
|
||||
tag_name = "tag"
|
||||
tag_value = "dummy_tag"
|
||||
mock_parent_run_context.get_tags.return_value = {tag_name: tag_value}
|
||||
training_runner.setup_azureml()
|
||||
# The function should read out tags from the parent run, and set them on the current run
|
||||
mock_parent_run_context.get_tags.assert_called_once_with()
|
||||
mock_run_context.set_tags.assert_called_once()
|
||||
call_args = mock_run_context.set_tags.call_args[0][0]
|
||||
assert tag_name in call_args
|
||||
assert call_args[tag_name] == tag_value
|
||||
assert EFFECTIVE_RANDOM_SEED_KEY_NAME in call_args
|
||||
|
||||
|
||||
def test_get_multiple_trainloader_mode(ml_runner: MLRunner) -> None:
|
||||
multiple_trainloader_mode = ml_runner.get_multiple_trainloader_mode()
|
||||
def test_get_multiple_trainloader_mode(training_runner: TrainingRunner) -> None:
|
||||
training_runner.init_training()
|
||||
multiple_trainloader_mode = training_runner.get_multiple_trainloader_mode()
|
||||
assert multiple_trainloader_mode == "max_size_cycle", "train_loader_cycle_mode is available now, "
|
||||
"`get_multiple_trainloader_mode` workaround can be safely removed."
|
||||
|
||||
|
||||
def _test_init_training(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None:
|
||||
def _test_init_training(run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture) -> None:
|
||||
"""Test that training is initialized correctly"""
|
||||
ml_runner.container.run_inference_only = run_inference_only
|
||||
ml_runner.setup()
|
||||
assert not ml_runner.checkpoint_handler.has_continued_training
|
||||
assert ml_runner.trainer is None
|
||||
assert ml_runner.storing_logger is None
|
||||
training_runner.container.run_inference_only = run_inference_only
|
||||
training_runner.setup()
|
||||
assert not training_runner.checkpoint_handler.has_continued_training
|
||||
assert training_runner.trainer is None
|
||||
assert training_runner.storing_logger is None
|
||||
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.object(ml_runner.container, "get_data_module") as mock_get_data_module:
|
||||
with patch("health_ml.run_ml.write_experiment_summary_file") as mock_write_experiment_summary_file:
|
||||
with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.object(training_runner.container, "get_data_module") as mock_get_data_module:
|
||||
with patch("health_ml.training_runner.write_experiment_summary_file") as mock_write_experiment_summary_file:
|
||||
with patch.object(
|
||||
ml_runner.checkpoint_handler, "get_recovery_or_checkpoint_path_train"
|
||||
training_runner.checkpoint_handler, "get_recovery_or_checkpoint_path_train"
|
||||
) as mock_get_recovery_or_checkpoint_path_train:
|
||||
with patch("health_ml.run_ml.seed_everything") as mock_seed:
|
||||
with patch("health_ml.training_runner.seed_everything") as mock_seed:
|
||||
mock_create_trainer.return_value = MagicMock(), MagicMock()
|
||||
mock_get_recovery_or_checkpoint_path_train.return_value = "dummy_path"
|
||||
|
||||
ml_runner.init_training()
|
||||
training_runner.init_training()
|
||||
|
||||
# Make sure write_experiment_summary_file is only called on rank 0
|
||||
if is_global_rank_zero():
|
||||
|
@ -157,47 +189,51 @@ def _test_init_training(run_inference_only: bool, ml_runner: MLRunner, caplog: L
|
|||
|
||||
# Make sure seed is set correctly with workers=True
|
||||
mock_seed.assert_called_once()
|
||||
assert mock_seed.call_args[0][0] == ml_runner.container.get_effective_random_seed()
|
||||
assert mock_seed.call_args[0][0] == training_runner.container.get_effective_random_seed()
|
||||
assert mock_seed.call_args[1]["workers"]
|
||||
|
||||
mock_get_data_module.assert_called_once()
|
||||
assert ml_runner.data_module is not None
|
||||
assert training_runner.data_module is not None
|
||||
|
||||
if not run_inference_only:
|
||||
mock_get_recovery_or_checkpoint_path_train.assert_called_once()
|
||||
# Validate that the trainer is created correctly
|
||||
assert mock_create_trainer.call_args[1]["resume_from_checkpoint"] == "dummy_path"
|
||||
assert ml_runner.storing_logger is not None
|
||||
assert ml_runner.trainer is not None
|
||||
assert training_runner.storing_logger is not None
|
||||
assert training_runner.trainer is not None
|
||||
assert "Environment variables:" in caplog.messages[-1]
|
||||
else:
|
||||
assert ml_runner.trainer is None
|
||||
assert ml_runner.storing_logger is None
|
||||
assert training_runner.trainer is None
|
||||
assert training_runner.storing_logger is None
|
||||
mock_get_recovery_or_checkpoint_path_train.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||
def test_init_training_cpu(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None:
|
||||
def test_init_training_cpu(
|
||||
run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test that training is initialized correctly"""
|
||||
ml_runner.container.max_num_gpus = 0
|
||||
_test_init_training(run_inference_only, ml_runner, caplog)
|
||||
training_runner.container.max_num_gpus = 0
|
||||
_test_init_training(run_inference_only, training_runner, caplog)
|
||||
|
||||
|
||||
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||
def test_init_training_gpu(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None:
|
||||
def test_init_training_gpu(
|
||||
run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test that training is initialized correctly in DDP mode"""
|
||||
_test_init_training(run_inference_only, ml_runner, caplog)
|
||||
_test_init_training(run_inference_only, training_runner, caplog)
|
||||
|
||||
|
||||
def test_run_training() -> None:
|
||||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||
|
||||
with patch.object(container, "get_data_module") as mock_get_data_module:
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer:
|
||||
runner.setup()
|
||||
mock_trainer = MagicMock()
|
||||
mock_storing_logger = MagicMock()
|
||||
|
@ -225,17 +261,17 @@ def test_end_training(max_num_gpus_inf: int) -> None:
|
|||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.max_num_gpus_inference = max_num_gpus_inf
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||
|
||||
with patch.object(container, "get_data_module"):
|
||||
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
||||
with patch("health_ml.training_runner.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
||||
runner.setup()
|
||||
runner.init_training()
|
||||
runner.run_training()
|
||||
|
||||
with patch.object(runner.checkpoint_handler, "additional_training_done") as mock_additional_training_done:
|
||||
with patch.object(runner, "after_ddp_cleanup") as mock_after_ddp_cleanup:
|
||||
with patch("health_ml.run_ml.cleanup_checkpoints") as mock_cleanup_ckpt:
|
||||
with patch("health_ml.training_runner.cleanup_checkpoints") as mock_cleanup_ckpt:
|
||||
environ_before_training = {"old": "environ"}
|
||||
runner.end_training(environ_before_training=environ_before_training)
|
||||
mock_additional_training_done.assert_called_once()
|
||||
|
@ -251,71 +287,98 @@ def test_end_training(max_num_gpus_inf: int) -> None:
|
|||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||
def test_init_inference(
|
||||
run_inference_only: bool, run_extra_val_epoch: bool, max_num_gpus_inf: int, ml_runner_with_run_id: MLRunner
|
||||
run_inference_only: bool,
|
||||
run_extra_val_epoch: bool,
|
||||
max_num_gpus_inf: int,
|
||||
training_runner_hello_world_with_checkpoint: TrainingRunner,
|
||||
) -> None:
|
||||
ml_runner_with_run_id.container.run_inference_only = run_inference_only
|
||||
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
ml_runner_with_run_id.container.max_num_gpus_inference = max_num_gpus_inf
|
||||
assert ml_runner_with_run_id.container.max_num_gpus == -1 # This is the default value of max_num_gpus
|
||||
ml_runner_with_run_id.init_training()
|
||||
training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only
|
||||
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
training_runner_hello_world_with_checkpoint.container.max_num_gpus_inference = max_num_gpus_inf
|
||||
assert (
|
||||
training_runner_hello_world_with_checkpoint.container.max_num_gpus == -1
|
||||
) # This is the default value of max_num_gpus
|
||||
training_runner_hello_world_with_checkpoint.init_training()
|
||||
if run_inference_only:
|
||||
expected_mlflow_run_id = None
|
||||
else:
|
||||
assert ml_runner_with_run_id.trainer is not None
|
||||
expected_mlflow_run_id = ml_runner_with_run_id.trainer.loggers[1].run_id # type: ignore
|
||||
assert training_runner_hello_world_with_checkpoint.trainer is not None
|
||||
create_mlflow_trash_folder(training_runner_hello_world_with_checkpoint)
|
||||
expected_mlflow_run_id = training_runner_hello_world_with_checkpoint.trainer.loggers[1].run_id # type: ignore
|
||||
if not run_inference_only:
|
||||
ml_runner_with_run_id.checkpoint_handler.additional_training_done()
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.object(ml_runner_with_run_id.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
|
||||
with patch.object(ml_runner_with_run_id.container, "get_data_module") as mock_get_data_module:
|
||||
training_runner_hello_world_with_checkpoint.checkpoint_handler.additional_training_done()
|
||||
with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.object(
|
||||
training_runner_hello_world_with_checkpoint.container, "get_checkpoint_to_test"
|
||||
) as mock_get_checkpoint_to_test:
|
||||
with patch.object(
|
||||
training_runner_hello_world_with_checkpoint.container, "get_data_module"
|
||||
) as mock_get_data_module:
|
||||
mock_checkpoint = MagicMock(is_file=MagicMock(return_value=True))
|
||||
mock_get_checkpoint_to_test.return_value = mock_checkpoint
|
||||
mock_trainer = MagicMock()
|
||||
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
||||
mock_get_data_module.return_value = "dummy_data_module"
|
||||
|
||||
assert ml_runner_with_run_id.inference_checkpoint is None
|
||||
assert not ml_runner_with_run_id.container.model._on_extra_val_epoch
|
||||
assert training_runner_hello_world_with_checkpoint.inference_checkpoint is None
|
||||
assert not training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch
|
||||
|
||||
ml_runner_with_run_id.init_inference()
|
||||
training_runner_hello_world_with_checkpoint.init_inference()
|
||||
|
||||
expected_ckpt = str(ml_runner_with_run_id.checkpoint_handler.trained_weights_path)
|
||||
expected_ckpt = str(training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path)
|
||||
expected_ckpt = expected_ckpt if run_inference_only else str(mock_checkpoint)
|
||||
assert ml_runner_with_run_id.inference_checkpoint == expected_ckpt
|
||||
assert training_runner_hello_world_with_checkpoint.inference_checkpoint == expected_ckpt
|
||||
|
||||
assert hasattr(ml_runner_with_run_id.container.model, "on_run_extra_validation_epoch")
|
||||
assert ml_runner_with_run_id.container.model._on_extra_val_epoch == run_extra_val_epoch
|
||||
assert hasattr(
|
||||
training_runner_hello_world_with_checkpoint.container.model, "on_run_extra_validation_epoch"
|
||||
)
|
||||
assert (
|
||||
training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch
|
||||
== run_extra_val_epoch
|
||||
)
|
||||
|
||||
mock_create_trainer.assert_called_once()
|
||||
assert ml_runner_with_run_id.trainer == mock_trainer
|
||||
assert ml_runner_with_run_id.container.max_num_gpus == max_num_gpus_inf
|
||||
assert mock_create_trainer.call_args[1]["container"] == ml_runner_with_run_id.container
|
||||
assert training_runner_hello_world_with_checkpoint.trainer == mock_trainer
|
||||
assert training_runner_hello_world_with_checkpoint.container.max_num_gpus == max_num_gpus_inf
|
||||
assert (
|
||||
mock_create_trainer.call_args[1]["container"]
|
||||
== training_runner_hello_world_with_checkpoint.container
|
||||
)
|
||||
assert mock_create_trainer.call_args[1]["num_nodes"] == 1
|
||||
assert mock_create_trainer.call_args[1]["mlflow_run_for_logging"] == expected_mlflow_run_id
|
||||
mock_get_data_module.assert_called_once()
|
||||
assert ml_runner_with_run_id.data_module == "dummy_data_module"
|
||||
assert training_runner_hello_world_with_checkpoint.data_module == "dummy_data_module"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||
def test_run_validation(
|
||||
run_extra_val_epoch: bool, run_inference_only: bool, ml_runner_with_run_id: MLRunner, caplog: LogCaptureFixture
|
||||
run_extra_val_epoch: bool,
|
||||
run_inference_only: bool,
|
||||
training_runner_hello_world_with_checkpoint: TrainingRunner,
|
||||
caplog: LogCaptureFixture,
|
||||
) -> None:
|
||||
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
ml_runner_with_run_id.container.run_inference_only = run_inference_only
|
||||
ml_runner_with_run_id.init_training()
|
||||
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only
|
||||
training_runner_hello_world_with_checkpoint.init_training()
|
||||
mock_datamodule = MagicMock()
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.object(ml_runner_with_run_id.container, "get_data_module", return_value=mock_datamodule):
|
||||
create_mlflow_trash_folder(training_runner_hello_world_with_checkpoint)
|
||||
with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.object(
|
||||
training_runner_hello_world_with_checkpoint.container, "get_data_module", return_value=mock_datamodule
|
||||
):
|
||||
mock_trainer = MagicMock()
|
||||
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
||||
ml_runner_with_run_id.init_inference()
|
||||
assert ml_runner_with_run_id.trainer == mock_trainer
|
||||
training_runner_hello_world_with_checkpoint.init_inference()
|
||||
assert training_runner_hello_world_with_checkpoint.trainer == mock_trainer
|
||||
mock_trainer.validate = Mock()
|
||||
ml_runner_with_run_id.run_validation()
|
||||
training_runner_hello_world_with_checkpoint.run_validation()
|
||||
if run_extra_val_epoch or run_inference_only:
|
||||
mock_trainer.validate.assert_called_once()
|
||||
assert mock_trainer.validate.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint
|
||||
assert (
|
||||
mock_trainer.validate.call_args[1]["ckpt_path"]
|
||||
== training_runner_hello_world_with_checkpoint.inference_checkpoint
|
||||
)
|
||||
assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule
|
||||
else:
|
||||
assert "Skipping extra validation" in caplog.messages[-1]
|
||||
|
@ -332,12 +395,12 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
|
|||
container = HelloWorld()
|
||||
container.create_lightning_module_and_store()
|
||||
container.run_extra_val_epoch = True
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||
runner.setup()
|
||||
runner.checkpoint_handler.additional_training_done()
|
||||
runner.container.outputs_folder.mkdir(parents=True, exist_ok=True)
|
||||
with patch.object(container, "get_data_module"):
|
||||
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
||||
with patch("health_ml.runner_base.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
||||
with patch.object(runner.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
|
||||
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
|
||||
runner.init_inference()
|
||||
|
@ -346,34 +409,34 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
|
|||
assert "Hook `on_run_extra_validation_epoch` is not implemented" in latest_message
|
||||
|
||||
|
||||
def test_run_inference(ml_runner_with_container: MLRunner, regression_datadir: Path) -> None:
|
||||
def test_run_inference(training_runner_hello_world: TrainingRunner, regression_datadir: Path) -> None:
|
||||
"""
|
||||
Test that run_inference gets called as expected.
|
||||
"""
|
||||
ml_runner_with_container.container.max_num_gpus = 0
|
||||
training_runner_hello_world.container.max_num_gpus = 0
|
||||
|
||||
def _expected_files_exist() -> bool:
|
||||
output_dir = ml_runner_with_container.container.outputs_folder
|
||||
output_dir = training_runner_hello_world.container.outputs_folder
|
||||
if not output_dir.is_dir():
|
||||
return False
|
||||
expected_files = ["test_mse.txt", "test_mae.txt"]
|
||||
expected_files = [TEST_MSE_FILE, TEST_MAE_FILE]
|
||||
return all([(output_dir / p).exists() for p in expected_files])
|
||||
|
||||
expected_ckpt_path = ml_runner_with_container.container.outputs_folder / "checkpoints" / "last.ckpt"
|
||||
expected_ckpt_path = training_runner_hello_world.container.outputs_folder / "checkpoints" / "last.ckpt"
|
||||
assert not expected_ckpt_path.exists()
|
||||
# update the container to look for test data at this location
|
||||
ml_runner_with_container.container.local_dataset_dir = regression_datadir
|
||||
training_runner_hello_world.container.local_dataset_dir = regression_datadir
|
||||
assert not _expected_files_exist()
|
||||
|
||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
actual_train_ckpt_path = training_runner_hello_world.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
assert actual_train_ckpt_path is None
|
||||
|
||||
ml_runner_with_container.run()
|
||||
training_runner_hello_world.run()
|
||||
|
||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
actual_train_ckpt_path = training_runner_hello_world.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
assert actual_train_ckpt_path == expected_ckpt_path
|
||||
|
||||
actual_test_ckpt_path = ml_runner_with_container.checkpoint_handler.get_checkpoint_to_test()
|
||||
actual_test_ckpt_path = training_runner_hello_world.checkpoint_handler.get_checkpoint_to_test()
|
||||
assert actual_test_ckpt_path == expected_ckpt_path
|
||||
assert actual_test_ckpt_path.is_file()
|
||||
# After training, the outputs directory should now exist and contain the 2 error files
|
||||
|
@ -382,25 +445,25 @@ def test_run_inference(ml_runner_with_container: MLRunner, regression_datadir: P
|
|||
|
||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||
@pytest.mark.parametrize("run_inference_only", [True, False])
|
||||
def test_run(run_inference_only: bool, run_extra_val_epoch: bool, ml_runner_with_container: MLRunner) -> None:
|
||||
def test_run(run_inference_only: bool, run_extra_val_epoch: bool, training_runner_hello_world: TrainingRunner) -> None:
|
||||
"""Test that model runner gets called"""
|
||||
ml_runner_with_container.container.run_inference_only = run_inference_only
|
||||
ml_runner_with_container.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
ml_runner_with_container.setup()
|
||||
assert not ml_runner_with_container.checkpoint_handler.has_continued_training
|
||||
training_runner_hello_world.container.run_inference_only = run_inference_only
|
||||
training_runner_hello_world.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
training_runner_hello_world.setup()
|
||||
assert not training_runner_hello_world.checkpoint_handler.has_continued_training
|
||||
|
||||
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
||||
with patch("health_ml.runner_base.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
|
||||
with patch.multiple(
|
||||
ml_runner_with_container,
|
||||
training_runner_hello_world,
|
||||
checkpoint_handler=mock.DEFAULT,
|
||||
run_training=mock.DEFAULT,
|
||||
run_validation=mock.DEFAULT,
|
||||
run_inference=mock.DEFAULT,
|
||||
end_training=mock.DEFAULT,
|
||||
) as mocks:
|
||||
ml_runner_with_container.run()
|
||||
assert ml_runner_with_container.container.has_custom_test_step()
|
||||
assert ml_runner_with_container._has_setup_run
|
||||
training_runner_hello_world.run()
|
||||
assert training_runner_hello_world.container.has_custom_test_step()
|
||||
assert training_runner_hello_world._has_setup_run
|
||||
assert mocks["end_training"] != run_inference_only
|
||||
assert mocks["run_training"].called != run_inference_only
|
||||
mocks["run_validation"].assert_called_once()
|
||||
|
@ -408,45 +471,59 @@ def test_run(run_inference_only: bool, run_extra_val_epoch: bool, ml_runner_with
|
|||
|
||||
|
||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||
def test_run_inference_only(run_extra_val_epoch: bool, ml_runner_with_run_id: MLRunner) -> None:
|
||||
def test_run_inference_only(
|
||||
run_extra_val_epoch: bool, training_runner_hello_world_with_checkpoint: TrainingRunner
|
||||
) -> None:
|
||||
"""Test inference only mode. Validation should be run regardless of run_extra_val_epoch status."""
|
||||
ml_runner_with_run_id.container.run_inference_only = True
|
||||
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
|
||||
training_runner_hello_world_with_checkpoint.container.run_inference_only = True
|
||||
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
assert training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path
|
||||
mock_datamodule = MagicMock()
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.object(ml_runner_with_run_id.container, "get_data_module", return_value=mock_datamodule):
|
||||
with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch.object(
|
||||
training_runner_hello_world_with_checkpoint.container, "get_data_module", return_value=mock_datamodule
|
||||
):
|
||||
with patch.multiple(
|
||||
ml_runner_with_run_id,
|
||||
training_runner_hello_world_with_checkpoint,
|
||||
run_training=mock.DEFAULT,
|
||||
) as mocks:
|
||||
mock_trainer = MagicMock()
|
||||
mock_create_trainer.return_value = mock_trainer, MagicMock()
|
||||
ml_runner_with_run_id.run()
|
||||
training_runner_hello_world_with_checkpoint.run()
|
||||
mock_create_trainer.assert_called_once()
|
||||
mocks["run_training"].assert_not_called()
|
||||
|
||||
mock_trainer.validate.assert_called_once()
|
||||
assert mock_trainer.validate.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint
|
||||
assert (
|
||||
mock_trainer.validate.call_args[1]["ckpt_path"]
|
||||
== training_runner_hello_world_with_checkpoint.inference_checkpoint
|
||||
)
|
||||
assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule
|
||||
mock_trainer.test.assert_called_once()
|
||||
assert mock_trainer.test.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint
|
||||
assert (
|
||||
mock_trainer.test.call_args[1]["ckpt_path"]
|
||||
== training_runner_hello_world_with_checkpoint.inference_checkpoint
|
||||
)
|
||||
assert mock_trainer.test.call_args[1]["datamodule"] == mock_datamodule
|
||||
|
||||
|
||||
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
|
||||
def test_resume_training_from_run_id(run_extra_val_epoch: bool, ml_runner_with_run_id: MLRunner) -> None:
|
||||
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
ml_runner_with_run_id.container.max_num_gpus = 0
|
||||
ml_runner_with_run_id.container.max_epochs += 10
|
||||
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
|
||||
def test_resume_training_from_run_id(
|
||||
run_extra_val_epoch: bool, training_runner_hello_world_with_checkpoint: TrainingRunner
|
||||
) -> None:
|
||||
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
|
||||
training_runner_hello_world_with_checkpoint.container.max_num_gpus = 0
|
||||
training_runner_hello_world_with_checkpoint.container.max_epochs += 10
|
||||
assert training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path
|
||||
mock_trainer = MagicMock()
|
||||
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(mock_trainer, MagicMock())):
|
||||
with patch.object(ml_runner_with_run_id.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
|
||||
with patch.object(ml_runner_with_run_id, "run_inference") as mock_run_inference:
|
||||
with patch("health_ml.run_ml.cleanup_checkpoints") as mock_cleanup_ckpt:
|
||||
with patch("health_ml.runner_base.create_lightning_trainer", return_value=(mock_trainer, MagicMock())):
|
||||
with patch.object(
|
||||
training_runner_hello_world_with_checkpoint.container, "get_checkpoint_to_test"
|
||||
) as mock_get_checkpoint_to_test:
|
||||
with patch.object(training_runner_hello_world_with_checkpoint, "run_inference") as mock_run_inference:
|
||||
with patch("health_ml.training_runner.cleanup_checkpoints") as mock_cleanup_ckpt:
|
||||
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
|
||||
ml_runner_with_run_id.run()
|
||||
training_runner_hello_world_with_checkpoint.run()
|
||||
mock_get_checkpoint_to_test.assert_called_once()
|
||||
mock_cleanup_ckpt.assert_called_once()
|
||||
assert mock_trainer.validate.called == run_extra_val_epoch
|
||||
|
@ -457,12 +534,12 @@ def test_model_weights_when_resume_training() -> None:
|
|||
experiment_config = ExperimentConfig(model="HelloWorld")
|
||||
container = HelloWorld()
|
||||
container.max_num_gpus = 0
|
||||
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
|
||||
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
|
||||
container.resume_training = True
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||
runner.setup()
|
||||
assert runner.checkpoint_handler.trained_weights_path.is_file() # type: ignore
|
||||
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
|
||||
with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer:
|
||||
mock_create_trainer.return_value = MagicMock(), MagicMock()
|
||||
runner.init_training()
|
||||
mock_create_trainer.assert_called_once()
|
||||
|
@ -482,14 +559,14 @@ def test_log_on_vm(log_from_vm: bool) -> None:
|
|||
tag = f"test_log_on_vm [{log_from_vm}]"
|
||||
container.tag = tag
|
||||
container.log_from_vm = log_from_vm
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
runner = TrainingRunner(experiment_config=experiment_config, container=container)
|
||||
# When logging to AzureML, need to provide the unit test AML workspace.
|
||||
# When not logging to AzureML, no workspace (and no authentication) should be needed.
|
||||
if log_from_vm:
|
||||
with patch("health_azure.utils.get_workspace", return_value=DEFAULT_WORKSPACE.workspace):
|
||||
runner.run()
|
||||
runner.run_and_cleanup()
|
||||
else:
|
||||
runner.run()
|
||||
runner.run_and_cleanup()
|
||||
# The PL trainer object is created in the init_training method.
|
||||
# Check that the AzureML logger is set up correctly.
|
||||
assert runner.trainer is not None
|
||||
|
@ -536,13 +613,15 @@ def test_get_mlflow_run_id_from_trainer() -> None:
|
|||
assert run_id == mock_run_id
|
||||
|
||||
|
||||
def test_inference_only_metrics_correctness(ml_runner_with_run_id: MLRunner, regression_datadir: Path) -> None:
|
||||
ml_runner_with_run_id.container.run_inference_only = True
|
||||
ml_runner_with_run_id.container.local_dataset_dir = regression_datadir
|
||||
ml_runner_with_run_id.run()
|
||||
with open(ml_runner_with_run_id.container.outputs_folder / "test_mse.txt") as f:
|
||||
def test_inference_only_metrics_correctness(
|
||||
training_runner_hello_world_with_checkpoint: TrainingRunner, regression_datadir: Path
|
||||
) -> None:
|
||||
training_runner_hello_world_with_checkpoint.container.run_inference_only = True
|
||||
training_runner_hello_world_with_checkpoint.container.local_dataset_dir = regression_datadir
|
||||
training_runner_hello_world_with_checkpoint.run()
|
||||
with open(training_runner_hello_world_with_checkpoint.container.outputs_folder / TEST_MSE_FILE) as f:
|
||||
mse = float(f.readlines()[0])
|
||||
assert isclose(mse, 0.010806690901517868, abs_tol=1e-3)
|
||||
with open(ml_runner_with_run_id.container.outputs_folder / "test_mae.txt") as f:
|
||||
with open(training_runner_hello_world_with_checkpoint.container.outputs_folder / TEST_MAE_FILE) as f:
|
||||
mae = float(f.readlines()[0])
|
||||
assert isclose(mae, 0.08260975033044815, abs_tol=1e-3)
|
Загрузка…
Ссылка в новой задаче