diff --git a/docs/source/runner.md b/docs/source/runner.md index 090822c7..903d9f69 100644 --- a/docs/source/runner.md +++ b/docs/source/runner.md @@ -7,6 +7,7 @@ use of these features: - Working with different models in the same codebase, and selecting one by name - Distributed training in AzureML - Logging via AzureML's native capabilities +- Evaluation of the trained model on new datasets This can be used by invoking the hi-ml runner and providing the name of the container class, like this: `himl-runner --model=MyContainer`. @@ -215,7 +216,13 @@ and returns a tuple containing the Optimizer and LRScheduler objects ## Run inference with a pretrained model You can use the hi-ml-runner in inference mode only by switching the `--run_inference_only` flag on and specifying -the model weights by setting `--src_checkpoint` argument that supports three types of checkpoints: +the model weights by setting `--src_checkpoint` argument. With this flag, the model will be evaluated on the test set +only. There is also an option for evaluating the model an a full dataset, described further below. + +### Specifying the checkpoint to use + +When running inference on a trained model, you need to provide a model checkpoint that should be used. This is done via +the `--src_checkpoint` argument. This supports three types of checkpoints: - A local path where the checkpoint is stored `--src_checkpoint=local/path/to/my_checkpoint/model.ckpt` - A remote URL from where to download the weights `--src_checkpoint=https://my_checkpoint_url.com/model.ckpt` @@ -228,13 +235,69 @@ the model weights by setting `--src_checkpoint` argument that supports three typ Refer to [Checkpoints Utils](checkpoints.md) for more details on how checkpoints are parsed. -Running the following command line will run inference using `MyContainer` model with weights from the checkpoint saved +### Running inference on the test set + +When supplying the flag `--run_inference_only` on the commandline, no model training will be run, and only inference on the +test set will be done: + +- The model weights will be loaded from the location specified by `--src_checkpoint` +- A PyTorch Lightining `Trainer` object will be instantiated. +- The test set will be read out from the data module specified by the `get_data_module` method of the + `LightningContainer` object. +- The model will be evaluated on the test set, by running `trainer.test`. Any special logic to use during the test step + will need to be added to the model's `test_step` method. + +Running the following command line will run inference using the `MyContainer` model with weights from the checkpoint saved in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`. ```bash himl-runner --model=Mycontainer --run_inference_only --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt ``` +### Running inference on a full dataset + +When supplying the flag `--mode=eval_full` on the commandline, no model training will be run, and the model will be +evaluated on a dataset different from the training/validation/test dataset. This dataset is loaded via the +`get_eval_data_module` method of the container. + +- The model weights will be loaded from the location specified by `--src_checkpoint` +- A PyTorch Lightining `Trainer` object will be instantiated. +- The test set will be read out from the data module specified by the `get_eval_data_module` method of the + `LightningContainer` object. The data module itself can read data from a mounted Azure dataset, which will be made + availabe for the container at the path `self.local_datasets`. In a typical use-case, all the data in that dataset will be + put into the `test_dataloader` field of the data module. +- The model will be evaluated on the test set, by running `trainer.test`. Any special logic to use during the test step + will need to added to the model's `test_step` method. + +Running the following command line will run inference using the `MyContainer` model with weights from the checkpoint saved +in the AzureMl run `MyContainer_XXXX_yyyy` at the best validation loss epoch `/outputs/checkpoints/best_val_loss.ckpt`. + +```bash +himl-runner --model=Mycontainer --src_checkpoint=MyContainer_XXXX_yyyy:best_val_loss.ckpt --mode=eval_full --azure_datasets=my_new_dataset +``` + +The example code snippet here shows how to add a method that reads the inference dataset. In this example, we assume +that the `MyDataModule` class has an argument `splits` that specifies the fraction of data to go into the training, +validation, and test data loaders. + +```python +class MyContainer(LightningContainer): + def __init__(self): + super().__init__() + self.azure_datasets = ["folder_name_in_azure_blob_storage"] + self.local_datasets = [Path("/some/local/path")] + self.max_epochs = 42 + + def create_model(self) -> LightningModule: + return MyLightningModel() + + def get_data_module(self) -> LightningDataModule: + return MyDataModule(root_path=self.local_dataset, splits=(0.7, 0.2, 0.1)) + + def get_eval_data_module(self) -> LightningDataModule: + return MyDataModule(root_path=self.local_dataset, splits=(0.0, 0.0, 1.0)) +``` + ## Resume training from a given checkpoint Analogously, one can resume training by setting `--src_checkpoint` and `--resume_training` to train a model longer. @@ -298,5 +361,6 @@ seeds on a local machine, other than by manually starting runs with `--random_se the limit that AML sets for snapshots. Solution: check for cache files, log files or other files that are not necessary for running your experiment and add them to a `.amlignore` file in the root directory. Alternatively, you can see Azure ML documentation for instructions on increasing this limit, although it will make your jobs slower. -2. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload -the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML. + +1. `"FileNotFoundError"`. Possible cause: Symlinked files. Azure ML SDK v2 will resolve the symlink and attempt to upload + the resolved file. Solution: Remove symlinks from any files that should be uploaded to Azure ML. diff --git a/hi-ml-azure/.vscode/settings.json b/hi-ml-azure/.vscode/settings.json index 67acf59e..88eeadbf 100644 --- a/hi-ml-azure/.vscode/settings.json +++ b/hi-ml-azure/.vscode/settings.json @@ -61,4 +61,9 @@ "${workspaceFolder}/src", "${workspaceFolder}/testazure" ], + "workbench.colorCustomizations": { + "activityBar.background": "#5D091F", + "titleBar.activeBackground": "#830D2B", + "titleBar.activeForeground": "#FFFCFC" + }, } diff --git a/hi-ml-azure/src/health_azure/datasets.py b/hi-ml-azure/src/health_azure/datasets.py index 544e3ef2..7ae7497e 100644 --- a/hi-ml-azure/src/health_azure/datasets.py +++ b/hi-ml-azure/src/health_azure/datasets.py @@ -524,7 +524,7 @@ def create_dataset_configs( count = num_azure elif num_azure == 0 and num_mount == 0: # No datasets in Azure at all: This is possible for runs that for example download their own data from the web. - # There can be any number of local datasets, but we are not checking that. In MLRunner.setup, there is a check + # There can be any number of local datasets, but we are not checking that. In TrainingRunner.setup, there is a check # that leaves local datasets intact if there are no Azure datasets. return [] else: diff --git a/hi-ml/src/health_ml/__init__.py b/hi-ml/src/health_ml/__init__.py index 3e5aab17..a8920681 100644 --- a/hi-ml/src/health_ml/__init__.py +++ b/hi-ml/src/health_ml/__init__.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -from health_ml.run_ml import MLRunner +from health_ml.training_runner import TrainingRunner from health_ml.runner import Runner -__all__ = ["MLRunner", "Runner"] +__all__ = ["TrainingRunner", "Runner"] diff --git a/hi-ml/src/health_ml/configs/hello_world.py b/hi-ml/src/health_ml/configs/hello_world.py index 97265163..a9cc8562 100644 --- a/hi-ml/src/health_ml/configs/hello_world.py +++ b/hi-ml/src/health_ml/configs/hello_world.py @@ -16,6 +16,9 @@ from torch.utils.data import DataLoader, Dataset from health_ml.lightning_container import LightningContainer +TEST_MSE_FILE = "test_mse.txt" +TEST_MAE_FILE = "test_mae.txt" + def _create_1d_regression_dataset(n: int = 100, seed: int = 0) -> torch.Tensor: """Creates a simple 1-D dataset of a noisy linear function. @@ -79,10 +82,10 @@ class HelloWorldDataModule(LightningDataModule): A data module that gives the training, validation and test data for a simple 1-dim regression task. """ - def __init__(self, crossval_count: int, crossval_index: Optional[int]) -> None: + def __init__(self, crossval_count: int, crossval_index: Optional[int] = None, seed: int = 0) -> None: super().__init__() n_total = 200 - xy = _create_1d_regression_dataset(n=n_total) + xy = _create_1d_regression_dataset(n=n_total, seed=seed) n_test = 40 n_val = 50 self.test = HelloWorldDataset(xy=xy[:n_test]) @@ -229,8 +232,8 @@ class HelloRegression(LightningModule): for example writing aggregate metrics to disk. """ average_mse = torch.mean(torch.stack(self.test_mse)) - Path("test_mse.txt").write_text(str(average_mse.item())) - Path("test_mae.txt").write_text(str(self.test_mae.compute().item())) + Path(TEST_MSE_FILE).write_text(str(average_mse.item())) + Path(TEST_MAE_FILE).write_text(str(self.test_mae.compute().item())) self.log("test_mse", average_mse, on_epoch=True, on_step=False) def on_run_extra_validation_epoch(self) -> None: @@ -266,6 +269,11 @@ class HelloWorld(LightningContainer): # datamodule must carry out appropriate splitting of the data. return HelloWorldDataModule(crossval_count=self.crossval_count, crossval_index=self.crossval_index) + # This method is optional. Override it to supply a data module to use for evaluating the model on a new dataset. + # Only the test data loader from the data module will be used. + def get_eval_data_module(self) -> LightningDataModule: + return HelloWorldDataModule(crossval_count=1, seed=1) + def get_callbacks(self) -> List[Callback]: if self.save_checkpoint: return [ diff --git a/hi-ml/src/health_ml/eval_runner.py b/hi-ml/src/health_ml/eval_runner.py new file mode 100644 index 00000000..68fcd564 --- /dev/null +++ b/hi-ml/src/health_ml/eval_runner.py @@ -0,0 +1,29 @@ +from pytorch_lightning import LightningDataModule + +from health_azure.logging import logging_section +from health_ml.runner_base import RunnerBase + + +class EvalRunner(RunnerBase): + """A class to run the evaluation of a model on a new dataset. The initialization logic is taken from the base + class `RunnerBase`. + """ + + def validate(self) -> None: + """Checks if the fields of the class are set up correctly.""" + if self.container.src_checkpoint is None or self.container.src_checkpoint.checkpoint == "": + raise ValueError( + "To use model evaluation, you need to provide a checkpoint to use, via the --src_checkpoint argument." + ) + + def run(self) -> None: + """Start the core workflow that the class implements: Initialize a PL Trainer object and use that to run + inference on the inference dataset.""" + self.container.outputs_folder.mkdir(exist_ok=True, parents=True) + self.init_inference() + with logging_section("Model inference"): + self.run_inference() + + def get_data_module(self) -> LightningDataModule: + """Reads the evaluation data module from the underlying container.""" + return self.container.get_eval_data_module() diff --git a/hi-ml/src/health_ml/experiment_config.py b/hi-ml/src/health_ml/experiment_config.py index 3b87e0bc..b63baf96 100644 --- a/hi-ml/src/health_ml/experiment_config.py +++ b/hi-ml/src/health_ml/experiment_config.py @@ -10,6 +10,11 @@ class DebugDDPOptions(Enum): DETAIL = "DETAIL" +class RunnerMode(Enum): + TRAIN = "train" + EVAL_FULL = "eval_full" + + DEBUG_DDP_ENV_VAR = "TORCH_DISTRIBUTED_DEBUG" @@ -82,3 +87,10 @@ class ExperimentConfig(param.Parameterized): doc="The maximum runtime that is allowed for this job in AzureML. This is given as a floating" "point number with a string suffix s, m, h, d for seconds, minutes, hours, day. Examples: '3.5h', '2d'", ) + mode: str = param.ClassSelector( + class_=RunnerMode, + default=RunnerMode.TRAIN, + doc=f"The mode to run the experiment in. Can be one of '{RunnerMode.TRAIN}' (training and evaluation on the " + f"test set), or '{RunnerMode.EVAL_FULL}' for evaluation on the full dataset specified by the " + "'get_eval_data_module' method of the container.", + ) diff --git a/hi-ml/src/health_ml/lightning_container.py b/hi-ml/src/health_ml/lightning_container.py index 7323ba0d..eed350c4 100644 --- a/hi-ml/src/health_ml/lightning_container.py +++ b/hi-ml/src/health_ml/lightning_container.py @@ -65,6 +65,17 @@ class LightningContainer(WorkflowParams, DatasetParams, OutputParams, TrainerPar """ return None # type: ignore + def get_eval_data_module(self) -> LightningDataModule: + """ + Gets the data that is used when evaluating the model on a new dataset. + This data module should read datasets from the self.local_datasets folder or download from a web location. + Only the test dataloader is used, hence the method needs to put all data into the test dataloader, rather + than splitting into train/val/test. + + :return: A LightningDataModule + """ + return None # type: ignore + def get_trainer_arguments(self) -> Dict[str, Any]: """ Gets additional parameters that will be passed on to the PyTorch Lightning trainer. diff --git a/hi-ml/src/health_ml/runner.py b/hi-ml/src/health_ml/runner.py index d87e26fd..de1ff448 100755 --- a/hi-ml/src/health_ml/runner.py +++ b/hi-ml/src/health_ml/runner.py @@ -31,9 +31,9 @@ from health_azure.datasets import create_dataset_configs # noqa: E402 from health_azure.himl import DEFAULT_DOCKER_BASE_IMAGE, OUTPUT_FOLDER # noqa: E402 from health_azure.logging import logging_to_stdout # noqa: E402 from health_azure.paths import is_himl_used_from_git_repo # noqa: E402 -from health_azure.utils import ( +from health_azure.utils import ( # noqa: E402 ENV_LOCAL_RANK, - ENV_NODE_RANK, # noqa: E402 + ENV_NODE_RANK, get_workspace, get_ml_client, is_local_rank_zero, @@ -43,9 +43,11 @@ from health_azure.utils import ( is_global_rank_zero, ) -from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig # noqa: E402 +from health_ml.eval_runner import EvalRunner # noqa: E402 +from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, ExperimentConfig, RunnerMode # noqa: E402 from health_ml.lightning_container import LightningContainer # noqa: E402 -from health_ml.run_ml import MLRunner # noqa: E402 +from health_ml.runner_base import RunnerBase # noqa: E402 +from health_ml.training_runner import TrainingRunner # noqa: E402 from health_ml.utils import fixed_paths # noqa: E402 from health_ml.utils.logging import ConsoleAndFileOutput # noqa: E402 from health_ml.utils.common_utils import check_conda_environment, choose_conda_env_file, is_linux # noqa: E402 @@ -101,8 +103,8 @@ class Runner: self.project_root = project_root self.experiment_config: ExperimentConfig = ExperimentConfig() self.lightning_container: LightningContainer = None # type: ignore - # This field stores the MLRunner object that has been created in the most recent call to the run() method. - self.ml_runner: Optional[MLRunner] = None + # This field stores the TrainingRunner object that has been created in the most recent call to the run() method. + self.ml_runner: Optional[RunnerBase] = None def parse_and_load_model(self) -> ParserResult: """ @@ -322,15 +324,26 @@ class Runner: assert azure_run_info.run is not None azure_run_info.run.set_tags(self.additional_run_tags(sys.argv[1:])) - # Set environment variables for multi-node training if needed. This function will terminate early - # if it detects that it is not in a multi-node environment. - if self.experiment_config.num_nodes > 1: - set_environment_variables_for_multi_node() - self.ml_runner = MLRunner( - experiment_config=self.experiment_config, container=self.lightning_container, project_root=self.project_root - ) - self.ml_runner.setup(azure_run_info) - self.ml_runner.run() + if self.experiment_config.mode == RunnerMode.TRAIN: + # Set environment variables for multi-node training if needed. This function will terminate early + # if it detects that it is not in a multi-node environment. + if self.experiment_config.num_nodes > 1: + set_environment_variables_for_multi_node() + self.ml_runner = TrainingRunner( + experiment_config=self.experiment_config, + container=self.lightning_container, + project_root=self.project_root, + ) + elif self.experiment_config.mode == RunnerMode.EVAL_FULL: + self.ml_runner = EvalRunner( + experiment_config=self.experiment_config, + container=self.lightning_container, + project_root=self.project_root, + ) + else: + raise ValueError(f"Unknown mode {self.experiment_config.mode}") + self.ml_runner.validate() + self.ml_runner.run_and_cleanup(azure_run_info) def run(project_root: Path) -> Tuple[LightningContainer, AzureRunInfo]: diff --git a/hi-ml/src/health_ml/runner_base.py b/hi-ml/src/health_ml/runner_base.py new file mode 100644 index 00000000..db2181b6 --- /dev/null +++ b/hi-ml/src/health_ml/runner_base.py @@ -0,0 +1,247 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ +import logging +from pathlib import Path +from typing import List, Optional + +from azureml.core import Run +from pytorch_lightning import LightningDataModule, Trainer, seed_everything + +from health_azure import AzureRunInfo +from health_azure.utils import ( + PARENT_RUN_CONTEXT, + RUN_CONTEXT, + create_aml_run_object, + create_run_recovery_id, + is_running_in_azure_ml, +) +from health_ml.experiment_config import ExperimentConfig +from health_ml.lightning_container import LightningContainer +from health_ml.model_trainer import create_lightning_trainer +from health_ml.utils import fixed_paths +from health_ml.utils.checkpoint_handler import CheckpointHandler +from health_ml.utils.common_utils import ( + EFFECTIVE_RANDOM_SEED_KEY_NAME, + RUN_RECOVERY_FROM_ID_KEY_NAME, + RUN_RECOVERY_ID_KEY, + change_working_directory, + seed_monai_if_available, +) +from health_ml.utils.lightning_loggers import StoringLogger, get_mlflow_run_id_from_trainer +from health_ml.utils.type_annotations import PathOrString + + +def check_dataset_folder_exists(local_dataset: PathOrString) -> Path: + """ + Checks if a folder with a local dataset exists. If it does exist, return the argument converted + to a Path instance. If it does not exist, raise a FileNotFoundError. + + :param local_dataset: The dataset folder to check. + :return: The local_dataset argument, converted to a Path. + """ + expected_dir = Path(local_dataset) + if not expected_dir.is_dir(): + raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.") + logging.info(f"Model will use the local dataset provided in {expected_dir}") + return expected_dir + + +class RunnerBase: + """ + A base class with operations that are shared between the training/test runner and the evaluation-only runner. + """ + + def __init__( + self, experiment_config: ExperimentConfig, container: LightningContainer, project_root: Optional[Path] = None + ) -> None: + """ + Driver class to run an ML experiment. Note that the project root argument MUST be supplied when using hi-ml + as a package! + + :param experiment_config: The ExperimentConfig object to use for training. + :param container: The LightningContainer object to use for training. + :param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying + it is crucial when using hi-ml as a package or submodule! + """ + self.container = container + self.experiment_config = experiment_config + self.container.num_nodes = self.experiment_config.num_nodes + self.project_root: Path = project_root or fixed_paths.repository_root_directory() + self.storing_logger: Optional[StoringLogger] = None + self._has_setup_run = False + self.checkpoint_handler = CheckpointHandler( + container=self.container, project_root=self.project_root, run_context=RUN_CONTEXT + ) + self.trainer: Optional[Trainer] = None + self.azureml_run_for_logging: Optional[Run] = None + self.mlflow_run_for_logging: Optional[str] = None + # This is passed to trainer.validate and trainer.test in inference mode + self.inference_checkpoint: Optional[str] = None + + def validate(self) -> None: + """ + Checks if all arguments and settings of the object are correct. + """ + pass + + def setup_azureml(self) -> None: + """ + Execute setup steps that are specific to AzureML. + """ + if PARENT_RUN_CONTEXT is not None: + # Set metadata for the run in AzureML if running in a Hyperdrive job. + run_tags_parent = PARENT_RUN_CONTEXT.get_tags() + tags_to_copy = [ + "tag", + "model_name", + "execution_mode", + "recovered_from", + "friendly_name", + "build_number", + "build_user", + RUN_RECOVERY_FROM_ID_KEY_NAME, + ] + new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy} + new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT) + new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed()) + RUN_CONTEXT.set_tags(new_tags) + + def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None: + """ + Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual + Lightning modules. + + :param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets. + This can be missing when running in unit tests, where the local dataset paths are already populated. + """ + if self._has_setup_run: + return + if azure_run_info: + # Set up the paths to the datasets. azure_run_info already has all necessary information, using either + # the provided local datasets for VM runs, or the AzureML mount points when running in AML. + # This must happen before container setup because that could already read datasets. + logging.info("Setting tags from parent run.") + input_datasets = azure_run_info.input_datasets + logging.info(f"Setting the following datasets as local datasets: {input_datasets}") + if len(input_datasets) > 0: + local_datasets: List[Path] = [] + for i, dataset in enumerate(input_datasets): + if dataset is None: + raise ValueError(f"Invalid setup: The dataset at index {i} is None") + local_datasets.append(check_dataset_folder_exists(dataset)) + self.container.local_datasets = local_datasets # type: ignore + # Ensure that we use fixed seeds before initializing the PyTorch models. + # MONAI needs a separate method to make all transforms deterministic by default + seed = self.container.get_effective_random_seed() + seed_monai_if_available(seed) + seed_everything(seed) + + # Creating the folder structure must happen before the LightningModule is created, because the output + # parameters of the container will be copied into the module. + self.container.create_filesystem(self.project_root) + + # configure recovery container if provided + self.checkpoint_handler.download_recovery_checkpoints_or_weights() + + # Create an AzureML run for logging if running outside AzureML. + self.create_logger() + + self.container.setup() + self.container.create_lightning_module_and_store() + self._has_setup_run = True + + if is_running_in_azure_ml(): + self.setup_azureml() + + def create_logger(self) -> None: + """ + Create an AzureML run for logging if running outside AzureML. This run will be used for metrics logging + during both training and inference. We can't rely on the automatically generated run inside the AzureMLLogger + class because two of those logger objects will be created, so training and inference metrics would be logged + in different runs. + """ + if self.container.log_from_vm: + run = create_aml_run_object(experiment_name=self.container.effective_experiment_name) + # Display name should already be set when creating the Run object, but in some scenarios this + # does not happen. Hence, set it again. + run.display_name = self.container.tag if self.container.tag else None + self.azureml_run_for_logging = run + + def get_data_module(self) -> LightningDataModule: + """ + Reads the datamodule that should be used for training or valuation from the container. This must be + overridden in subclasses. + """ + raise NotImplementedError() + + def set_trainer_for_inference(self) -> None: + """Set the runner's PL Trainer object that should be used when running inference on the validation or test set. + We run inference on a single device because distributed strategies such as DDP use DistributedSampler + internally, which replicates some samples to make sure all devices have the same batch size in case of + uneven inputs which biases the results.""" + mlflow_run_id = get_mlflow_run_id_from_trainer(self.trainer) + self.container.max_num_gpus = self.container.max_num_gpus_inference + self.trainer, _ = create_lightning_trainer( + container=self.container, + num_nodes=1, + azureml_run_for_logging=self.azureml_run_for_logging, + mlflow_run_for_logging=mlflow_run_id, + ) + + def init_inference(self) -> None: + """Prepare the runner for inference on validation set, test set, or a full dataset. + The following steps are performed: + + 1. Get the checkpoint to use for inference. This is either the checkpoint from the last training epoch or the + one specified in src_checkpoint argument. + + 2. Create a new trainer instance for inference. This is necessary because the trainer is created with a single + device in contrast to training that uses DDP if multiple GPUs are available. + + 3. Create a new data module instance for inference to account for any requested changes in the dataloading + parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch. + """ + self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test()) + self.set_trainer_for_inference() + self.data_module = self.get_data_module() + + def run_inference(self) -> None: + """Run inference on the test set for all models. This is done by calling the LightningModule.test_step method. + If the LightningModule.test_step method is not overridden, then this method does nothing. The cwd is changed to + the outputs folder so that the model can write to current working directory, and still everything is put into + the right place in AzureML (there, only the contents of the "outputs" folder is treated as a result file). + """ + if self.container.has_custom_test_step(): + logging.info("Running inference via the LightningModule.test_step method") + with change_working_directory(self.container.outputs_folder): + assert self.trainer, "Trainer should be initialized before inference. Call self.init_inference()." + _ = self.trainer.test( + self.container.model, datamodule=self.data_module, ckpt_path=self.inference_checkpoint + ) + else: + logging.warning("None of the suitable test methods is overridden. Skipping inference completely.") + + def run(self) -> None: + """ + Run the training or evaluation. This method must be overridden in subclasses. + """ + pass + + def run_and_cleanup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None: + """ + Run the training or evaluation via `self.run` and cleanup afterwards. + + :param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets. + This can be missing when running in unit tests, where the local dataset paths are already populated. + """ + self.setup(azure_run_info) + try: + self.run() + finally: + if self.azureml_run_for_logging is not None: + try: + self.azureml_run_for_logging.complete() + except Exception as ex: + logging.error("Failed to complete AzureML run: %s", ex) diff --git a/hi-ml/src/health_ml/run_ml.py b/hi-ml/src/health_ml/training_runner.py similarity index 50% rename from hi-ml/src/health_ml/run_ml.py rename to hi-ml/src/health_ml/training_runner.py index 97456d67..01c84334 100644 --- a/hi-ml/src/health_ml/run_ml.py +++ b/hi-ml/src/health_ml/training_runner.py @@ -6,14 +6,11 @@ import json import logging import os import sys -from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict import torch -from azureml.core import Run -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import LightningDataModule, seed_everything -from health_azure import AzureRunInfo from health_azure.logging import logging_section, print_message_with_rank_pid from health_azure.utils import ( ENV_GLOBAL_RANK, @@ -22,140 +19,24 @@ from health_azure.utils import ( ENV_OMPI_COMM_WORLD_RANK, PARENT_RUN_CONTEXT, RUN_CONTEXT, - create_aml_run_object, - create_run_recovery_id, get_metrics_for_hyperdrive_run, get_metrics_for_run, is_global_rank_zero, is_local_rank_zero, is_running_in_azure_ml, ) -from health_ml.experiment_config import ExperimentConfig -from health_ml.lightning_container import LightningContainer from health_ml.model_trainer import create_lightning_trainer, write_experiment_summary_file -from health_ml.utils import fixed_paths -from health_ml.utils.checkpoint_handler import CheckpointHandler +from health_ml.runner_base import RunnerBase from health_ml.utils.checkpoint_utils import cleanup_checkpoints from health_ml.utils.common_utils import ( - EFFECTIVE_RANDOM_SEED_KEY_NAME, - RUN_RECOVERY_FROM_ID_KEY_NAME, - RUN_RECOVERY_ID_KEY, change_working_directory, - seed_monai_if_available, ) -from health_ml.utils.lightning_loggers import StoringLogger, get_mlflow_run_id_from_trainer from health_ml.utils.regression_test_utils import REGRESSION_TEST_METRICS_FILENAME, compare_folders_and_run_outputs -from health_ml.utils.type_annotations import PathOrString -def check_dataset_folder_exists(local_dataset: PathOrString) -> Path: - """ - Checks if a folder with a local dataset exists. If it does exist, return the argument converted - to a Path instance. If it does not exist, raise a FileNotFoundError. - - :param local_dataset: The dataset folder to check. - :return: The local_dataset argument, converted to a Path. - """ - expected_dir = Path(local_dataset) - if not expected_dir.is_dir(): - raise FileNotFoundError(f"The model uses a dataset in {expected_dir}, but that does not exist.") - logging.info(f"Model training will use the local dataset provided in {expected_dir}") - return expected_dir - - -class MLRunner: - def __init__( - self, experiment_config: ExperimentConfig, container: LightningContainer, project_root: Optional[Path] = None - ) -> None: - """ - Driver class to run a ML experiment. Note that the project root argument MUST be supplied when using hi-ml - as a package! - - :param experiment_config: The ExperimentConfig object to use for training. - :param container: The LightningContainer object to use for training. - :param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying - it is crucial when using hi-ml as a package or submodule! - """ - self.container = container - self.experiment_config = experiment_config - self.container.num_nodes = self.experiment_config.num_nodes - self.project_root: Path = project_root or fixed_paths.repository_root_directory() - self.storing_logger: Optional[StoringLogger] = None - self._has_setup_run = False - self.checkpoint_handler = CheckpointHandler( - container=self.container, project_root=self.project_root, run_context=RUN_CONTEXT - ) - self.trainer: Optional[Trainer] = None - self.azureml_run_for_logging: Optional[Run] = None - self.mlflow_run_for_logging: Optional[str] = None - self.inference_checkpoint: Optional[str] = None # Passed to trainer.validate and trainer.test in inference mode - - def set_run_tags_from_parent(self) -> None: - """ - Set metadata for the run - """ - assert PARENT_RUN_CONTEXT, "This function should only be called in a Hyperdrive run." - run_tags_parent = PARENT_RUN_CONTEXT.get_tags() - tags_to_copy = [ - "tag", - "model_name", - "execution_mode", - "recovered_from", - "friendly_name", - "build_number", - "build_user", - RUN_RECOVERY_FROM_ID_KEY_NAME, - ] - new_tags = {tag: run_tags_parent.get(tag, "") for tag in tags_to_copy} - new_tags[RUN_RECOVERY_ID_KEY] = create_run_recovery_id(run=RUN_CONTEXT) - new_tags[EFFECTIVE_RANDOM_SEED_KEY_NAME] = str(self.container.get_effective_random_seed()) - RUN_CONTEXT.set_tags(new_tags) - - def setup(self, azure_run_info: Optional[AzureRunInfo] = None) -> None: - """ - Sets the random seeds, calls the setup method on the LightningContainer and then creates the actual - Lightning modules. - - :param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets. - This can be missing when running in unit tests, where the local dataset paths are already populated. - """ - if self._has_setup_run: - return - if azure_run_info: - # Set up the paths to the datasets. azure_run_info already has all necessary information, using either - # the provided local datasets for VM runs, or the AzureML mount points when running in AML. - # This must happen before container setup because that could already read datasets. - input_datasets = azure_run_info.input_datasets - logging.info(f"Setting the following datasets as local datasets: {input_datasets}") - if len(input_datasets) > 0: - local_datasets: List[Path] = [] - for i, dataset in enumerate(input_datasets): - if dataset is None: - raise ValueError(f"Invalid setup: The dataset at index {i} is None") - local_datasets.append(check_dataset_folder_exists(dataset)) - self.container.local_datasets = local_datasets # type: ignore - # Ensure that we use fixed seeds before initializing the PyTorch models. - # MONAI needs a separate method to make all transforms deterministic by default - seed = self.container.get_effective_random_seed() - seed_monai_if_available(seed) - seed_everything(seed) - - # Creating the folder structure must happen before the LightningModule is created, because the output - # parameters of the container will be copied into the module. - self.container.create_filesystem(self.project_root) - - # configure recovery container if provided - self.checkpoint_handler.download_recovery_checkpoints_or_weights() # type: ignore - - self.container.setup() - self.container.create_lightning_module_and_store() - self._has_setup_run = True - - is_offline_run = not is_running_in_azure_ml(RUN_CONTEXT) - # Get the AzureML context in which the script is running - if not is_offline_run and PARENT_RUN_CONTEXT is not None: - logging.info("Setting tags from parent run.") - self.set_run_tags_from_parent() +class TrainingRunner(RunnerBase): + def get_data_module(self) -> LightningDataModule: + return self.container.get_data_module() def get_multiple_trainloader_mode(self) -> str: # Workaround for a bug in PL 1.5.5: We need to pass the cycle mode for the training data as a trainer argument @@ -190,18 +71,7 @@ class MLRunner: seed_everything(self.container.get_effective_random_seed(), workers=True) # Get the container's datamodule - self.data_module = self.container.get_data_module() - - # Create an AzureML run for logging if running outside AzureML. This run will be used for metrics logging - # during both training and inference. We can't rely on the automatically generated run inside the AzureMLLogger - # class because two of those logger objects will be created, so training and inference metrics would be logged - # in different runs. - if self.container.log_from_vm: - run = create_aml_run_object(experiment_name=self.container.effective_experiment_name) - # Display name should already be set when creating the Run object, but in some scenarios this - # does not happen. Hence, set it again. - run.display_name = self.container.tag if self.container.tag else None - self.azureml_run_for_logging = run + self.data_module = self.get_data_module() if not self.container.run_inference_only: checkpoint_path_for_recovery = self.checkpoint_handler.get_recovery_or_checkpoint_path_train() @@ -275,36 +145,6 @@ class MLRunner: return self.container.crossval_index == 0 return True - def set_trainer_for_inference(self) -> None: - """Set the runner's PL Trainer object that should be used when running inference on the validation or test set. - We run inference on a single device because distributed strategies such as DDP use DistributedSampler - internally, which replicates some samples to make sure all devices have the same batch size in case of - uneven inputs which biases the results.""" - mlflow_run_id = get_mlflow_run_id_from_trainer(self.trainer) - self.container.max_num_gpus = self.container.max_num_gpus_inference - self.trainer, _ = create_lightning_trainer( - container=self.container, - num_nodes=1, - azureml_run_for_logging=self.azureml_run_for_logging, - mlflow_run_for_logging=mlflow_run_id, - ) - - def init_inference(self) -> None: - """Prepare the runner for inference: validation or test. The following steps are performed: - 1. Get the checkpoint to use for inference. This is either the checkpoint from the last training epoch or the - one specified in src_checkpoint argument. - 2. If the container has a run_extra_val_epoch method, call it to run an extra validation epoch. - 3. Create a new trainer instance for inference. This is necessary because the trainer is created with a single - device in contrast to training that uses DDP if multiple GPUs are available. - 4. Create a new data module instance for inference to account for any requested changes in the dataloading - parameters (e.g. batch_size, max_num_workers, etc) as part of on_run_extra_validation_epoch. - """ - self.inference_checkpoint = str(self.checkpoint_handler.get_checkpoint_to_test()) - if self.container.run_extra_val_epoch: - self.container.on_run_extra_validation_epoch() - self.set_trainer_for_inference() - self.data_module = self.container.get_data_module() - def run_training(self) -> None: """ The main training loop. It creates the Pytorch model based on the configuration options passed in, @@ -321,6 +161,16 @@ class MLRunner: assert logger is not None logger.finalize('success') + def init_inference(self) -> None: + """ + Prepare the trainer for running inference on the validation and test set. This chooses a checkpoint, + initializes the PL Trainer object, and chooses the right data module. Afterwards, the hook for running + inference on the validation set is run (`LightningContainer.on_run_extra_validation_epoch`) + """ + super().init_inference() + if self.container.run_extra_val_epoch: + self.container.on_run_extra_validation_epoch() + def run_validation(self) -> None: """Run validation on the validation set for all models to save time/memory consuming outputs. This is done in inference only mode or when the user has requested an extra validation epoch. The cwd is changed to the outputs @@ -334,22 +184,6 @@ class MLRunner: else: logging.info("Skipping extra validation because the user has not requested it.") - def run_inference(self) -> None: - """Run inference on the test set for all models. This is done by calling the LightningModule.test_step method. - If the LightningModule.test_step method is not overridden, then this method does nothing. The cwd is changed to - the outputs folder so that the model can write to current working directory, and still everything is put into - the right place in AzureML (there, only the contents of the "outputs" folder is treated as a result file). - """ - if self.container.has_custom_test_step(): - logging.info("Running inference via the LightningModule.test_step method") - with change_working_directory(self.container.outputs_folder): - assert self.trainer, "Trainer should be initialized before inference. Call self.init_inference()." - _ = self.trainer.test( - self.container.model, datamodule=self.data_module, ckpt_path=self.inference_checkpoint - ) - else: - logging.warning("None of the suitable test methods is overridden. Skipping inference completely.") - def run_regression_test(self) -> None: if self.container.regression_test_folder: with logging_section("Regression Test"): @@ -392,32 +226,23 @@ class MLRunner: """ Driver function to run a ML experiment """ - self.setup() - try: - self.init_training() + self.init_training() - if not self.container.run_inference_only: - # Backup the environment variables in case we need to run a second training in the unit tests. - environ_before_training = dict(os.environ) + if not self.container.run_inference_only: + # Backup the environment variables in case we need to run a second training in the unit tests. + environ_before_training = dict(os.environ) - with logging_section("Model training"): - self.run_training() + with logging_section("Model training"): + self.run_training() - self.end_training(environ_before_training) + self.end_training(environ_before_training) - self.init_inference() + self.init_inference() - with logging_section("Model validation"): - self.run_validation() + with logging_section("Model validation"): + self.run_validation() - with logging_section("Model inference"): - self.run_inference() + with logging_section("Model inference"): + self.run_inference() - self.run_regression_test() - - finally: - if self.azureml_run_for_logging is not None: - try: - self.azureml_run_for_logging.complete() - except Exception as ex: - logging.error("Failed to complete AzureML run: %s", ex) + self.run_regression_test() diff --git a/hi-ml/testhiml/conftest.py b/hi-ml/testhiml/conftest.py index c6368583..273bb4d9 100644 --- a/hi-ml/testhiml/conftest.py +++ b/hi-ml/testhiml/conftest.py @@ -1,4 +1,15 @@ +from pathlib import Path +import pytest + +from health_ml.runner import Runner from health_ml.utils import health_ml_package_setup # Reduce logging noise in DEBUG mode health_ml_package_setup() + + +@pytest.fixture +def mock_runner(tmp_path: Path) -> Runner: + """A test fixture that creates a Runner object in a temporary folder.""" + + return Runner(project_root=tmp_path) diff --git a/hi-ml/testhiml/testhiml/test_eval_runner.py b/hi-ml/testhiml/testhiml/test_eval_runner.py new file mode 100644 index 00000000..26bb2601 --- /dev/null +++ b/hi-ml/testhiml/testhiml/test_eval_runner.py @@ -0,0 +1,78 @@ +from pathlib import Path +import sys +from unittest.mock import patch + +import pytest + +from health_ml import TrainingRunner +from health_ml.configs.hello_world import ( + TEST_MAE_FILE, + TEST_MSE_FILE, + HelloWorld, + HelloWorldDataModule, +) +from health_ml.eval_runner import EvalRunner +from health_ml.experiment_config import ExperimentConfig, RunnerMode +from health_ml.runner import Runner +from health_ml.utils.checkpoint_utils import CheckpointParser +from testhiml.test_training_runner import training_runner_hello_world +from testhiml.utils.fixed_paths_for_tests import full_test_data_path + + +hello_world_checkpoint = full_test_data_path(suffix="hello_world_checkpoint.ckpt") + + +def test_eval_runner_no_checkpoint(mock_runner: Runner) -> None: + """Test of the evaluation mode fails if no checkpoint source is provided""" + arguments = ["", f"--model=HelloWorld", f"--mode={RunnerMode.EVAL_FULL.value}"] + with pytest.raises(ValueError, match="To use model evaluation, you need to provide a checkpoint to use"): + with patch.object(sys, "argv", arguments): + mock_runner.run() + + +def test_eval_runner_end_to_end(mock_runner: Runner) -> None: + """Test the end-to-end integration of the EvalRunner class into the overall Runner""" + arguments = [ + "", + f"--model=HelloWorld", + f"--mode={RunnerMode.EVAL_FULL.value}", + f"--src_checkpoint={hello_world_checkpoint}", + ] + with patch("health_ml.training_runner.TrainingRunner.run_and_cleanup") as mock_training_run: + with patch.object(sys, "argv", arguments): + mock_runner.run() + # The training runner should not be invoked + mock_training_run.assert_not_called() + # The eval runner should have been invoked. The test step writes two files with metrics, check that + # they exist + output_folder = mock_runner.lightning_container.outputs_folder + for file_name in [TEST_MSE_FILE, TEST_MAE_FILE]: + assert (output_folder / file_name).exists() + + +def test_eval_runner_methods_called(tmp_path: Path) -> None: + """Test if the eval runner uses the right data module from the HelloWorld model""" + container = HelloWorld() + container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint)) + eval_runner = EvalRunner( + container=container, experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL), project_root=tmp_path + ) + with patch("health_ml.configs.hello_world.HelloWorld.get_eval_data_module") as mock_get_data_module: + mock_get_data_module.return_value = HelloWorldDataModule(crossval_count=1, seed=1) + eval_runner.run_and_cleanup() + mock_get_data_module.assert_called_once_with() + + +def test_eval_runner_no_extra_validation_epoch_called(tmp_path: Path) -> None: + """ + Ensure that the eval runner does not invoke the hook the extra validation epoch that is used by the training runner. + """ + container = HelloWorld() + container.run_extra_val_epoch = True + container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint)) + eval_runner = EvalRunner( + container=container, experiment_config=ExperimentConfig(mode=RunnerMode.EVAL_FULL), project_root=tmp_path + ) + with patch("health_ml.configs.hello_world.HelloRegression.on_run_extra_validation_epoch") as mock_hook: + eval_runner.run_and_cleanup() + mock_hook.assert_not_called() diff --git a/hi-ml/testhiml/testhiml/test_regression_test_utils.py b/hi-ml/testhiml/testhiml/test_regression_test_utils.py index c71c0a25..99725709 100644 --- a/hi-ml/testhiml/testhiml/test_regression_test_utils.py +++ b/hi-ml/testhiml/testhiml/test_regression_test_utils.py @@ -13,7 +13,7 @@ import pytest from health_azure.utils import create_aml_run_object from health_ml.experiment_config import ExperimentConfig -from health_ml.run_ml import MLRunner +from health_ml.training_runner import TrainingRunner from health_ml.configs.hello_world import HelloWorld from health_ml.utils.regression_test_utils import ( CONTENTS_MISMATCH, @@ -54,7 +54,7 @@ def test_regression_test() -> None: """ container = HelloWorld() container.regression_test_folder = Path(str(uuid.uuid4().hex)) - runner = MLRunner(container=container, experiment_config=ExperimentConfig()) + runner = TrainingRunner(container=container, experiment_config=ExperimentConfig()) runner.setup() with pytest.raises(ValueError) as ex: runner.run() diff --git a/hi-ml/testhiml/testhiml/test_runner.py b/hi-ml/testhiml/testhiml/test_runner.py index 4838f9c7..7ad39b90 100644 --- a/hi-ml/testhiml/testhiml/test_runner.py +++ b/hi-ml/testhiml/testhiml/test_runner.py @@ -19,7 +19,7 @@ from health_azure import AzureRunInfo, DatasetConfig from health_azure.himl import OUTPUT_FOLDER from health_azure.utils import ENV_LOCAL_RANK, ENV_NODE_RANK from health_azure.paths import ENVIRONMENT_YAML_FILE_NAME -from health_ml.configs.hello_world import HelloWorld # type: ignore +from health_ml.configs.hello_world import TEST_MAE_FILE, TEST_MSE_FILE, HelloWorld # type: ignore from health_ml.deep_learning_config import WorkflowParams from health_ml.experiment_config import DEBUG_DDP_ENV_VAR, DebugDDPOptions from health_ml.lightning_container import LightningContainer @@ -28,13 +28,6 @@ from health_ml.utils.common_utils import change_working_directory from health_ml.utils.fixed_paths import repository_root_directory -@pytest.fixture -def mock_runner(tmp_path: Path) -> Runner: - """A test fixture that creates a Runner object in a temporary folder.""" - - return Runner(project_root=tmp_path) - - @contextmanager def change_working_folder_and_add_environment(tmp_path: Path) -> Generator: # Use a special simplified environment file only for the tests here. Copy that to a temp folder, then let the runner @@ -407,7 +400,7 @@ def test_run_hello_world(mock_runner: Runner) -> None: # time-consuming auth mock_get_workspace.assert_not_called() # Summary.txt is written at start, the other files during inference - expected_files = ["experiment_summary.txt", "test_mae.txt", "test_mse.txt"] + expected_files = ["experiment_summary.txt", TEST_MSE_FILE, TEST_MAE_FILE] for file in expected_files: assert (mock_runner.lightning_container.outputs_folder / file).is_file(), f"Missing file: {file}" @@ -428,9 +421,8 @@ def test_invalid_profiler(mock_runner: Runner) -> None: invalid_profile = "--pl_profiler=foo" arguments = ["", "--model=HelloWorld", invalid_profile] with patch.object(sys, "argv", arguments): - with pytest.raises(ValueError) as ex: + with pytest.raises(ValueError, match="Unsupported profiler."): mock_runner.run() - assert "Unsupported profiler." in str(ex) def test_custom_datastore_outside_aml(mock_runner: Runner) -> None: diff --git a/hi-ml/testhiml/testhiml/test_run_ml.py b/hi-ml/testhiml/testhiml/test_training_runner.py similarity index 52% rename from hi-ml/testhiml/testhiml/test_run_ml.py rename to hi-ml/testhiml/testhiml/test_training_runner.py index e2a233b7..462d1cb0 100644 --- a/hi-ml/testhiml/testhiml/test_run_ml.py +++ b/hi-ml/testhiml/testhiml/test_training_runner.py @@ -18,34 +18,38 @@ from pytorch_lightning import LightningModule import mlflow from pytorch_lightning import Trainer -from health_ml.configs.hello_world import HelloWorld # type: ignore +from health_ml.configs.hello_world import TEST_MAE_FILE, TEST_MSE_FILE, HelloWorld # type: ignore from health_ml.experiment_config import ExperimentConfig from health_ml.lightning_container import LightningContainer -from health_ml.run_ml import MLRunner +from health_ml.runner_base import RunnerBase +from health_ml.training_runner import TrainingRunner from health_ml.utils.checkpoint_handler import CheckpointHandler from health_ml.utils.checkpoint_utils import CheckpointParser -from health_ml.utils.common_utils import is_gpu_available +from health_ml.utils.common_utils import EFFECTIVE_RANDOM_SEED_KEY_NAME, is_gpu_available from health_ml.utils.lightning_loggers import HimlMLFlowLogger, StoringLogger, get_mlflow_run_id_from_trainer from health_azure.utils import ENV_EXPERIMENT_NAME, is_global_rank_zero from testazure.utils_testazure import DEFAULT_WORKSPACE, experiment_for_unittests -from testhiml.utils.fixed_paths_for_tests import mock_run_id +from testhiml.utils.fixed_paths_for_tests import full_test_data_path no_gpu = not is_gpu_available() +hello_world_checkpoint = full_test_data_path(suffix="hello_world_checkpoint.ckpt") -@pytest.fixture(scope="module") -def ml_runner_no_setup() -> MLRunner: +@pytest.fixture() +def training_runner_no_setup(tmp_path: Path) -> TrainingRunner: experiment_config = ExperimentConfig(model="HelloWorld") container = LightningContainer(num_epochs=1) - runner = MLRunner(experiment_config=experiment_config, container=container) + container.set_output_to(tmp_path) + runner = TrainingRunner(experiment_config=experiment_config, container=container) return runner -@pytest.fixture(scope="module") -def ml_runner() -> Generator: +@pytest.fixture() +def training_runner(tmp_path: Path) -> Generator: experiment_config = ExperimentConfig(model="HelloWorld") container = LightningContainer(num_epochs=1) - runner = MLRunner(experiment_config=experiment_config, container=container) + container.set_output_to(tmp_path) + runner = TrainingRunner(experiment_config=experiment_config, container=container) runner.setup() yield runner output_dir = runner.container.file_system_config.outputs_folder @@ -54,10 +58,11 @@ def ml_runner() -> Generator: @pytest.fixture() -def ml_runner_with_container() -> Generator: +def training_runner_hello_world(tmp_path: Path) -> Generator: experiment_config = ExperimentConfig(model="HelloWorld") container = HelloWorld() - runner = MLRunner(experiment_config=experiment_config, container=container) + container.set_output_to(tmp_path) + runner = TrainingRunner(experiment_config=experiment_config, container=container) runner.setup() yield runner output_dir = runner.container.file_system_config.outputs_folder @@ -66,12 +71,15 @@ def ml_runner_with_container() -> Generator: @pytest.fixture() -def ml_runner_with_run_id() -> Generator: +def training_runner_hello_world_with_checkpoint() -> Generator: + """ + A fixture with a training runner for the HelloWorld model that has a src_checkpoint set. + """ experiment_config = ExperimentConfig(model="HelloWorld") container = HelloWorld() container.save_checkpoint = True - container.src_checkpoint = CheckpointParser(mock_run_id(id=0)) - runner = MLRunner(experiment_config=experiment_config, container=container) + container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint)) + runner = TrainingRunner(experiment_config=experiment_config, container=container) runner.setup() yield runner output_dir = runner.container.file_system_config.outputs_folder @@ -92,14 +100,27 @@ def regression_datadir(tmp_path: Path) -> Generator: shutil.rmtree(tmp_path) -def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None: +def create_mlflow_trash_folder(runner: RunnerBase) -> None: + """Create a trash folder where MLFlow expects its deleted runs. + This is a workaround for sporadic test failures: When reading out the run_id, MLFlow checks its own + deleted runs folder, but that (or one of its parents) does not exist + """ + trash_folder = runner.container.outputs_folder / "mlruns" / ".trash" + trash_folder.mkdir(exist_ok=True, parents=True) + + +def test_ml_runner_setup(training_runner_no_setup: TrainingRunner) -> None: """Check that all the necessary methods get called during setup""" - assert not ml_runner_no_setup._has_setup_run - with patch.object(ml_runner_no_setup, "container", spec=LightningContainer) as mock_container: - with patch.object(ml_runner_no_setup, "checkpoint_handler", spec=CheckpointHandler) as mock_checkpoint_handler: - with patch("health_ml.run_ml.seed_everything") as mock_seed: - with patch("health_ml.run_ml.seed_monai_if_available") as mock_seed_monai: - ml_runner_no_setup.setup() + assert not training_runner_no_setup._has_setup_run + with patch.object(training_runner_no_setup, "container", spec=LightningContainer) as mock_container: + # Without that, it would try to create a local run object for logging and fail there. + mock_container.log_from_vm = False + with patch.object( + training_runner_no_setup, "checkpoint_handler", spec=CheckpointHandler + ) as mock_checkpoint_handler: + with patch("health_ml.runner_base.seed_everything") as mock_seed: + with patch("health_ml.runner_base.seed_monai_if_available") as mock_seed_monai: + training_runner_no_setup.setup() mock_container.get_effective_random_seed.assert_called() mock_seed.assert_called_once() mock_seed_monai.assert_called_once() @@ -107,47 +128,58 @@ def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None: mock_checkpoint_handler.download_recovery_checkpoints_or_weights.assert_called_once() mock_container.setup.assert_called_once() mock_container.create_lightning_module_and_store.assert_called_once() - assert ml_runner_no_setup._has_setup_run + assert training_runner_no_setup._has_setup_run -def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None: - """Test that set_run_tags_from_parents causes set_tags to get called""" - with pytest.raises(AssertionError) as ae: - ml_runner.set_run_tags_from_parent() - assert "should only be called in a Hyperdrive run" in str(ae) +def test_setup_azureml(training_runner: TrainingRunner) -> None: + """Test that setup_azureml causes set_tags to get called when running in Hyperdrive""" + with patch("health_ml.runner_base.RUN_CONTEXT") as mock_run_context: + training_runner.setup_azureml() + # Tests always run outside of a Hyperdrive run. In those cases, no tags should be set on + # the current run. + mock_run_context.set_tags.assert_not_called() - with patch("health_ml.run_ml.PARENT_RUN_CONTEXT") as mock_parent_run_context: - with patch("health_ml.run_ml.RUN_CONTEXT") as mock_run_context: - mock_parent_run_context.get_tags.return_value = {"tag": "dummy_tag"} - ml_runner.set_run_tags_from_parent() - mock_run_context.set_tags.assert_called() + with patch("health_ml.runner_base.PARENT_RUN_CONTEXT") as mock_parent_run_context: + # Mock the presence of a parent run, and tags that are present there + tag_name = "tag" + tag_value = "dummy_tag" + mock_parent_run_context.get_tags.return_value = {tag_name: tag_value} + training_runner.setup_azureml() + # The function should read out tags from the parent run, and set them on the current run + mock_parent_run_context.get_tags.assert_called_once_with() + mock_run_context.set_tags.assert_called_once() + call_args = mock_run_context.set_tags.call_args[0][0] + assert tag_name in call_args + assert call_args[tag_name] == tag_value + assert EFFECTIVE_RANDOM_SEED_KEY_NAME in call_args -def test_get_multiple_trainloader_mode(ml_runner: MLRunner) -> None: - multiple_trainloader_mode = ml_runner.get_multiple_trainloader_mode() +def test_get_multiple_trainloader_mode(training_runner: TrainingRunner) -> None: + training_runner.init_training() + multiple_trainloader_mode = training_runner.get_multiple_trainloader_mode() assert multiple_trainloader_mode == "max_size_cycle", "train_loader_cycle_mode is available now, " "`get_multiple_trainloader_mode` workaround can be safely removed." -def _test_init_training(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None: +def _test_init_training(run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture) -> None: """Test that training is initialized correctly""" - ml_runner.container.run_inference_only = run_inference_only - ml_runner.setup() - assert not ml_runner.checkpoint_handler.has_continued_training - assert ml_runner.trainer is None - assert ml_runner.storing_logger is None + training_runner.container.run_inference_only = run_inference_only + training_runner.setup() + assert not training_runner.checkpoint_handler.has_continued_training + assert training_runner.trainer is None + assert training_runner.storing_logger is None - with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer: - with patch.object(ml_runner.container, "get_data_module") as mock_get_data_module: - with patch("health_ml.run_ml.write_experiment_summary_file") as mock_write_experiment_summary_file: + with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer: + with patch.object(training_runner.container, "get_data_module") as mock_get_data_module: + with patch("health_ml.training_runner.write_experiment_summary_file") as mock_write_experiment_summary_file: with patch.object( - ml_runner.checkpoint_handler, "get_recovery_or_checkpoint_path_train" + training_runner.checkpoint_handler, "get_recovery_or_checkpoint_path_train" ) as mock_get_recovery_or_checkpoint_path_train: - with patch("health_ml.run_ml.seed_everything") as mock_seed: + with patch("health_ml.training_runner.seed_everything") as mock_seed: mock_create_trainer.return_value = MagicMock(), MagicMock() mock_get_recovery_or_checkpoint_path_train.return_value = "dummy_path" - ml_runner.init_training() + training_runner.init_training() # Make sure write_experiment_summary_file is only called on rank 0 if is_global_rank_zero(): @@ -157,47 +189,51 @@ def _test_init_training(run_inference_only: bool, ml_runner: MLRunner, caplog: L # Make sure seed is set correctly with workers=True mock_seed.assert_called_once() - assert mock_seed.call_args[0][0] == ml_runner.container.get_effective_random_seed() + assert mock_seed.call_args[0][0] == training_runner.container.get_effective_random_seed() assert mock_seed.call_args[1]["workers"] mock_get_data_module.assert_called_once() - assert ml_runner.data_module is not None + assert training_runner.data_module is not None if not run_inference_only: mock_get_recovery_or_checkpoint_path_train.assert_called_once() # Validate that the trainer is created correctly assert mock_create_trainer.call_args[1]["resume_from_checkpoint"] == "dummy_path" - assert ml_runner.storing_logger is not None - assert ml_runner.trainer is not None + assert training_runner.storing_logger is not None + assert training_runner.trainer is not None assert "Environment variables:" in caplog.messages[-1] else: - assert ml_runner.trainer is None - assert ml_runner.storing_logger is None + assert training_runner.trainer is None + assert training_runner.storing_logger is None mock_get_recovery_or_checkpoint_path_train.assert_not_called() @pytest.mark.parametrize("run_inference_only", [True, False]) -def test_init_training_cpu(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None: +def test_init_training_cpu( + run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture +) -> None: """Test that training is initialized correctly""" - ml_runner.container.max_num_gpus = 0 - _test_init_training(run_inference_only, ml_runner, caplog) + training_runner.container.max_num_gpus = 0 + _test_init_training(run_inference_only, training_runner, caplog) @pytest.mark.skipif(no_gpu, reason="Test requires GPU") @pytest.mark.gpu @pytest.mark.parametrize("run_inference_only", [True, False]) -def test_init_training_gpu(run_inference_only: bool, ml_runner: MLRunner, caplog: LogCaptureFixture) -> None: +def test_init_training_gpu( + run_inference_only: bool, training_runner: TrainingRunner, caplog: LogCaptureFixture +) -> None: """Test that training is initialized correctly in DDP mode""" - _test_init_training(run_inference_only, ml_runner, caplog) + _test_init_training(run_inference_only, training_runner, caplog) def test_run_training() -> None: experiment_config = ExperimentConfig(model="HelloWorld") container = HelloWorld() - runner = MLRunner(experiment_config=experiment_config, container=container) + runner = TrainingRunner(experiment_config=experiment_config, container=container) with patch.object(container, "get_data_module") as mock_get_data_module: - with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer: + with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer: runner.setup() mock_trainer = MagicMock() mock_storing_logger = MagicMock() @@ -225,17 +261,17 @@ def test_end_training(max_num_gpus_inf: int) -> None: experiment_config = ExperimentConfig(model="HelloWorld") container = HelloWorld() container.max_num_gpus_inference = max_num_gpus_inf - runner = MLRunner(experiment_config=experiment_config, container=container) + runner = TrainingRunner(experiment_config=experiment_config, container=container) with patch.object(container, "get_data_module"): - with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())): + with patch("health_ml.training_runner.create_lightning_trainer", return_value=(MagicMock(), MagicMock())): runner.setup() runner.init_training() runner.run_training() with patch.object(runner.checkpoint_handler, "additional_training_done") as mock_additional_training_done: with patch.object(runner, "after_ddp_cleanup") as mock_after_ddp_cleanup: - with patch("health_ml.run_ml.cleanup_checkpoints") as mock_cleanup_ckpt: + with patch("health_ml.training_runner.cleanup_checkpoints") as mock_cleanup_ckpt: environ_before_training = {"old": "environ"} runner.end_training(environ_before_training=environ_before_training) mock_additional_training_done.assert_called_once() @@ -251,71 +287,98 @@ def test_end_training(max_num_gpus_inf: int) -> None: @pytest.mark.parametrize("run_extra_val_epoch", [True, False]) @pytest.mark.parametrize("run_inference_only", [True, False]) def test_init_inference( - run_inference_only: bool, run_extra_val_epoch: bool, max_num_gpus_inf: int, ml_runner_with_run_id: MLRunner + run_inference_only: bool, + run_extra_val_epoch: bool, + max_num_gpus_inf: int, + training_runner_hello_world_with_checkpoint: TrainingRunner, ) -> None: - ml_runner_with_run_id.container.run_inference_only = run_inference_only - ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch - ml_runner_with_run_id.container.max_num_gpus_inference = max_num_gpus_inf - assert ml_runner_with_run_id.container.max_num_gpus == -1 # This is the default value of max_num_gpus - ml_runner_with_run_id.init_training() + training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only + training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch + training_runner_hello_world_with_checkpoint.container.max_num_gpus_inference = max_num_gpus_inf + assert ( + training_runner_hello_world_with_checkpoint.container.max_num_gpus == -1 + ) # This is the default value of max_num_gpus + training_runner_hello_world_with_checkpoint.init_training() if run_inference_only: expected_mlflow_run_id = None else: - assert ml_runner_with_run_id.trainer is not None - expected_mlflow_run_id = ml_runner_with_run_id.trainer.loggers[1].run_id # type: ignore + assert training_runner_hello_world_with_checkpoint.trainer is not None + create_mlflow_trash_folder(training_runner_hello_world_with_checkpoint) + expected_mlflow_run_id = training_runner_hello_world_with_checkpoint.trainer.loggers[1].run_id # type: ignore if not run_inference_only: - ml_runner_with_run_id.checkpoint_handler.additional_training_done() - with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer: - with patch.object(ml_runner_with_run_id.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test: - with patch.object(ml_runner_with_run_id.container, "get_data_module") as mock_get_data_module: + training_runner_hello_world_with_checkpoint.checkpoint_handler.additional_training_done() + with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer: + with patch.object( + training_runner_hello_world_with_checkpoint.container, "get_checkpoint_to_test" + ) as mock_get_checkpoint_to_test: + with patch.object( + training_runner_hello_world_with_checkpoint.container, "get_data_module" + ) as mock_get_data_module: mock_checkpoint = MagicMock(is_file=MagicMock(return_value=True)) mock_get_checkpoint_to_test.return_value = mock_checkpoint mock_trainer = MagicMock() mock_create_trainer.return_value = mock_trainer, MagicMock() mock_get_data_module.return_value = "dummy_data_module" - assert ml_runner_with_run_id.inference_checkpoint is None - assert not ml_runner_with_run_id.container.model._on_extra_val_epoch + assert training_runner_hello_world_with_checkpoint.inference_checkpoint is None + assert not training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch - ml_runner_with_run_id.init_inference() + training_runner_hello_world_with_checkpoint.init_inference() - expected_ckpt = str(ml_runner_with_run_id.checkpoint_handler.trained_weights_path) + expected_ckpt = str(training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path) expected_ckpt = expected_ckpt if run_inference_only else str(mock_checkpoint) - assert ml_runner_with_run_id.inference_checkpoint == expected_ckpt + assert training_runner_hello_world_with_checkpoint.inference_checkpoint == expected_ckpt - assert hasattr(ml_runner_with_run_id.container.model, "on_run_extra_validation_epoch") - assert ml_runner_with_run_id.container.model._on_extra_val_epoch == run_extra_val_epoch + assert hasattr( + training_runner_hello_world_with_checkpoint.container.model, "on_run_extra_validation_epoch" + ) + assert ( + training_runner_hello_world_with_checkpoint.container.model._on_extra_val_epoch + == run_extra_val_epoch + ) mock_create_trainer.assert_called_once() - assert ml_runner_with_run_id.trainer == mock_trainer - assert ml_runner_with_run_id.container.max_num_gpus == max_num_gpus_inf - assert mock_create_trainer.call_args[1]["container"] == ml_runner_with_run_id.container + assert training_runner_hello_world_with_checkpoint.trainer == mock_trainer + assert training_runner_hello_world_with_checkpoint.container.max_num_gpus == max_num_gpus_inf + assert ( + mock_create_trainer.call_args[1]["container"] + == training_runner_hello_world_with_checkpoint.container + ) assert mock_create_trainer.call_args[1]["num_nodes"] == 1 assert mock_create_trainer.call_args[1]["mlflow_run_for_logging"] == expected_mlflow_run_id mock_get_data_module.assert_called_once() - assert ml_runner_with_run_id.data_module == "dummy_data_module" + assert training_runner_hello_world_with_checkpoint.data_module == "dummy_data_module" @pytest.mark.parametrize("run_inference_only", [True, False]) @pytest.mark.parametrize("run_extra_val_epoch", [True, False]) def test_run_validation( - run_extra_val_epoch: bool, run_inference_only: bool, ml_runner_with_run_id: MLRunner, caplog: LogCaptureFixture + run_extra_val_epoch: bool, + run_inference_only: bool, + training_runner_hello_world_with_checkpoint: TrainingRunner, + caplog: LogCaptureFixture, ) -> None: - ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch - ml_runner_with_run_id.container.run_inference_only = run_inference_only - ml_runner_with_run_id.init_training() + training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch + training_runner_hello_world_with_checkpoint.container.run_inference_only = run_inference_only + training_runner_hello_world_with_checkpoint.init_training() mock_datamodule = MagicMock() - with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer: - with patch.object(ml_runner_with_run_id.container, "get_data_module", return_value=mock_datamodule): + create_mlflow_trash_folder(training_runner_hello_world_with_checkpoint) + with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer: + with patch.object( + training_runner_hello_world_with_checkpoint.container, "get_data_module", return_value=mock_datamodule + ): mock_trainer = MagicMock() mock_create_trainer.return_value = mock_trainer, MagicMock() - ml_runner_with_run_id.init_inference() - assert ml_runner_with_run_id.trainer == mock_trainer + training_runner_hello_world_with_checkpoint.init_inference() + assert training_runner_hello_world_with_checkpoint.trainer == mock_trainer mock_trainer.validate = Mock() - ml_runner_with_run_id.run_validation() + training_runner_hello_world_with_checkpoint.run_validation() if run_extra_val_epoch or run_inference_only: mock_trainer.validate.assert_called_once() - assert mock_trainer.validate.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint + assert ( + mock_trainer.validate.call_args[1]["ckpt_path"] + == training_runner_hello_world_with_checkpoint.inference_checkpoint + ) assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule else: assert "Skipping extra validation" in caplog.messages[-1] @@ -332,12 +395,12 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None: container = HelloWorld() container.create_lightning_module_and_store() container.run_extra_val_epoch = True - runner = MLRunner(experiment_config=experiment_config, container=container) + runner = TrainingRunner(experiment_config=experiment_config, container=container) runner.setup() runner.checkpoint_handler.additional_training_done() runner.container.outputs_folder.mkdir(parents=True, exist_ok=True) with patch.object(container, "get_data_module"): - with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())): + with patch("health_ml.runner_base.create_lightning_trainer", return_value=(MagicMock(), MagicMock())): with patch.object(runner.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test: mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True)) runner.init_inference() @@ -346,34 +409,34 @@ def test_model_extra_val_epoch_missing_hook(caplog: LogCaptureFixture) -> None: assert "Hook `on_run_extra_validation_epoch` is not implemented" in latest_message -def test_run_inference(ml_runner_with_container: MLRunner, regression_datadir: Path) -> None: +def test_run_inference(training_runner_hello_world: TrainingRunner, regression_datadir: Path) -> None: """ Test that run_inference gets called as expected. """ - ml_runner_with_container.container.max_num_gpus = 0 + training_runner_hello_world.container.max_num_gpus = 0 def _expected_files_exist() -> bool: - output_dir = ml_runner_with_container.container.outputs_folder + output_dir = training_runner_hello_world.container.outputs_folder if not output_dir.is_dir(): return False - expected_files = ["test_mse.txt", "test_mae.txt"] + expected_files = [TEST_MSE_FILE, TEST_MAE_FILE] return all([(output_dir / p).exists() for p in expected_files]) - expected_ckpt_path = ml_runner_with_container.container.outputs_folder / "checkpoints" / "last.ckpt" + expected_ckpt_path = training_runner_hello_world.container.outputs_folder / "checkpoints" / "last.ckpt" assert not expected_ckpt_path.exists() # update the container to look for test data at this location - ml_runner_with_container.container.local_dataset_dir = regression_datadir + training_runner_hello_world.container.local_dataset_dir = regression_datadir assert not _expected_files_exist() - actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train() + actual_train_ckpt_path = training_runner_hello_world.checkpoint_handler.get_recovery_or_checkpoint_path_train() assert actual_train_ckpt_path is None - ml_runner_with_container.run() + training_runner_hello_world.run() - actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train() + actual_train_ckpt_path = training_runner_hello_world.checkpoint_handler.get_recovery_or_checkpoint_path_train() assert actual_train_ckpt_path == expected_ckpt_path - actual_test_ckpt_path = ml_runner_with_container.checkpoint_handler.get_checkpoint_to_test() + actual_test_ckpt_path = training_runner_hello_world.checkpoint_handler.get_checkpoint_to_test() assert actual_test_ckpt_path == expected_ckpt_path assert actual_test_ckpt_path.is_file() # After training, the outputs directory should now exist and contain the 2 error files @@ -382,25 +445,25 @@ def test_run_inference(ml_runner_with_container: MLRunner, regression_datadir: P @pytest.mark.parametrize("run_extra_val_epoch", [True, False]) @pytest.mark.parametrize("run_inference_only", [True, False]) -def test_run(run_inference_only: bool, run_extra_val_epoch: bool, ml_runner_with_container: MLRunner) -> None: +def test_run(run_inference_only: bool, run_extra_val_epoch: bool, training_runner_hello_world: TrainingRunner) -> None: """Test that model runner gets called""" - ml_runner_with_container.container.run_inference_only = run_inference_only - ml_runner_with_container.container.run_extra_val_epoch = run_extra_val_epoch - ml_runner_with_container.setup() - assert not ml_runner_with_container.checkpoint_handler.has_continued_training + training_runner_hello_world.container.run_inference_only = run_inference_only + training_runner_hello_world.container.run_extra_val_epoch = run_extra_val_epoch + training_runner_hello_world.setup() + assert not training_runner_hello_world.checkpoint_handler.has_continued_training - with patch("health_ml.run_ml.create_lightning_trainer", return_value=(MagicMock(), MagicMock())): + with patch("health_ml.runner_base.create_lightning_trainer", return_value=(MagicMock(), MagicMock())): with patch.multiple( - ml_runner_with_container, + training_runner_hello_world, checkpoint_handler=mock.DEFAULT, run_training=mock.DEFAULT, run_validation=mock.DEFAULT, run_inference=mock.DEFAULT, end_training=mock.DEFAULT, ) as mocks: - ml_runner_with_container.run() - assert ml_runner_with_container.container.has_custom_test_step() - assert ml_runner_with_container._has_setup_run + training_runner_hello_world.run() + assert training_runner_hello_world.container.has_custom_test_step() + assert training_runner_hello_world._has_setup_run assert mocks["end_training"] != run_inference_only assert mocks["run_training"].called != run_inference_only mocks["run_validation"].assert_called_once() @@ -408,45 +471,59 @@ def test_run(run_inference_only: bool, run_extra_val_epoch: bool, ml_runner_with @pytest.mark.parametrize("run_extra_val_epoch", [True, False]) -def test_run_inference_only(run_extra_val_epoch: bool, ml_runner_with_run_id: MLRunner) -> None: +def test_run_inference_only( + run_extra_val_epoch: bool, training_runner_hello_world_with_checkpoint: TrainingRunner +) -> None: """Test inference only mode. Validation should be run regardless of run_extra_val_epoch status.""" - ml_runner_with_run_id.container.run_inference_only = True - ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch - assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path + training_runner_hello_world_with_checkpoint.container.run_inference_only = True + training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch + assert training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path mock_datamodule = MagicMock() - with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer: - with patch.object(ml_runner_with_run_id.container, "get_data_module", return_value=mock_datamodule): + with patch("health_ml.runner_base.create_lightning_trainer") as mock_create_trainer: + with patch.object( + training_runner_hello_world_with_checkpoint.container, "get_data_module", return_value=mock_datamodule + ): with patch.multiple( - ml_runner_with_run_id, + training_runner_hello_world_with_checkpoint, run_training=mock.DEFAULT, ) as mocks: mock_trainer = MagicMock() mock_create_trainer.return_value = mock_trainer, MagicMock() - ml_runner_with_run_id.run() + training_runner_hello_world_with_checkpoint.run() mock_create_trainer.assert_called_once() mocks["run_training"].assert_not_called() mock_trainer.validate.assert_called_once() - assert mock_trainer.validate.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint + assert ( + mock_trainer.validate.call_args[1]["ckpt_path"] + == training_runner_hello_world_with_checkpoint.inference_checkpoint + ) assert mock_trainer.validate.call_args[1]["datamodule"] == mock_datamodule mock_trainer.test.assert_called_once() - assert mock_trainer.test.call_args[1]["ckpt_path"] == ml_runner_with_run_id.inference_checkpoint + assert ( + mock_trainer.test.call_args[1]["ckpt_path"] + == training_runner_hello_world_with_checkpoint.inference_checkpoint + ) assert mock_trainer.test.call_args[1]["datamodule"] == mock_datamodule @pytest.mark.parametrize("run_extra_val_epoch", [True, False]) -def test_resume_training_from_run_id(run_extra_val_epoch: bool, ml_runner_with_run_id: MLRunner) -> None: - ml_runner_with_run_id.container.run_extra_val_epoch = run_extra_val_epoch - ml_runner_with_run_id.container.max_num_gpus = 0 - ml_runner_with_run_id.container.max_epochs += 10 - assert ml_runner_with_run_id.checkpoint_handler.trained_weights_path +def test_resume_training_from_run_id( + run_extra_val_epoch: bool, training_runner_hello_world_with_checkpoint: TrainingRunner +) -> None: + training_runner_hello_world_with_checkpoint.container.run_extra_val_epoch = run_extra_val_epoch + training_runner_hello_world_with_checkpoint.container.max_num_gpus = 0 + training_runner_hello_world_with_checkpoint.container.max_epochs += 10 + assert training_runner_hello_world_with_checkpoint.checkpoint_handler.trained_weights_path mock_trainer = MagicMock() - with patch("health_ml.run_ml.create_lightning_trainer", return_value=(mock_trainer, MagicMock())): - with patch.object(ml_runner_with_run_id.container, "get_checkpoint_to_test") as mock_get_checkpoint_to_test: - with patch.object(ml_runner_with_run_id, "run_inference") as mock_run_inference: - with patch("health_ml.run_ml.cleanup_checkpoints") as mock_cleanup_ckpt: + with patch("health_ml.runner_base.create_lightning_trainer", return_value=(mock_trainer, MagicMock())): + with patch.object( + training_runner_hello_world_with_checkpoint.container, "get_checkpoint_to_test" + ) as mock_get_checkpoint_to_test: + with patch.object(training_runner_hello_world_with_checkpoint, "run_inference") as mock_run_inference: + with patch("health_ml.training_runner.cleanup_checkpoints") as mock_cleanup_ckpt: mock_get_checkpoint_to_test.return_value = MagicMock(is_file=MagicMock(return_value=True)) - ml_runner_with_run_id.run() + training_runner_hello_world_with_checkpoint.run() mock_get_checkpoint_to_test.assert_called_once() mock_cleanup_ckpt.assert_called_once() assert mock_trainer.validate.called == run_extra_val_epoch @@ -457,12 +534,12 @@ def test_model_weights_when_resume_training() -> None: experiment_config = ExperimentConfig(model="HelloWorld") container = HelloWorld() container.max_num_gpus = 0 - container.src_checkpoint = CheckpointParser(mock_run_id(id=0)) + container.src_checkpoint = CheckpointParser(str(hello_world_checkpoint)) container.resume_training = True - runner = MLRunner(experiment_config=experiment_config, container=container) + runner = TrainingRunner(experiment_config=experiment_config, container=container) runner.setup() assert runner.checkpoint_handler.trained_weights_path.is_file() # type: ignore - with patch("health_ml.run_ml.create_lightning_trainer") as mock_create_trainer: + with patch("health_ml.training_runner.create_lightning_trainer") as mock_create_trainer: mock_create_trainer.return_value = MagicMock(), MagicMock() runner.init_training() mock_create_trainer.assert_called_once() @@ -482,14 +559,14 @@ def test_log_on_vm(log_from_vm: bool) -> None: tag = f"test_log_on_vm [{log_from_vm}]" container.tag = tag container.log_from_vm = log_from_vm - runner = MLRunner(experiment_config=experiment_config, container=container) + runner = TrainingRunner(experiment_config=experiment_config, container=container) # When logging to AzureML, need to provide the unit test AML workspace. # When not logging to AzureML, no workspace (and no authentication) should be needed. if log_from_vm: with patch("health_azure.utils.get_workspace", return_value=DEFAULT_WORKSPACE.workspace): - runner.run() + runner.run_and_cleanup() else: - runner.run() + runner.run_and_cleanup() # The PL trainer object is created in the init_training method. # Check that the AzureML logger is set up correctly. assert runner.trainer is not None @@ -536,13 +613,15 @@ def test_get_mlflow_run_id_from_trainer() -> None: assert run_id == mock_run_id -def test_inference_only_metrics_correctness(ml_runner_with_run_id: MLRunner, regression_datadir: Path) -> None: - ml_runner_with_run_id.container.run_inference_only = True - ml_runner_with_run_id.container.local_dataset_dir = regression_datadir - ml_runner_with_run_id.run() - with open(ml_runner_with_run_id.container.outputs_folder / "test_mse.txt") as f: +def test_inference_only_metrics_correctness( + training_runner_hello_world_with_checkpoint: TrainingRunner, regression_datadir: Path +) -> None: + training_runner_hello_world_with_checkpoint.container.run_inference_only = True + training_runner_hello_world_with_checkpoint.container.local_dataset_dir = regression_datadir + training_runner_hello_world_with_checkpoint.run() + with open(training_runner_hello_world_with_checkpoint.container.outputs_folder / TEST_MSE_FILE) as f: mse = float(f.readlines()[0]) assert isclose(mse, 0.010806690901517868, abs_tol=1e-3) - with open(ml_runner_with_run_id.container.outputs_folder / "test_mae.txt") as f: + with open(training_runner_hello_world_with_checkpoint.container.outputs_folder / TEST_MAE_FILE) as f: mae = float(f.readlines()[0]) assert isclose(mae, 0.08260975033044815, abs_tol=1e-3)