ENH: Ability to evaluate model on a new dataset (#859)

Renamed the `MLRunner` to `TrainingRunner`. Introduced a new
`EvalRunner`. Shared functionality moved from `MLRunner` to
`RunnerBase`.
This commit is contained in:
Anton Schwaighofer 2023-04-04 17:24:15 +01:00 коммит произвёл GitHub
Родитель 2cc2511efd
Коммит e73c92649d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
16 изменённых файлов: 768 добавлений и 394 удалений

Просмотреть файл

@ -7,6 +7,7 @@ use of these features:
- Working with different models in the same codebase, and selecting one by name
- Distributed training in AzureML
- Logging via AzureML's native capabilities
- Evaluation of the trained model on new datasets
This can be used by invoking the hi-ml runner and providing the name of the container class, like this:
`himl-runner --model=MyContainer`.
@ -215,7 +216,13 @@ and returns a tuple containing the Optimizer and LRScheduler objects
## Run inference with a pretrained model
You can use the hi-ml-runner in inference mode only by switching the `--run_inference_only` flag on and specifying
the model weights by setting `--src_checkpoint` argument that supports three types of checkpoints:
the model weights by setting `--src_checkpoint` argument. With this flag, the model will be evaluated on the test set
only. There is also an option for evaluating the model an a full dataset, described further below.
### Specifying the checkpoint to use
When running inference on a trained model, you need to provide a model checkpoint that should be used. This is done via
the `--src_checkpoint` argument. This supports three types of checkpoints:
- A local path where the checkpoint is stored `--src_checkpoint=local/path/to/my_checkpoint/model.ckpt`
- A remote URL from where to download the weights `--src_checkpoint=https://my_checkpoint_url.com/model.ckpt`
@ -228,13 +235,69 @@ the model weights by setting `--src_checkpoint` argument that supports three typ
Refer to [Checkpoints Utils](checkpoints.md) for more details on how checkpoints are parsed.
Running the following command line will run inference using `MyContainer` model with weights from the checkpoint saved
### Running inference on the test set
When supplying the flag `--run_inference_only` on the commandline, no model training will be run, and only inference on the
test set will be done:
- The model weights will be loaded from the location specified by `--src_checkpoint`
- A PyTorch Lightining `Trainer` object will be instantiated.
- The test set will be read out from the data module specified by the `get_data_module` method of the
`LightningContainer` object.
- The model will be evaluated on the test set, by running `trainer.test`. Any special logic to use during the test step
will need to be added to the model's `test_step` method.
Running the following command line will run inference using the `MyContainer` model with weights from the checkpoint saved
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
```bash
himl-runner --model=Mycontainer --run_inference_only --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt
```
### Running inference on a full dataset
When supplying the flag `--mode=eval_full` on the commandline, no model training will be run, and the model will be
evaluated on a dataset different from the training/validation/test dataset. This dataset is loaded via the
`get_eval_data_module` method of the container.
- The model weights will be loaded from the location specified by `--src_checkpoint`
- A PyTorch Lightining `Trainer` object will be instantiated.
- The test set will be read out from the data module specified by the `get_eval_data_module` method of the
`LightningContainer` object. The data module itself can read data from a mounted Azure dataset, which will be made
availabe for the container at the path `self.local_datasets`. In a typical use-case, all the data in that dataset will be
put into the `test_dataloader` field of the data module.
- The model will be evaluated on the test set, by running `trainer.test`. Any special logic to use during the test step
will need to added to the model's `test_step` method.
Running the following command line will run inference using the `MyContainer` model with weights from the checkpoint saved
in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`.
```bash
himl-runner --model=Mycontainer --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt --mode=eval_full --azure_datasets=my_new_dataset
```
The example code snippet here shows how to add a method that reads the inference dataset. In this example, we assume
that the `MyDataModule` class has an argument `splits` that specifies the fraction of data to go into the training,
validation, and test data loaders.
```python
class MyContainer(LightningContainer):
def __init__(self):
super().__init__()
self.azure_datasets = ["folder_name_in_azure_blob_storage"]
self.local_datasets = [Path("/some/local/path")]
self.max_epochs = 42
def create_model(self) -> LightningModule:
return MyLightningModel()
def get_data_module(self) -> LightningDataModule:
return MyDataModule(root_path=self.local_dataset, splits=(0.7, 0.2, 0.1))
def get_eval_data_module(self) -> LightningDataModule:
return MyDataModule(root_path=self.local_dataset, splits=(0.0, 0.0, 1.0))
```
## Resume training from a given checkpoint
Analogously, one can resume training by setting `--src_checkpoint` and `--resume_training` to train a model longer.
@ -298,5 +361,6 @@ seeds on a local machine, other than by manually starting runs with `--random_se
the limit that AML sets for snapshots. Solution: check for cache files, log files or other files that are not
necessary for running your experiment and add them to a `.amlignore` file in the root directory. Alternatively, you
can see Azure ML documentation for instructions on increasing this limit, although it will make your jobs slower.
2. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload
the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML.
1. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload
the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML.

5
hi-ml-azure/.vscode/settings.json поставляемый
Просмотреть файл

@ -61,4 +61,9 @@
"${workspaceFolder}/src",
"${workspaceFolder}/testazure"
],
"workbench.colorCustomizations": {
"activityBar.background": "#5D091F",
"titleBar.activeBackground": "#830D2B",
"titleBar.activeForeground": "#FFFCFC"
},
}

Просмотреть файл

@ -524,7 +524,7 @@ def create_dataset_configs(
count = num_azure
elif num_azure == 0 and num_mount == 0:
# No datasets in Azure at all: This is possible for runs that for example download their own data from the web.
# There can be any number of local datasets, but we are not checking that. In MLRunner.setup, there is a check
# There can be any number of local datasets, but we are not checking that. In TrainingRunner.setup, there is a check
# that leaves local datasets intact if there are no Azure datasets.
return []
else:

Просмотреть файл

@ -2,8 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from health_ml.run_ml import MLRunner
from health_ml.training_runner import TrainingRunner
from health_ml.runner import Runner
__all__ = ["MLRunner", "Runner"]
__all__ = ["TrainingRunner", "Runner"]

Просмотреть файл

@ -16,6 +16,9 @@ from torch.utils.data import DataLoader, Dataset
from health_ml.lightning_container import LightningContainer
TEST_MSE_FILE = "test_mse.txt"
TEST_MAE_FILE = "test_mae.txt"
def _create_1d_regression_dataset(n: int = 100, seed: int = 0) -> torch.Tensor:
"""Creates a simple 1-D dataset of a noisy linear function.
@ -79,10 +82,10 @@ class HelloWorldDataModule(LightningDataModule):
A data module that gives the training, validation and test data for a simple 1-dim regression task.
"""
def __init__(self, crossval_count: int, crossval_index: Optional[int]) -> None:
def __init__(self, crossval_count: int, crossval_index: Optional[int] = None, seed: int = 0) -> None:
super().__init__()
n_total = 200
xy = _create_1d_regression_dataset(n=n_total)
xy = _create_1d_regression_dataset(n=n_total, seed=seed)
n_test = 40
n_val = 50
self.test = HelloWorldDataset(xy=xy[:n_test])
@ -229,8 +232,8 @@ class HelloRegression(LightningModule):
for example writing aggregate metrics to disk.
"""
average_mse = torch.mean(torch.stack(self.test_mse))
Path("test_mse.txt").write_text(str(average_mse.item()))
Path("test_mae.txt").write_text(str(self.test_mae.compute().item()))
Path(TEST_MSE_FILE).write_text(str(average_mse.item()))
Path(TEST_MAE_FILE).write_text(str(self.test_mae.compute().item()))
self.log("test_mse", average_mse, on_epoch=True, on_step=False)
def on_run_extra_validation_epoch(self) -> None:
@ -266,6 +269,11 @@ class HelloWorld(LightningContainer):
# datamodule must carry out appropriate splitting of the data.
return HelloWorldDataModule(crossval_count=self.crossval_count, crossval_index=self.crossval_index)
# This method is optional. Override it to supply a data module to use for evaluating the model on a new dataset.
# Only the test data loader from the data module will be used.
def get_eval_data_module(self) -> LightningDataModule:
return HelloWorldDataModule(crossval_count=1, seed=1)
def get_callbacks(self) -> List[Callback]:
if self.save_checkpoint:
return [

Просмотреть файл

@ -0,0 +1,29 @@
from pytorch_lightning import LightningDataModule
from health_azure.logging import logging_section
from health_ml.runner_base import RunnerBase
class EvalRunner(RunnerBase):
"""A class to run the evaluation of a model on a new dataset. The initialization logic is taken from the base
class `RunnerBase`.
"""
def validate(self) -> None:
"""Checks if the fields of the class are set up correctly."""
if self.container.src_checkpoint is None or self.container.src_checkpoint.checkpoint == "":
raise ValueError(
"To use model evaluation, you need to provide a checkpoint to use, via the --src_checkpoint argument."
)
def run(self) -> None:
"""Start the core workflow that the class implements: Initialize a PL Trainer object and use that to run
inference on the inference dataset."""
self.container.outputs_folder.mkdir(exist_ok=True, parents=True)
self.init_inference()
with logging_section("Model inference"):
self.run_inference()
def get_data_module(self) -> LightningDataModule:
"""Reads the evaluation data module from the underlying container."""
return self.container.get_eval_data_module()

Просмотреть файл

@ -10,6 +10,11 @@ class DebugDDPOptions(Enum):
DETAIL = "DETAIL"
class RunnerMode(Enum):
TRAIN = "train"
EVAL_FULL = "eval_full"
DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG"
@ -82,3 +87,10 @@ class ExperimentConfig(param.Parameterized):
doc="The maximum runtime that is allowed for this job in AzureML. This is given as a floating"
"point number with a string suffix s, m, h, d for seconds, minutes, hours, day. Examples: '3.5h', '2d'",
)
mode: str = param.ClassSelector(
class_=RunnerMode,
default=RunnerMode.TRAIN,
doc=f"The mode to run the experiment in. Can be one of '{RunnerMode.TRAIN}' (training and evaluation on the "
f"test set), or '{RunnerMode.EVAL_FULL}' for evaluation on the full dataset specified by the "
"'get_eval_data_module' method of the container.",
)

Просмотреть файл

@ -65,6 +65,17 @@ class LightningContainer(WorkflowParams, DatasetParams, OutputParams, TrainerPar
"""
return None # type: ignore
def get_eval_data_module(self) -> LightningDataModule:
"""
Gets the data that is used when evaluating the model on a new dataset.
This data module should read datasets from the self.local_datasets folder or download from a web location.
Only the test dataloader is used, hence the method needs to put all data into the test dataloader, rather
than splitting into train/val/test.
:return: A LightningDataModule
"""
return None # type: ignore
def get_trainer_arguments(self) -> Dict[str, Any]:
"""
Gets additional parameters that will be passed on to the PyTorch Lightning trainer.

Просмотреть файл

@ -31,9 +31,9 @@ from health_azure.datasets import create_dataset_configs # noqa: E402
from health_azure.himl import DEFAULT_DOCKER_BASE_IMAGE, OUTPUT_FOLDER # noqa: E402
from health_azure.logging import logging_to_stdout # noqa: E402
from health_azure.paths import is_himl_used_from_git_repo # noqa: E402
from health_azure.utils import (
from health_azure.utils import ( # noqa: E402
ENV_LOCAL_RANK,
ENV_NODE_RANK, # noqa: E402
ENV_NODE_RANK,
get_workspace,
get_ml_client,
is_local_rank_zero,
@ -43,9 +43,11 @@ from health_azure.utils import (
is_global_rank_zero,
)
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402
from health_ml.eval_runner import EvalRunner # noqa: E402
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig, RunnerMode # noqa: E402
from health_ml.lightning_container import LightningContainer # noqa: E402
from health_ml.run_ml import MLRunner # noqa: E402
from health_ml.runner_base import RunnerBase # noqa: E402
from health_ml.training_runner import TrainingRunner # noqa: E402
from health_ml.utils import fixed_paths # noqa: E402
from health_ml.utils.logging import ConsoleAndFileOutput # noqa: E402
from health_ml.utils.common_utils import check_conda_environment, choose_conda_env_file, is_linux # noqa: E402
@ -101,8 +103,8 @@ class Runner:
self.project_root = project_root
self.experiment_config: ExperimentConfig = ExperimentConfig()
self.lightning_container: LightningContainer = None # type: ignore
# This field stores the MLRunner object that has been created in the most recent call to the run() method.
self.ml_runner: Optional[MLRunner] = None
# This field stores the TrainingRunner object that has been created in the most recent call to the run() method.
self.ml_runner: Optional[RunnerBase] = None
def parse_and_load_model(self) -> ParserResult:
"""
@ -322,15 +324,26 @@ class Runner:
assert azure_run_info.run is not None
azure_run_info.run.set_tags(self.additional_run_tags(sys.argv[1:]))
# Set environment variables for multi-node training if needed. This function will terminate early
# if it detects that it is not in a multi-node environment.
if self.experiment_config.num_nodes > 1:
set_environment_variables_for_multi_node()
self.ml_runner = MLRunner(
experiment_config=self.experiment_config, container=self.lightning_container, project_root=self.project_root
)
self.ml_runner.setup(azure_run_info)
self.ml_runner.run()
if self.experiment_config.mode == RunnerMode.TRAIN:
# Set environment variables for multi-node training if needed. This function will terminate early
# if it detects that it is not in a multi-node environment.
if self.experiment_config.num_nodes > 1:
set_environment_variables_for_multi_node()
self.ml_runner = TrainingRunner(
experiment_config=self.experiment_config,
container=self.lightning_container,
project_root=self.project_root,
)
elif self.experiment_config.mode == RunnerMode.EVAL_FULL:
self.ml_runner = EvalRunner(
experiment_config=self.experiment_config,
container=self.lightning_container,
project_root=self.project_root,
)
else:
raise ValueError(f"Unknown mode {self.experiment_config.mode}")
self.ml_runner.validate()
self.ml_runner.run_and_cleanup(azure_run_info)
def run(project_root: Path) -> Tuple[LightningContainer, AzureRunInfo]:

Просмотреть файл

@ -0,0 +1,247 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
from pathlib import Path
from typing import List, Optional
from azureml.core import Run
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
from health_azure import AzureRunInfo
from health_azure.utils import (
PARENT_RUN_CONTEXT,
RUN_CONTEXT,
create_aml_run_object,
create_run_recovery_id,
is_running_in_azure_ml,
)
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.model_trainer import create_lightning_trainer
from health_ml.utils import fixed_paths
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.common_utils import (
EFFECTIVE_RANDOM_SEED_KEY_NAME,
RUN_RECOVERY_FROM_ID_KEY_NAME,
RUN_RECOVERY_ID_KEY,
change_working_directory,
seed_monai_if_available,
)
from health_ml.utils.lightning_loggers import StoringLogger, get_mlflow_run_id_from_trainer
from health_ml.utils.type_annotations import PathOrString
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
"""
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
to a Path instance. If it does not exist, raise a FileNotFoundError.
:param local_dataset: The dataset folder to check.
:return: The local_dataset argument, converted to a Path.
"""
expected_dir = Path(local_dataset)
if not expected_dir.is_dir():
raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.")
logging.info(f"Model will use the local dataset provided in {expected_dir}")
return expected_dir
class RunnerBase:
"""
A base class with operations that are shared between the training/test runner and the evaluation-only runner.
"""
def __init__(
self, experiment_config: ExperimentConfig, container: LightningContainer, project_root: Optional[Path] = None
) -> None:
"""
Driver class to run an ML experiment. Note that the project root argument MUST be supplied when using hi-ml
as a package!
:param experiment_config: The ExperimentConfig object to use for training.
:param container: The LightningContainer object to use for training.
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
it is crucial when using hi-ml as a package or submodule!
"""
self.container = container
self.experiment_config = experiment_config
self.container.num_nodes = self.experiment_config.num_nodes
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
self.storing_logger: Optional[StoringLogger] = None
self._has_setup_run = False
self.checkpoint_handler = CheckpointHandler(
container=self.container, project_root=self.project_root, run_context=RUN_CONTEXT
)
self.trainer: Optional[Trainer] = None
self.azureml_run_for_logging: Optional[Run] = None
self.mlflow_run_for_logging: Optional[str] = None
# This is passed to trainer.validate and trainer.test in inference mode
self.inference_checkpoint: Optional[str] = None
def validate(self) -> None:
"""
Checks if all arguments and settings of the object are correct.
"""
pass
def setup_azureml(self) -> None:
"""
Execute setup steps that are specific to AzureML.
"""
if PARENT_RUN_CONTEXT is not None:
# Set metadata for the run in AzureML if running in a Hyperdrive job.
run_tags_parent = PARENT_RUN_CONTEXT.get_tags()
tags_to_copy = [
"tag",
"model_name",
"execution_mode",
"recovered_from",
"friendly_name",
"build_number",
"build_user",
RUN_RECOVERY_FROM_ID_KEY_NAME,
]
new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy}
new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT)
new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed())
RUN_CONTEXT.set_tags(new_tags)
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
"""
Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual
Lightning modules.
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
This can be missing when running in unit tests, where the local dataset paths are already populated.
"""
if self._has_setup_run:
return
if azure_run_info:
# Set up the paths to the datasets. azure_run_info already has all necessary information, using either
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
# This must happen before container setup because that could already read datasets.
logging.info("Setting tags from parent run.")
input_datasets = azure_run_info.input_datasets
logging.info(f"Setting the following datasets as local datasets: {input_datasets}")
if len(input_datasets) > 0:
local_datasets: List[Path] = []
for i, dataset in enumerate(input_datasets):
if dataset is None:
raise ValueError(f"Invalid setup: The dataset at index {i} is None")
local_datasets.append(check_dataset_folder_exists(dataset))
self.container.local_datasets = local_datasets # type: ignore
# Ensure that we use fixed seeds before initializing the PyTorch models.
# MONAI needs a separate method to make all transforms deterministic by default
seed = self.container.get_effective_random_seed()
seed_monai_if_available(seed)
seed_everything(seed)
# Creating the folder structure must happen before the LightningModule is created, because the output
# parameters of the container will be copied into the module.
self.container.create_filesystem(self.project_root)
# configure recovery container if provided
self.checkpoint_handler.download_recovery_checkpoints_or_weights()
# Create an AzureML run for logging if running outside AzureML.
self.create_logger()
self.container.setup()
self.container.create_lightning_module_and_store()
self._has_setup_run = True
if is_running_in_azure_ml():
self.setup_azureml()
def create_logger(self) -> None:
"""
Create an AzureML run for logging if running outside AzureML. This run will be used for metrics logging
during both training and inference. We can't rely on the automatically generated run inside the AzureMLLogger
class because two of those logger objects will be created, so training and inference metrics would be logged
in different runs.
"""
if self.container.log_from_vm:
run = create_aml_run_object(experiment_name=self.container.effective_experiment_name)
# Display name should already be set when creating the Run object, but in some scenarios this
# does not happen. Hence, set it again.
run.display_name = self.container.tag if self.container.tag else None
self.azureml_run_for_logging = run
def get_data_module(self) -> LightningDataModule:
"""
Reads the datamodule that should be used for training or valuation from the container. This must be
overridden in subclasses.
"""
raise NotImplementedError()
def set_trainer_for_inference(self) -> None:
"""Set the runner's PL Trainer object that should be used when running inference on the validation or test set.
We run inference on a single device because distributed strategies such as DDP use DistributedSampler
internally, which replicates some samples to make sure all devices have the same batch size in case of
uneven inputs which biases the results."""
mlflow_run_id = get_mlflow_run_id_from_trainer(self.trainer)
self.container.max_num_gpus = self.container.max_num_gpus_inference
self.trainer, _ = create_lightning_trainer(
container=self.container,
num_nodes=1,
azureml_run_for_logging=self.azureml_run_for_logging,
mlflow_run_for_logging=mlflow_run_id,
)
def init_inference(self) -> None:
"""Prepare the runner for inference on validation set, test set, or a full dataset.
The following steps are performed:
1. Get the checkpoint to use for inference. This is either the checkpoint from the last training epoch or the
one specified in src_checkpoint argument.
2. Create a new trainer instance for inference. This is necessary because the trainer is created with a single
device in contrast to training that uses DDP if multiple GPUs are available.
3. Create a new data module instance for inference to account for any requested changes in the dataloading
parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch.
"""
self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test())
self.set_trainer_for_inference()
self.data_module = self.get_data_module()
def run_inference(self) -> None:
"""Run inference on the test set for all models. This is done by calling the LightningModule.test_step method.
If the LightningModule.test_step method is not overridden, then this method does nothing. The cwd is changed to
the outputs folder so that the model can write to current working directory, and still everything is put into
the right place in AzureML (there, only the contents of the "outputs" folder is treated as a result file).
"""
if self.container.has_custom_test_step():
logging.info("Running inference via the LightningModule.test_step method")
with change_working_directory(self.container.outputs_folder):
assert self.trainer, "Trainer should be initialized before inference. Call self.init_inference()."
_ = self.trainer.test(
self.container.model, datamodule=self.data_module, ckpt_path=self.inference_checkpoint
)
else:
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")
def run(self) -> None:
"""
Run the training or evaluation. This method must be overridden in subclasses.
"""
pass
def run_and_cleanup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
"""
Run the training or evaluation via `self.run` and cleanup afterwards.
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
This can be missing when running in unit tests, where the local dataset paths are already populated.
"""
self.setup(azure_run_info)
try:
self.run()
finally:
if self.azureml_run_for_logging is not None:
try:
self.azureml_run_for_logging.complete()
except Exception as ex:
logging.error("Failed to complete AzureML run: %s", ex)

Просмотреть файл

@ -6,14 +6,11 @@ import json
import logging
import os
import sys
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict
import torch
from azureml.core import Run
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import LightningDataModule, seed_everything
from health_azure import AzureRunInfo
from health_azure.logging import logging_section, print_message_with_rank_pid
from health_azure.utils import (
ENV_GLOBAL_RANK,
@ -22,140 +19,24 @@ from health_azure.utils import (
ENV_OMPI_COMM_WORLD_RANK,
PARENT_RUN_CONTEXT,
RUN_CONTEXT,
create_aml_run_object,
create_run_recovery_id,
get_metrics_for_hyperdrive_run,
get_metrics_for_run,
is_global_rank_zero,
is_local_rank_zero,
is_running_in_azure_ml,
)
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file
from health_ml.utils import fixed_paths
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.runner_base import RunnerBase
from health_ml.utils.checkpoint_utils import cleanup_checkpoints
from health_ml.utils.common_utils import (
EFFECTIVE_RANDOM_SEED_KEY_NAME,
RUN_RECOVERY_FROM_ID_KEY_NAME,
RUN_RECOVERY_ID_KEY,
change_working_directory,
seed_monai_if_available,
)
from health_ml.utils.lightning_loggers import StoringLogger, get_mlflow_run_id_from_trainer
from health_ml.utils.regression_test_utils import REGRESSION_TEST_METRICS_FILENAME, compare_folders_and_run_outputs
from health_ml.utils.type_annotations import PathOrString
def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
"""
Checks if a folder with a local dataset exists. If it does exist, return the argument converted
to a Path instance. If it does not exist, raise a FileNotFoundError.
:param local_dataset: The dataset folder to check.
:return: The local_dataset argument, converted to a Path.
"""
expected_dir = Path(local_dataset)
if not expected_dir.is_dir():
raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.")
logging.info(f"Model training will use the local dataset provided in {expected_dir}")
return expected_dir
class MLRunner:
def __init__(
self, experiment_config: ExperimentConfig, container: LightningContainer, project_root: Optional[Path] = None
) -> None:
"""
Driver class to run a ML experiment. Note that the project root argument MUST be supplied when using hi-ml
as a package!
:param experiment_config: The ExperimentConfig object to use for training.
:param container: The LightningContainer object to use for training.
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
it is crucial when using hi-ml as a package or submodule!
"""
self.container = container
self.experiment_config = experiment_config
self.container.num_nodes = self.experiment_config.num_nodes
self.project_root: Path = project_root or fixed_paths.repository_root_directory()
self.storing_logger: Optional[StoringLogger] = None
self._has_setup_run = False
self.checkpoint_handler = CheckpointHandler(
container=self.container, project_root=self.project_root, run_context=RUN_CONTEXT
)
self.trainer: Optional[Trainer] = None
self.azureml_run_for_logging: Optional[Run] = None
self.mlflow_run_for_logging: Optional[str] = None
self.inference_checkpoint: Optional[str] = None # Passed to trainer.validate and trainer.test in inference mode
def set_run_tags_from_parent(self) -> None:
"""
Set metadata for the run
"""
assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run."
run_tags_parent = PARENT_RUN_CONTEXT.get_tags()
tags_to_copy = [
"tag",
"model_name",
"execution_mode",
"recovered_from",
"friendly_name",
"build_number",
"build_user",
RUN_RECOVERY_FROM_ID_KEY_NAME,
]
new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy}
new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT)
new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed())
RUN_CONTEXT.set_tags(new_tags)
def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None:
"""
Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual
Lightning modules.
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
This can be missing when running in unit tests, where the local dataset paths are already populated.
"""
if self._has_setup_run:
return
if azure_run_info:
# Set up the paths to the datasets. azure_run_info already has all necessary information, using either
# the provided local datasets for VM runs, or the AzureML mount points when running in AML.
# This must happen before container setup because that could already read datasets.
input_datasets = azure_run_info.input_datasets
logging.info(f"Setting the following datasets as local datasets: {input_datasets}")
if len(input_datasets) > 0:
local_datasets: List[Path] = []
for i, dataset in enumerate(input_datasets):
if dataset is None:
raise ValueError(f"Invalid setup: The dataset at index {i} is None")
local_datasets.append(check_dataset_folder_exists(dataset))
self.container.local_datasets = local_datasets # type: ignore
# Ensure that we use fixed seeds before initializing the PyTorch models.
# MONAI needs a separate method to make all transforms deterministic by default
seed = self.container.get_effective_random_seed()
seed_monai_if_available(seed)
seed_everything(seed)
# Creating the folder structure must happen before the LightningModule is created, because the output
# parameters of the container will be copied into the module.
self.container.create_filesystem(self.project_root)
# configure recovery container if provided
self.checkpoint_handler.download_recovery_checkpoints_or_weights() # type: ignore
self.container.setup()
self.container.create_lightning_module_and_store()
self._has_setup_run = True
is_offline_run = not is_running_in_azure_ml(RUN_CONTEXT)
# Get the AzureML context in which the script is running
if not is_offline_run and PARENT_RUN_CONTEXT is not None:
logging.info("Setting tags from parent run.")
self.set_run_tags_from_parent()
class TrainingRunner(RunnerBase):
def get_data_module(self) -> LightningDataModule:
return self.container.get_data_module()
def get_multiple_trainloader_mode(self) -> str:
# Workaround for a bug in PL 1.5.5: We need to pass the cycle mode for the training data as a trainer argument
@ -190,18 +71,7 @@ class MLRunner:
seed_everything(self.container.get_effective_random_seed(), workers=True)
# Get the container's datamodule
self.data_module = self.container.get_data_module()
# Create an AzureML run for logging if running outside AzureML. This run will be used for metrics logging
# during both training and inference. We can't rely on the automatically generated run inside the AzureMLLogger
# class because two of those logger objects will be created, so training and inference metrics would be logged
# in different runs.
if self.container.log_from_vm:
run = create_aml_run_object(experiment_name=self.container.effective_experiment_name)
# Display name should already be set when creating the Run object, but in some scenarios this
# does not happen. Hence, set it again.
run.display_name = self.container.tag if self.container.tag else None
self.azureml_run_for_logging = run
self.data_module = self.get_data_module()
if not self.container.run_inference_only:
checkpoint_path_for_recovery = self.checkpoint_handler.get_recovery_or_checkpoint_path_train()
@ -275,36 +145,6 @@ class MLRunner:
return self.container.crossval_index == 0
return True
def set_trainer_for_inference(self) -> None:
"""Set the runner's PL Trainer object that should be used when running inference on the validation or test set.
We run inference on a single device because distributed strategies such as DDP use DistributedSampler
internally, which replicates some samples to make sure all devices have the same batch size in case of
uneven inputs which biases the results."""
mlflow_run_id = get_mlflow_run_id_from_trainer(self.trainer)
self.container.max_num_gpus = self.container.max_num_gpus_inference
self.trainer, _ = create_lightning_trainer(
container=self.container,
num_nodes=1,
azureml_run_for_logging=self.azureml_run_for_logging,
mlflow_run_for_logging=mlflow_run_id,
)
def init_inference(self) -> None:
"""Prepare the runner for inference: validation or test. The following steps are performed:
1. Get the checkpoint to use for inference. This is either the checkpoint from the last training epoch or the
one specified in src_checkpoint argument.
2. If the container has a run_extra_val_epoch method, call it to run an extra validation epoch.
3. Create a new trainer instance for inference. This is necessary because the trainer is created with a single
device in contrast to training that uses DDP if multiple GPUs are available.
4. Create a new data module instance for inference to account for any requested changes in the dataloading
parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch.
"""
self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test())
if self.container.run_extra_val_epoch:
self.container.on_run_extra_validation_epoch()
self.set_trainer_for_inference()
self.data_module = self.container.get_data_module()
def run_training(self) -> None:
"""
The main training loop. It creates the Pytorch model based on the configuration options passed in,
@ -321,6 +161,16 @@ class MLRunner:
assert logger is not None
logger.finalize('success')
def init_inference(self) -> None:
"""
Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint,
initializes the PL Trainer object, and chooses the right data module. Afterwards, the hook for running
inference on the validation set is run (`LightningContainer.on_run_extra_validation_epoch`)
"""
super().init_inference()
if self.container.run_extra_val_epoch:
self.container.on_run_extra_validation_epoch()
def run_validation(self) -> None:
"""Run validation on the validation set for all models to save time/memory consuming outputs. This is done in
inference only mode or when the user has requested an extra validation epoch. The cwd is changed to the outputs
@ -334,22 +184,6 @@ class MLRunner:
else:
logging.info("Skipping extra validation because the user has not requested it.")
def run_inference(self) -> None:
"""Run inference on the test set for all models. This is done by calling the LightningModule.test_step method.
If the LightningModule.test_step method is not overridden, then this method does nothing. The cwd is changed to
the outputs folder so that the model can write to current working directory, and still everything is put into
the right place in AzureML (there, only the contents of the "outputs" folder is treated as a result file).
"""
if self.container.has_custom_test_step():
logging.info("Running inference via the LightningModule.test_step method")
with change_working_directory(self.container.outputs_folder):
assert self.trainer, "Trainer should be initialized before inference. Call self.init_inference()."
_ = self.trainer.test(
self.container.model, datamodule=self.data_module, ckpt_path=self.inference_checkpoint
)
else:
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")
def run_regression_test(self) -> None:
if self.container.regression_test_folder:
with logging_section("Regression Test"):
@ -392,32 +226,23 @@ class MLRunner:
"""
Driver function to run a ML experiment
"""
self.setup()
try:
self.init_training()
self.init_training()
if not self.container.run_inference_only:
# Backup the environment variables in case we need to run a second training in the unit tests.
environ_before_training = dict(os.environ)
if not self.container.run_inference_only:
# Backup the environment variables in case we need to run a second training in the unit tests.
environ_before_training = dict(os.environ)
with logging_section("Model training"):
self.run_training()
with logging_section("Model training"):
self.run_training()
self.end_training(environ_before_training)
self.end_training(environ_before_training)
self.init_inference()
self.init_inference()
with logging_section("Model validation"):
self.run_validation()
with logging_section("Model validation"):
self.run_validation()
with logging_section("Model inference"):
self.run_inference()
with logging_section("Model inference"):
self.run_inference()
self.run_regression_test()
finally:
if self.azureml_run_for_logging is not None:
try:
self.azureml_run_for_logging.complete()
except Exception as ex:
logging.error("Failed to complete AzureML run: %s", ex)
self.run_regression_test()

Просмотреть файл

@ -1,4 +1,15 @@
from pathlib import Path
import pytest
from health_ml.runner import Runner
from health_ml.utils import health_ml_package_setup
# Reduce logging noise in DEBUG mode
health_ml_package_setup()
@pytest.fixture
def mock_runner(tmp_path: Path) -> Runner:
"""A test fixture that creates a Runner object in a temporary folder."""
return Runner(project_root=tmp_path)

Просмотреть файл

@ -0,0 +1,78 @@
from pathlib import Path
import sys
from unittest.mock import patch
import pytest
from health_ml import TrainingRunner
from health_ml.configs.hello_world import (
TEST_MAE_FILE,
TEST_MSE_FILE,
HelloWorld,
HelloWorldDataModule,
)
from health_ml.eval_runner import EvalRunner
from health_ml.experiment_config import ExperimentConfig, RunnerMode
from health_ml.runner import Runner
from health_ml.utils.checkpoint_utils import CheckpointParser
from testhiml.test_training_runner import training_runner_hello_world
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
hello_world_checkpoint = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
def test_eval_runner_no_checkpoint(mock_runner: Runner) -> None:
"""Test of the evaluation mode fails if no checkpoint source is provided"""
arguments = ["", f"--model=HelloWorld", f"--mode={RunnerMode.EVAL_FULL.value}"]
with pytest.raises(ValueError, match="To use model evaluation, you need to provide a checkpoint to use"):
with patch.object(sys, "argv", arguments):
mock_runner.run()
def test_eval_runner_end_to_end(mock_runner: Runner) -> None:
"""Test the end-to-end integration of the EvalRunner class into the overall Runner"""
arguments = [
"",
f"--model=HelloWorld",
f"--mode={RunnerMode.EVAL_FULL.value}",
f"--src_checkpoint={hello_world_checkpoint}",
]
with patch("health_ml.training_runner.TrainingRunner.run_and_cleanup") as mock_training_run:
with patch.object(sys, "argv", arguments):
mock_runner.run()
# The training runner should not be invoked
mock_training_run.assert_not_called()
# The eval runner should have been invoked. The test step writes two files with metrics, check that
# they exist
output_folder = mock_runner.lightning_container.outputs_folder
for file_name in [TEST_MSE_FILE, TEST_MAE_FILE]:
assert (output_folder / file_name).exists()
def test_eval_runner_methods_called(tmp_path: Path) -> None:
"""Test if the eval runner uses the right data module from the HelloWorld model"""
container = HelloWorld()
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
eval_runner = EvalRunner(
container=container, experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL), project_root=tmp_path
)
with patch("health_ml.configs.hello_world.HelloWorld.get_eval_data_module") as mock_get_data_module:
mock_get_data_module.return_value = HelloWorldDataModule(crossval_count=1, seed=1)
eval_runner.run_and_cleanup()
mock_get_data_module.assert_called_once_with()
def test_eval_runner_no_extra_validation_epoch_called(tmp_path: Path) -> None:
"""
Ensure that the eval runner does not invoke the hook the extra validation epoch that is used by the training runner.
"""
container = HelloWorld()
container.run_extra_val_epoch = True
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
eval_runner = EvalRunner(
container=container, experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL), project_root=tmp_path
)
with patch("health_ml.configs.hello_world.HelloRegression.on_run_extra_validation_epoch") as mock_hook:
eval_runner.run_and_cleanup()
mock_hook.assert_not_called()

Просмотреть файл

@ -13,7 +13,7 @@ import pytest
from health_azure.utils import create_aml_run_object
from health_ml.experiment_config import ExperimentConfig
from health_ml.run_ml import MLRunner
from health_ml.training_runner import TrainingRunner
from health_ml.configs.hello_world import HelloWorld
from health_ml.utils.regression_test_utils import (
CONTENTS_MISMATCH,
@ -54,7 +54,7 @@ def test_regression_test() -> None:
"""
container = HelloWorld()
container.regression_test_folder = Path(str(uuid.uuid4().hex))
runner = MLRunner(container=container, experiment_config=ExperimentConfig())
runner = TrainingRunner(container=container, experiment_config=ExperimentConfig())
runner.setup()
with pytest.raises(ValueError) as ex:
runner.run()

Просмотреть файл

@ -19,7 +19,7 @@ from health_azure import AzureRunInfo, DatasetConfig
from health_azure.himl import OUTPUT_FOLDER
from health_azure.utils import ENV_LOCAL_RANK, ENV_NODE_RANK
from health_azure.paths import ENVIRONMENT_YAML_FILE_NAME
from health_ml.configs.hello_world import HelloWorld # type: ignore
from health_ml.configs.hello_world import TEST_MAE_FILE, TEST_MSE_FILE, HelloWorld # type: ignore
from health_ml.deep_learning_config import WorkflowParams
from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, DebugDDPOptions
from health_ml.lightning_container import LightningContainer
@ -28,13 +28,6 @@ from health_ml.utils.common_utils import change_working_directory
from health_ml.utils.fixed_paths import repository_root_directory
@pytest.fixture
def mock_runner(tmp_path: Path) -> Runner:
"""A test fixture that creates a Runner object in a temporary folder."""
return Runner(project_root=tmp_path)
@contextmanager
def change_working_folder_and_add_environment(tmp_path: Path) -> Generator:
# Use a special simplified environment file only for the tests here. Copy that to a temp folder, then let the runner
@ -407,7 +400,7 @@ def test_run_hello_world(mock_runner: Runner) -> None:
# time-consuming auth
mock_get_workspace.assert_not_called()
# Summary.txt is written at start, the other files during inference
expected_files = ["experiment_summary.txt", "test_mae.txt", "test_mse.txt"]
expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE]
for file in expected_files:
assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}"
@ -428,9 +421,8 @@ def test_invalid_profiler(mock_runner: Runner) -> None:
invalid_profile = "--pl_profiler=foo"
arguments = ["", "--model=HelloWorld", invalid_profile]
with patch.object(sys, "argv", arguments):
with pytest.raises(ValueError) as ex:
with pytest.raises(ValueError, match="Unsupported profiler."):
mock_runner.run()
assert "Unsupported profiler." in str(ex)
def test_custom_datastore_outside_aml(mock_runner: Runner) -> None:

Просмотреть файл

@ -18,34 +18,38 @@ from pytorch_lightning import LightningModule
import mlflow
from pytorch_lightning import Trainer
from health_ml.configs.hello_world import HelloWorld # type: ignore
from health_ml.configs.hello_world import TEST_MAE_FILE, TEST_MSE_FILE, HelloWorld # type: ignore
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.run_ml import MLRunner
from health_ml.runner_base import RunnerBase
from health_ml.training_runner import TrainingRunner
from health_ml.utils.checkpoint_handler import CheckpointHandler
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_ml.utils.common_utils import is_gpu_available
from health_ml.utils.common_utils import EFFECTIVE_RANDOM_SEED_KEY_NAME, is_gpu_available
from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger, get_mlflow_run_id_from_trainer
from health_azure.utils import ENV_EXPERIMENT_NAME, is_global_rank_zero
from testazure.utils_testazure import DEFAULT_WORKSPACE, experiment_for_unittests
from testhiml.utils.fixed_paths_for_tests import mock_run_id
from testhiml.utils.fixed_paths_for_tests import full_test_data_path
no_gpu = not is_gpu_available()
hello_world_checkpoint = full_test_data_path(suffix="hello_world_checkpoint.ckpt")
@pytest.fixture(scope="module")
def ml_runner_no_setup() -> MLRunner:
@pytest.fixture()
def training_runner_no_setup(tmp_path: Path) -> TrainingRunner:
experiment_config = ExperimentConfig(model="HelloWorld")
container = LightningContainer(num_epochs=1)
runner = MLRunner(experiment_config=experiment_config, container=container)
container.set_output_to(tmp_path)
runner = TrainingRunner(experiment_config=experiment_config, container=container)
return runner
@pytest.fixture(scope="module")
def ml_runner() -> Generator:
@pytest.fixture()
def training_runner(tmp_path: Path) -> Generator:
experiment_config = ExperimentConfig(model="HelloWorld")
container = LightningContainer(num_epochs=1)
runner = MLRunner(experiment_config=experiment_config, container=container)
container.set_output_to(tmp_path)
runner = TrainingRunner(experiment_config=experiment_config, container=container)
runner.setup()
yield runner
output_dir = runner.container.file_system_config.outputs_folder
@ -54,10 +58,11 @@ def ml_runner() -> Generator:
@pytest.fixture()
def ml_runner_with_container() -> Generator:
def training_runner_hello_world(tmp_path: Path) -> Generator:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
runner = MLRunner(experiment_config=experiment_config, container=container)
container.set_output_to(tmp_path)
runner = TrainingRunner(experiment_config=experiment_config, container=container)
runner.setup()
yield runner
output_dir = runner.container.file_system_config.outputs_folder
@ -66,12 +71,15 @@ def ml_runner_with_container() -> Generator:
@pytest.fixture()
def ml_runner_with_run_id() -> Generator:
def training_runner_hello_world_with_checkpoint() -> Generator:
"""
A fixture with a training runner for the HelloWorld model that has a src_checkpoint set.
"""
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.save_checkpoint = True
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
runner = MLRunner(experiment_config=experiment_config, container=container)
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
runner = TrainingRunner(experiment_config=experiment_config, container=container)
runner.setup()
yield runner
output_dir = runner.container.file_system_config.outputs_folder
@ -92,14 +100,27 @@ def regression_datadir(tmp_path: Path) -> Generator:
shutil.rmtree(tmp_path)
def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None:
def create_mlflow_trash_folder(runner: RunnerBase) -> None:
"""Create a trash folder where MLFlow expects its deleted runs.
This is a workaround for sporadic test failures: When reading out the run_id, MLFlow checks its own
deleted runs folder, but that (or one of its parents) does not exist
"""
trash_folder = runner.container.outputs_folder / "mlruns" / ".trash"
trash_folder.mkdir(exist_ok=True, parents=True)
def test_ml_runner_setup(training_runner_no_setup: TrainingRunner) -> None:
"""Check that all the necessary methods get called during setup"""
assert not ml_runner_no_setup._has_setup_run
with patch.object(ml_runner_no_setup, "container", spec=LightningContainer) as mock_container:
with patch.object(ml_runner_no_setup, "checkpoint_handler", spec=CheckpointHandler) as mock_checkpoint_handler:
with patch("health_ml.run_ml.seed_everything") as mock_seed:
with patch("health_ml.run_ml.seed_monai_if_available") as mock_seed_monai:
ml_runner_no_setup.setup()
assert not training_runner_no_setup._has_setup_run
with patch.object(training_runner_no_setup, "container", spec=LightningContainer) as mock_container:
# Without that, it would try to create a local run object for logging and fail there.
mock_container.log_from_vm = False
with patch.object(
training_runner_no_setup, "checkpoint_handler", spec=CheckpointHandler
) as mock_checkpoint_handler:
with patch("health_ml.runner_base.seed_everything") as mock_seed:
with patch("health_ml.runner_base.seed_monai_if_available") as mock_seed_monai:
training_runner_no_setup.setup()
mock_container.get_effective_random_seed.assert_called()
mock_seed.assert_called_once()
mock_seed_monai.assert_called_once()
@ -107,47 +128,58 @@ def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None:
mock_checkpoint_handler.download_recovery_checkpoints_or_weights.assert_called_once()
mock_container.setup.assert_called_once()
mock_container.create_lightning_module_and_store.assert_called_once()
assert ml_runner_no_setup._has_setup_run
assert training_runner_no_setup._has_setup_run
def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None:
"""Test that set_run_tags_from_parents causes set_tags to get called"""
with pytest.raises(AssertionError) as ae:
ml_runner.set_run_tags_from_parent()
assert "should only be called in a Hyperdrive run" in str(ae)
def test_setup_azureml(training_runner: TrainingRunner) -> None:
"""Test that setup_azureml causes set_tags to get called when running in Hyperdrive"""
with patch("health_ml.runner_base.RUN_CONTEXT") as mock_run_context:
training_runner.setup_azureml()
# Tests always run outside of a Hyperdrive run. In those cases, no tags should be set on
# the current run.
mock_run_context.set_tags.assert_not_called()
with patch("health_ml.run_ml.PARENT_RUN_CONTEXT") as mock_parent_run_context:
with patch("health_ml.run_ml.RUN_CONTEXT") as mock_run_context:
mock_parent_run_context.get_tags.return_value = {"tag": "dummy_tag"}
ml_runner.set_run_tags_from_parent()
mock_run_context.set_tags.assert_called()
with patch("health_ml.runner_base.PARENT_RUN_CONTEXT") as mock_parent_run_context:
# Mock the presence of a parent run, and tags that are present there
tag_name = "tag"
tag_value = "dummy_tag"
mock_parent_run_context.get_tags.return_value = {tag_name: tag_value}
training_runner.setup_azureml()
# The function should read out tags from the parent run, and set them on the current run
mock_parent_run_context.get_tags.assert_called_once_with()
mock_run_context.set_tags.assert_called_once()
call_args = mock_run_context.set_tags.call_args[0][0]
assert tag_name in call_args
assert call_args[tag_name] == tag_value
assert EFFECTIVE_RANDOM_SEED_KEY_NAME in call_args
def test_get_multiple_trainloader_mode(ml_runner: MLRunner) -> None:
multiple_trainloader_mode = ml_runner.get_multiple_trainloader_mode()
def test_get_multiple_trainloader_mode(training_runner: TrainingRunner) -> None:
training_runner.init_training()
multiple_trainloader_mode = training_runner.get_multiple_trainloader_mode()
assert multiple_trainloader_mode == "max_size_cycle", "train_loader_cycle_mode is available now, "
"`get_multiple_trainloader_mode` workaround can be safely removed."
def _test_init_training(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None:
def _test_init_training(run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture) -> None:
"""Test that training is initialized correctly"""
ml_runner.container.run_inference_only = run_inference_only
ml_runner.setup()
assert not ml_runner.checkpoint_handler.has_continued_training
assert ml_runner.trainer is None
assert ml_runner.storing_logger is None
training_runner.container.run_inference_only = run_inference_only
training_runner.setup()
assert not training_runner.checkpoint_handler.has_continued_training
assert training_runner.trainer is None
assert training_runner.storing_logger is None
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
with patch.object(ml_runner.container, "get_data_module") as mock_get_data_module:
with patch("health_ml.run_ml.write_experiment_summary_file") as mock_write_experiment_summary_file:
with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer:
with patch.object(training_runner.container, "get_data_module") as mock_get_data_module:
with patch("health_ml.training_runner.write_experiment_summary_file") as mock_write_experiment_summary_file:
with patch.object(
ml_runner.checkpoint_handler, "get_recovery_or_checkpoint_path_train"
training_runner.checkpoint_handler, "get_recovery_or_checkpoint_path_train"
) as mock_get_recovery_or_checkpoint_path_train:
with patch("health_ml.run_ml.seed_everything") as mock_seed:
with patch("health_ml.training_runner.seed_everything") as mock_seed:
mock_create_trainer.return_value = MagicMock(), MagicMock()
mock_get_recovery_or_checkpoint_path_train.return_value = "dummy_path"
ml_runner.init_training()
training_runner.init_training()
# Make sure write_experiment_summary_file is only called on rank 0
if is_global_rank_zero():
@ -157,47 +189,51 @@ def _test_init_training(run_inference_only: bool, ml_runner: MLRunner, caplog: L
# Make sure seed is set correctly with workers=True
mock_seed.assert_called_once()
assert mock_seed.call_args[0][0] == ml_runner.container.get_effective_random_seed()
assert mock_seed.call_args[0][0] == training_runner.container.get_effective_random_seed()
assert mock_seed.call_args[1]["workers"]
mock_get_data_module.assert_called_once()
assert ml_runner.data_module is not None
assert training_runner.data_module is not None
if not run_inference_only:
mock_get_recovery_or_checkpoint_path_train.assert_called_once()
# Validate that the trainer is created correctly
assert mock_create_trainer.call_args[1]["resume_from_checkpoint"] == "dummy_path"
assert ml_runner.storing_logger is not None
assert ml_runner.trainer is not None
assert training_runner.storing_logger is not None
assert training_runner.trainer is not None
assert "Environment variables:" in caplog.messages[-1]
else:
assert ml_runner.trainer is None
assert ml_runner.storing_logger is None
assert training_runner.trainer is None
assert training_runner.storing_logger is None
mock_get_recovery_or_checkpoint_path_train.assert_not_called()
@pytest.mark.parametrize("run_inference_only", [True, False])
def test_init_training_cpu(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None:
def test_init_training_cpu(
run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture
) -> None:
"""Test that training is initialized correctly"""
ml_runner.container.max_num_gpus = 0
_test_init_training(run_inference_only, ml_runner, caplog)
training_runner.container.max_num_gpus = 0
_test_init_training(run_inference_only, training_runner, caplog)
@pytest.mark.skipif(no_gpu, reason="Test requires GPU")
@pytest.mark.gpu
@pytest.mark.parametrize("run_inference_only", [True, False])
def test_init_training_gpu(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None:
def test_init_training_gpu(
run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture
) -> None:
"""Test that training is initialized correctly in DDP mode"""
_test_init_training(run_inference_only, ml_runner, caplog)
_test_init_training(run_inference_only, training_runner, caplog)
def test_run_training() -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
runner = MLRunner(experiment_config=experiment_config, container=container)
runner = TrainingRunner(experiment_config=experiment_config, container=container)
with patch.object(container, "get_data_module") as mock_get_data_module:
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer:
runner.setup()
mock_trainer = MagicMock()
mock_storing_logger = MagicMock()
@ -225,17 +261,17 @@ def test_end_training(max_num_gpus_inf: int) -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.max_num_gpus_inference = max_num_gpus_inf
runner = MLRunner(experiment_config=experiment_config, container=container)
runner = TrainingRunner(experiment_config=experiment_config, container=container)
with patch.object(container, "get_data_module"):
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
with patch("health_ml.training_runner.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
runner.setup()
runner.init_training()
runner.run_training()
with patch.object(runner.checkpoint_handler, "additional_training_done") as mock_additional_training_done:
with patch.object(runner, "after_ddp_cleanup") as mock_after_ddp_cleanup:
with patch("health_ml.run_ml.cleanup_checkpoints") as mock_cleanup_ckpt:
with patch("health_ml.training_runner.cleanup_checkpoints") as mock_cleanup_ckpt:
environ_before_training = {"old": "environ"}
runner.end_training(environ_before_training=environ_before_training)
mock_additional_training_done.assert_called_once()
@ -251,71 +287,98 @@ def test_end_training(max_num_gpus_inf: int) -> None:
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
@pytest.mark.parametrize("run_inference_only", [True, False])
def test_init_inference(
run_inference_only: bool, run_extra_val_epoch: bool, max_num_gpus_inf: int, ml_runner_with_run_id: MLRunner
run_inference_only: bool,
run_extra_val_epoch: bool,
max_num_gpus_inf: int,
training_runner_hello_world_with_checkpoint: TrainingRunner,
) -> None:
ml_runner_with_run_id.container.run_inference_only = run_inference_only
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
ml_runner_with_run_id.container.max_num_gpus_inference = max_num_gpus_inf
assert ml_runner_with_run_id.container.max_num_gpus == -1 # This is the default value of max_num_gpus
ml_runner_with_run_id.init_training()
training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
training_runner_hello_world_with_checkpoint.container.max_num_gpus_inference = max_num_gpus_inf
assert (
training_runner_hello_world_with_checkpoint.container.max_num_gpus == -1
) # This is the default value of max_num_gpus
training_runner_hello_world_with_checkpoint.init_training()
if run_inference_only:
expected_mlflow_run_id = None
else:
assert ml_runner_with_run_id.trainer is not None
expected_mlflow_run_id = ml_runner_with_run_id.trainer.loggers[1].run_id # type: ignore
assert training_runner_hello_world_with_checkpoint.trainer is not None
create_mlflow_trash_folder(training_runner_hello_world_with_checkpoint)
expected_mlflow_run_id = training_runner_hello_world_with_checkpoint.trainer.loggers[1].run_id # type: ignore
if not run_inference_only:
ml_runner_with_run_id.checkpoint_handler.additional_training_done()
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
with patch.object(ml_runner_with_run_id.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
with patch.object(ml_runner_with_run_id.container, "get_data_module") as mock_get_data_module:
training_runner_hello_world_with_checkpoint.checkpoint_handler.additional_training_done()
with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer:
with patch.object(
training_runner_hello_world_with_checkpoint.container, "get_checkpoint_to_test"
) as mock_get_checkpoint_to_test:
with patch.object(
training_runner_hello_world_with_checkpoint.container, "get_data_module"
) as mock_get_data_module:
mock_checkpoint = MagicMock(is_file=MagicMock(return_value=True))
mock_get_checkpoint_to_test.return_value = mock_checkpoint
mock_trainer = MagicMock()
mock_create_trainer.return_value = mock_trainer, MagicMock()
mock_get_data_module.return_value = "dummy_data_module"
assert ml_runner_with_run_id.inference_checkpoint is None
assert not ml_runner_with_run_id.container.model._on_extra_val_epoch
assert training_runner_hello_world_with_checkpoint.inference_checkpoint is None
assert not training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch
ml_runner_with_run_id.init_inference()
training_runner_hello_world_with_checkpoint.init_inference()
expected_ckpt = str(ml_runner_with_run_id.checkpoint_handler.trained_weights_path)
expected_ckpt = str(training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path)
expected_ckpt = expected_ckpt if run_inference_only else str(mock_checkpoint)
assert ml_runner_with_run_id.inference_checkpoint == expected_ckpt
assert training_runner_hello_world_with_checkpoint.inference_checkpoint == expected_ckpt
assert hasattr(ml_runner_with_run_id.container.model, "on_run_extra_validation_epoch")
assert ml_runner_with_run_id.container.model._on_extra_val_epoch == run_extra_val_epoch
assert hasattr(
training_runner_hello_world_with_checkpoint.container.model, "on_run_extra_validation_epoch"
)
assert (
training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch
== run_extra_val_epoch
)
mock_create_trainer.assert_called_once()
assert ml_runner_with_run_id.trainer == mock_trainer
assert ml_runner_with_run_id.container.max_num_gpus == max_num_gpus_inf
assert mock_create_trainer.call_args[1]["container"] == ml_runner_with_run_id.container
assert training_runner_hello_world_with_checkpoint.trainer == mock_trainer
assert training_runner_hello_world_with_checkpoint.container.max_num_gpus == max_num_gpus_inf
assert (
mock_create_trainer.call_args[1]["container"]
== training_runner_hello_world_with_checkpoint.container
)
assert mock_create_trainer.call_args[1]["num_nodes"] == 1
assert mock_create_trainer.call_args[1]["mlflow_run_for_logging"] == expected_mlflow_run_id
mock_get_data_module.assert_called_once()
assert ml_runner_with_run_id.data_module == "dummy_data_module"
assert training_runner_hello_world_with_checkpoint.data_module == "dummy_data_module"
@pytest.mark.parametrize("run_inference_only", [True, False])
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
def test_run_validation(
run_extra_val_epoch: bool, run_inference_only: bool, ml_runner_with_run_id: MLRunner, caplog: LogCaptureFixture
run_extra_val_epoch: bool,
run_inference_only: bool,
training_runner_hello_world_with_checkpoint: TrainingRunner,
caplog: LogCaptureFixture,
) -> None:
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
ml_runner_with_run_id.container.run_inference_only = run_inference_only
ml_runner_with_run_id.init_training()
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only
training_runner_hello_world_with_checkpoint.init_training()
mock_datamodule = MagicMock()
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
with patch.object(ml_runner_with_run_id.container, "get_data_module", return_value=mock_datamodule):
create_mlflow_trash_folder(training_runner_hello_world_with_checkpoint)
with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer:
with patch.object(
training_runner_hello_world_with_checkpoint.container, "get_data_module", return_value=mock_datamodule
):
mock_trainer = MagicMock()
mock_create_trainer.return_value = mock_trainer, MagicMock()
ml_runner_with_run_id.init_inference()
assert ml_runner_with_run_id.trainer == mock_trainer
training_runner_hello_world_with_checkpoint.init_inference()
assert training_runner_hello_world_with_checkpoint.trainer == mock_trainer
mock_trainer.validate = Mock()
ml_runner_with_run_id.run_validation()
training_runner_hello_world_with_checkpoint.run_validation()
if run_extra_val_epoch or run_inference_only:
mock_trainer.validate.assert_called_once()
assert mock_trainer.validate.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint
assert (
mock_trainer.validate.call_args[1]["ckpt_path"]
== training_runner_hello_world_with_checkpoint.inference_checkpoint
)
assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule
else:
assert "Skipping extra validation" in caplog.messages[-1]
@ -332,12 +395,12 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
container = HelloWorld()
container.create_lightning_module_and_store()
container.run_extra_val_epoch = True
runner = MLRunner(experiment_config=experiment_config, container=container)
runner = TrainingRunner(experiment_config=experiment_config, container=container)
runner.setup()
runner.checkpoint_handler.additional_training_done()
runner.container.outputs_folder.mkdir(parents=True, exist_ok=True)
with patch.object(container, "get_data_module"):
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
with patch("health_ml.runner_base.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
with patch.object(runner.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
runner.init_inference()
@ -346,34 +409,34 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None:
assert "Hook `on_run_extra_validation_epoch` is not implemented" in latest_message
def test_run_inference(ml_runner_with_container: MLRunner, regression_datadir: Path) -> None:
def test_run_inference(training_runner_hello_world: TrainingRunner, regression_datadir: Path) -> None:
"""
Test that run_inference gets called as expected.
"""
ml_runner_with_container.container.max_num_gpus = 0
training_runner_hello_world.container.max_num_gpus = 0
def _expected_files_exist() -> bool:
output_dir = ml_runner_with_container.container.outputs_folder
output_dir = training_runner_hello_world.container.outputs_folder
if not output_dir.is_dir():
return False
expected_files = ["test_mse.txt", "test_mae.txt"]
expected_files = [TEST_MSE_FILE, TEST_MAE_FILE]
return all([(output_dir / p).exists() for p in expected_files])
expected_ckpt_path = ml_runner_with_container.container.outputs_folder / "checkpoints" / "last.ckpt"
expected_ckpt_path = training_runner_hello_world.container.outputs_folder / "checkpoints" / "last.ckpt"
assert not expected_ckpt_path.exists()
# update the container to look for test data at this location
ml_runner_with_container.container.local_dataset_dir = regression_datadir
training_runner_hello_world.container.local_dataset_dir = regression_datadir
assert not _expected_files_exist()
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
actual_train_ckpt_path = training_runner_hello_world.checkpoint_handler.get_recovery_or_checkpoint_path_train()
assert actual_train_ckpt_path is None
ml_runner_with_container.run()
training_runner_hello_world.run()
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
actual_train_ckpt_path = training_runner_hello_world.checkpoint_handler.get_recovery_or_checkpoint_path_train()
assert actual_train_ckpt_path == expected_ckpt_path
actual_test_ckpt_path = ml_runner_with_container.checkpoint_handler.get_checkpoint_to_test()
actual_test_ckpt_path = training_runner_hello_world.checkpoint_handler.get_checkpoint_to_test()
assert actual_test_ckpt_path == expected_ckpt_path
assert actual_test_ckpt_path.is_file()
# After training, the outputs directory should now exist and contain the 2 error files
@ -382,25 +445,25 @@ def test_run_inference(ml_runner_with_container: MLRunner, regression_datadir: P
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
@pytest.mark.parametrize("run_inference_only", [True, False])
def test_run(run_inference_only: bool, run_extra_val_epoch: bool, ml_runner_with_container: MLRunner) -> None:
def test_run(run_inference_only: bool, run_extra_val_epoch: bool, training_runner_hello_world: TrainingRunner) -> None:
"""Test that model runner gets called"""
ml_runner_with_container.container.run_inference_only = run_inference_only
ml_runner_with_container.container.run_extra_val_epoch = run_extra_val_epoch
ml_runner_with_container.setup()
assert not ml_runner_with_container.checkpoint_handler.has_continued_training
training_runner_hello_world.container.run_inference_only = run_inference_only
training_runner_hello_world.container.run_extra_val_epoch = run_extra_val_epoch
training_runner_hello_world.setup()
assert not training_runner_hello_world.checkpoint_handler.has_continued_training
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
with patch("health_ml.runner_base.create_lightning_trainer", return_value=(MagicMock(), MagicMock())):
with patch.multiple(
ml_runner_with_container,
training_runner_hello_world,
checkpoint_handler=mock.DEFAULT,
run_training=mock.DEFAULT,
run_validation=mock.DEFAULT,
run_inference=mock.DEFAULT,
end_training=mock.DEFAULT,
) as mocks:
ml_runner_with_container.run()
assert ml_runner_with_container.container.has_custom_test_step()
assert ml_runner_with_container._has_setup_run
training_runner_hello_world.run()
assert training_runner_hello_world.container.has_custom_test_step()
assert training_runner_hello_world._has_setup_run
assert mocks["end_training"] != run_inference_only
assert mocks["run_training"].called != run_inference_only
mocks["run_validation"].assert_called_once()
@ -408,45 +471,59 @@ def test_run(run_inference_only: bool, run_extra_val_epoch: bool, ml_runner_with
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
def test_run_inference_only(run_extra_val_epoch: bool, ml_runner_with_run_id: MLRunner) -> None:
def test_run_inference_only(
run_extra_val_epoch: bool, training_runner_hello_world_with_checkpoint: TrainingRunner
) -> None:
"""Test inference only mode. Validation should be run regardless of run_extra_val_epoch status."""
ml_runner_with_run_id.container.run_inference_only = True
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
training_runner_hello_world_with_checkpoint.container.run_inference_only = True
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
assert training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path
mock_datamodule = MagicMock()
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
with patch.object(ml_runner_with_run_id.container, "get_data_module", return_value=mock_datamodule):
with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer:
with patch.object(
training_runner_hello_world_with_checkpoint.container, "get_data_module", return_value=mock_datamodule
):
with patch.multiple(
ml_runner_with_run_id,
training_runner_hello_world_with_checkpoint,
run_training=mock.DEFAULT,
) as mocks:
mock_trainer = MagicMock()
mock_create_trainer.return_value = mock_trainer, MagicMock()
ml_runner_with_run_id.run()
training_runner_hello_world_with_checkpoint.run()
mock_create_trainer.assert_called_once()
mocks["run_training"].assert_not_called()
mock_trainer.validate.assert_called_once()
assert mock_trainer.validate.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint
assert (
mock_trainer.validate.call_args[1]["ckpt_path"]
== training_runner_hello_world_with_checkpoint.inference_checkpoint
)
assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule
mock_trainer.test.assert_called_once()
assert mock_trainer.test.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint
assert (
mock_trainer.test.call_args[1]["ckpt_path"]
== training_runner_hello_world_with_checkpoint.inference_checkpoint
)
assert mock_trainer.test.call_args[1]["datamodule"] == mock_datamodule
@pytest.mark.parametrize("run_extra_val_epoch", [True, False])
def test_resume_training_from_run_id(run_extra_val_epoch: bool, ml_runner_with_run_id: MLRunner) -> None:
ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch
ml_runner_with_run_id.container.max_num_gpus = 0
ml_runner_with_run_id.container.max_epochs += 10
assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path
def test_resume_training_from_run_id(
run_extra_val_epoch: bool, training_runner_hello_world_with_checkpoint: TrainingRunner
) -> None:
training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch
training_runner_hello_world_with_checkpoint.container.max_num_gpus = 0
training_runner_hello_world_with_checkpoint.container.max_epochs += 10
assert training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path
mock_trainer = MagicMock()
with patch("health_ml.run_ml.create_lightning_trainer", return_value=(mock_trainer, MagicMock())):
with patch.object(ml_runner_with_run_id.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test:
with patch.object(ml_runner_with_run_id, "run_inference") as mock_run_inference:
with patch("health_ml.run_ml.cleanup_checkpoints") as mock_cleanup_ckpt:
with patch("health_ml.runner_base.create_lightning_trainer", return_value=(mock_trainer, MagicMock())):
with patch.object(
training_runner_hello_world_with_checkpoint.container, "get_checkpoint_to_test"
) as mock_get_checkpoint_to_test:
with patch.object(training_runner_hello_world_with_checkpoint, "run_inference") as mock_run_inference:
with patch("health_ml.training_runner.cleanup_checkpoints") as mock_cleanup_ckpt:
mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True))
ml_runner_with_run_id.run()
training_runner_hello_world_with_checkpoint.run()
mock_get_checkpoint_to_test.assert_called_once()
mock_cleanup_ckpt.assert_called_once()
assert mock_trainer.validate.called == run_extra_val_epoch
@ -457,12 +534,12 @@ def test_model_weights_when_resume_training() -> None:
experiment_config = ExperimentConfig(model="HelloWorld")
container = HelloWorld()
container.max_num_gpus = 0
container.src_checkpoint = CheckpointParser(mock_run_id(id=0))
container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint))
container.resume_training = True
runner = MLRunner(experiment_config=experiment_config, container=container)
runner = TrainingRunner(experiment_config=experiment_config, container=container)
runner.setup()
assert runner.checkpoint_handler.trained_weights_path.is_file() # type: ignore
with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer:
with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer:
mock_create_trainer.return_value = MagicMock(), MagicMock()
runner.init_training()
mock_create_trainer.assert_called_once()
@ -482,14 +559,14 @@ def test_log_on_vm(log_from_vm: bool) -> None:
tag = f"test_log_on_vm [{log_from_vm}]"
container.tag = tag
container.log_from_vm = log_from_vm
runner = MLRunner(experiment_config=experiment_config, container=container)
runner = TrainingRunner(experiment_config=experiment_config, container=container)
# When logging to AzureML, need to provide the unit test AML workspace.
# When not logging to AzureML, no workspace (and no authentication) should be needed.
if log_from_vm:
with patch("health_azure.utils.get_workspace", return_value=DEFAULT_WORKSPACE.workspace):
runner.run()
runner.run_and_cleanup()
else:
runner.run()
runner.run_and_cleanup()
# The PL trainer object is created in the init_training method.
# Check that the AzureML logger is set up correctly.
assert runner.trainer is not None
@ -536,13 +613,15 @@ def test_get_mlflow_run_id_from_trainer() -> None:
assert run_id == mock_run_id
def test_inference_only_metrics_correctness(ml_runner_with_run_id: MLRunner, regression_datadir: Path) -> None:
ml_runner_with_run_id.container.run_inference_only = True
ml_runner_with_run_id.container.local_dataset_dir = regression_datadir
ml_runner_with_run_id.run()
with open(ml_runner_with_run_id.container.outputs_folder / "test_mse.txt") as f:
def test_inference_only_metrics_correctness(
training_runner_hello_world_with_checkpoint: TrainingRunner, regression_datadir: Path
) -> None:
training_runner_hello_world_with_checkpoint.container.run_inference_only = True
training_runner_hello_world_with_checkpoint.container.local_dataset_dir = regression_datadir
training_runner_hello_world_with_checkpoint.run()
with open(training_runner_hello_world_with_checkpoint.container.outputs_folder / TEST_MSE_FILE) as f:
mse = float(f.readlines()[0])
assert isclose(mse, 0.010806690901517868, abs_tol=1e-3)
with open(ml_runner_with_run_id.container.outputs_folder / "test_mae.txt") as f:
with open(training_runner_hello_world_with_checkpoint.container.outputs_folder / TEST_MAE_FILE) as f:
mae = float(f.readlines()[0])
assert isclose(mae, 0.08260975033044815, abs_tol=1e-3)