Add inference to the hi-ml runner, and downgrade lightning-bolts to 0.4.0 in line with InnerEye
This commit is contained in:
mebristo 2022-02-22 13:31:26 +00:00 коммит произвёл GitHub
Родитель 2bc397b470
Коммит fed8220456
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
15 изменённых файлов: 307 добавлений и 129 удалений

Просмотреть файл

@ -18,6 +18,7 @@ the environment file since it is necessary for the augmentations.
- ([#178](https://github.com/microsoft/hi-ml/pull/178)) Add runner script for running ML experiments.
- ([#181](https://github.com/microsoft/hi-ml/pull/181)) Add computational pathology tools in hi-ml-histopathology folder.
- ([#187](https://github.com/microsoft/hi-ml/pull/187)) Add mean pooling layer for MIL.
- ([#186](https://github.com/microsoft/hi-ml/pull/186)) Add inference to hi-ml runner.
### Changed

Просмотреть файл

@ -41,6 +41,12 @@ The `hi-ml` toolbox provides
histopathology.md
.. toctree::
:maxdepth: 1
:caption: Self supervised learning
self_supervised_models.md
.. toctree::
:maxdepth: 1
:caption: Developers

Просмотреть файл

@ -57,6 +57,10 @@ ENV_NODE_RANK = "NODE_RANK"
ENV_GLOBAL_RANK = "GLOBAL_RANK"
ENV_LOCAL_RANK = "LOCAL_RANK"
ENVIRONMENT_VERSION = "1"
FINAL_MODEL_FOLDER = "final_model"
MODEL_ID_KEY_NAME = "model_id"
PYTHON_ENVIRONMENT_NAME = "python_environment_name"
RUN_CONTEXT = Run.get_context()
PARENT_RUN_CONTEXT = getattr(RUN_CONTEXT, "parent", None)
WORKSPACE_CONFIG_JSON = "config.json"

Просмотреть файл

@ -228,11 +228,11 @@ def test_split_recovery_id_fails() -> None:
with pytest.raises(ValueError) as e:
id = util.EXPERIMENT_RUN_SEPARATOR.join([str(i) for i in range(3)])
util.split_recovery_id(id)
assert str(e.value) == f"recovery_id must be in the format: 'experiment_name:run_id', but got: {id}"
assert str(e.value) == f"recovery_id must be in the format: 'experiment_name:run_id', but got: {id}"
with pytest.raises(ValueError) as e:
id = "foo_bar"
util.split_recovery_id(id)
assert str(e.value) == f"The recovery ID was not in the expected format: {id}"
assert str(e.value) == f"The recovery ID was not in the expected format: {id}"
@pytest.mark.parametrize(["id", "expected1", "expected2"],

Просмотреть файл

@ -11,46 +11,31 @@ dependencies:
- python-blosc==1.7.0
- torchvision=0.11.1
- pip:
- -r ../test_requirements.txt
- azureml-sdk==1.36.0
- conda-merge==0.1.5
- cryptography==3.3.2
- docker==4.3.1
- gitpython==3.1.7
- flask==2.0.1
- gputil==1.4.0
- hi-ml>=0.1.12
- joblib==0.16.0
- jupyter==1.0.0
- jupyter-client==6.1.5
- lightning-bolts==0.5.0
- matplotlib==3.3.0
- mlflow==1.17.0
- monai==0.6.0
- more-itertools==8.10.0
- mypy==0.910
- mypy-extensions==0.4.3
- numba==0.51.2
- numpy==1.19.1
- opencv-python-headless==4.5.1.48
- pandas==1.3.4
- param==1.9.3
- pillow==8.3.2
- psutil==5.7.2
- pydicom==2.0.0
- pyflakes==2.2.0
- PyJWT==1.7.1
- pytorch-lightning==1.5.5
- rich==5.1.1
- rpdb==0.1.6
- ruamel.yaml==0.16.12
- runstats==1.8.0
- scikit-image==0.17.2
- scikit-learn==0.23.2
- scipy==1.5.2
- simpleitk==1.2.4
- six==1.15.0
- stopit==1.1.2
- tabulate==0.8.7
- torchprof==1.3.3
- torchmetrics==0.6.0
- umap-learn==0.5.2
- yacs==0.1.8

Просмотреть файл

@ -44,7 +44,7 @@ class NIH_RSNA_SimCLR(SSLContainer):
super().__init__(ssl_training_dataset_name=SSLDatasetName.NIHCXR,
linear_head_dataset_name=SSLDatasetName.RSNAKaggleCXR,
# the first Azure dataset is for training, the second is for the linear head
azure_dataset_id=[NIH_AZURE_DATASET_ID, RSNA_AZURE_DATASET_ID],
azure_datasets=[NIH_AZURE_DATASET_ID, RSNA_AZURE_DATASET_ID],
random_seed=1,
max_epochs=1000,
# We usually train this model with 16 GPUs, giving an effective batch size of 1200

Просмотреть файл

@ -142,7 +142,8 @@ class SslOnlineEvaluatorHiml(SSLOnlineEvaluator):
batch = batch[SSLDataModuleType.LINEAR_HEAD] if isinstance(batch, dict) else batch
x, y = self.to_device(batch, pl_module.device)
with torch.no_grad():
representations = pl_module(x).flatten(start_dim=1)
representations = self.get_representations(pl_module, x)
representations = representations.detach()
# Run the linear-head with SSL embeddings.
mlp_preds = self.evaluator(representations)

Просмотреть файл

@ -312,7 +312,8 @@ class DeepMILModule(LightningModule):
list_slide_dicts.append(slide_dict)
list_encoded_features.append(results[ResultsKey.IMAGE][slide_idx])
outputs_path = fixed_paths.repository_parent_directory() / 'outputs'
outputs_path = fixed_paths.repository_root_directory() / 'outputs'
assert outputs_path.is_dir, f"No such dir: {outputs_path}"
print(f"Metrics results will be output to {outputs_path}")
outputs_fig_path = outputs_path / 'fig'
csv_filename = outputs_path / 'test_output.csv'
@ -324,7 +325,7 @@ class DeepMILModule(LightningModule):
slide_dict = self.normalize_dict_for_df(slide_dict, use_gpu=False)
df_list.append(pd.DataFrame.from_dict(slide_dict))
df = pd.concat(df_list, ignore_index=True)
df.to_csv(csv_filename, mode='w', header=True)
df.to_csv(csv_filename, mode='w+', header=True)
# Collect all features in a list and save
features_list = self.move_list_to_device(list_encoded_features, use_gpu=False)

Просмотреть файл

@ -1,3 +1,4 @@
dataclasses-json==0.5.2
hi-ml-azure>=0.1.8
jinja2==3.0.2
matplotlib==3.4.3

Просмотреть файл

@ -104,6 +104,7 @@ class ExperimentFolderHandler(Parameterized):
else:
logging.info("All results will be written to a subfolder of the project root folder.")
root = project_root.absolute() / DEFAULT_AML_UPLOAD_DIR
timestamp = create_unique_timestamp_id()
run_folder = root / f"{timestamp}_{model_name}"
outputs_folder = run_folder
@ -114,6 +115,7 @@ class ExperimentFolderHandler(Parameterized):
run_folder = project_root
outputs_folder = project_root / DEFAULT_AML_UPLOAD_DIR
logs_folder = project_root / DEFAULT_LOGS_DIR_NAME
logging.info(f"Run outputs folder: {outputs_folder}")
logging.info(f"Logs folder: {logs_folder}")
return ExperimentFolderHandler(

Просмотреть файл

@ -5,14 +5,14 @@
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional
from typing import List, Optional
import torch.multiprocessing
from pytorch_lightning import seed_everything
from pytorch_lightning import LightningModule, seed_everything
from health_azure import AzureRunInfo
from health_azure.utils import (ENV_OMPI_COMM_WORLD_RANK, RUN_CONTEXT, create_run_recovery_id,
PARENT_RUN_CONTEXT, is_running_in_azure_ml)
from health_azure.utils import (create_run_recovery_id, ENV_OMPI_COMM_WORLD_RANK,
is_running_in_azure_ml, PARENT_RUN_CONTEXT, RUN_CONTEXT)
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
@ -20,8 +20,8 @@ from health_ml.model_trainer import create_lightning_trainer, model_train
from health_ml.utils import fixed_paths
from health_ml.utils.checkpoint_utils import CheckpointHandler
from health_ml.utils.common_utils import (
change_working_directory, logging_section, RUN_RECOVERY_ID_KEY,
EFFECTIVE_RANDOM_SEED_KEY_NAME, RUN_RECOVERY_FROM_ID_KEY_NAME)
EFFECTIVE_RANDOM_SEED_KEY_NAME, logging_section,
RUN_RECOVERY_ID_KEY, RUN_RECOVERY_FROM_ID_KEY_NAME)
from health_ml.utils.lightning_loggers import StoringLogger
from health_ml.utils.type_annotations import PathOrString
@ -142,35 +142,49 @@ class MLRunner:
container=self.container)
self.storing_logger = storing_logger
def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> List[Dict[str, float]]:
# Since we have trained the model, let the checkpoint_handler object know so it can handle
# checkpoints correctly.
if self.checkpoint_handler is not None:
self.checkpoint_handler.additional_training_done()
checkpoint_paths_for_testing = self.checkpoint_handler.get_checkpoints_to_test()
else:
checkpoint_paths_for_testing = []
with logging_section("Model inference"):
self.run_inference(checkpoint_paths_for_testing)
def run_inference(self, checkpoint_paths: List[Path]) -> None:
"""
Run inference on the test set for all models that are specified via a LightningContainer.
Run inference on the test set for all models.
:param checkpoint_paths: The path to the checkpoint that should be used for inference.
"""
if len(checkpoint_paths) != 1:
raise ValueError(f"This method expects exactly 1 checkpoint for inference, but got {len(checkpoint_paths)}")
# lightning_model = self.container.model
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
logging.info("Running inference via the LightningModule.test_step method")
# Lightning does not cope with having two calls to .fit or .test in the same script. As a workaround for
# now, restrict number of GPUs to 1, meaning that it will not start DDP.
self.container.max_num_gpus = 1
# Without this, the trainer will think it should still operate in multi-node mode, and wrongly start
# searching for Horovod
if ENV_OMPI_COMM_WORLD_RANK in os.environ:
del os.environ[ENV_OMPI_COMM_WORLD_RANK]
# From the training setup, torch still thinks that it should run in a distributed manner,
# and would block on some GPU operations. Hence, clean up distributed training.
if torch.distributed.is_initialized(): # type: ignore
torch.distributed.destroy_process_group() # type: ignore
lightning_model = self.container.model
if type(lightning_model).test_step != LightningModule.test_step:
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
logging.info("Running inference via the LightningModule.test_step method")
# Lightning does not cope with having two calls to .fit or .test in the same script. As a workaround for
# now, restrict number of GPUs to 1, meaning that it will not start DDP.
self.container.max_num_gpus = 1
# Without this, the trainer will think it should still operate in multi-node mode, and wrongly start
# searching for Horovod
if ENV_OMPI_COMM_WORLD_RANK in os.environ:
del os.environ[ENV_OMPI_COMM_WORLD_RANK]
# From the training setup, torch still thinks that it should run in a distributed manner,
# and would block on some GPU operations. Hence, clean up distributed training.
if torch.distributed.is_initialized(): # type: ignore
torch.distributed.destroy_process_group() # type: ignore
trainer, _ = create_lightning_trainer(self.container, num_nodes=1)
trainer, _ = create_lightning_trainer(self.container, num_nodes=1)
self.container.load_model_checkpoint(checkpoint_path=checkpoint_paths[0])
# Change the current working directory to ensure that test files go to thr right folder
data_module = self.container.get_data_module()
with change_working_directory(self.container.outputs_folder):
results = trainer.test(self.container.model, datamodule=data_module)
return results
self.container.load_model_checkpoint(checkpoint_path=checkpoint_paths[0])
# Change the current working directory to ensure that test files go to thr right folder
data_module = self.container.get_data_module()
_ = trainer.test(self.container.model, datamodule=data_module)
else:
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")

Просмотреть файл

@ -6,6 +6,7 @@ import logging
import os
import tempfile
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
from urllib.parse import urlparse
@ -19,7 +20,8 @@ from health_azure.utils import RUN_CONTEXT, download_files_from_run_id, get_run_
is_running_in_azure_ml
from health_ml.deep_learning_config import OutputParams
from health_ml.lightning_container import LightningContainer
from health_ml.utils.common_utils import AUTOSAVE_CHECKPOINT_CANDIDATES, CHECKPOINT_FOLDER, DEFAULT_AML_UPLOAD_DIR
from health_ml.utils.common_utils import AUTOSAVE_CHECKPOINT_CANDIDATES, CHECKPOINT_FOLDER, DEFAULT_AML_UPLOAD_DIR, \
check_properties_are_not_none
CHECKPOINT_SUFFIX = ".ckpt"
# This is a constant that must match a filename defined in pytorch_lightning.ModelCheckpoint, but we don't want
@ -31,16 +33,42 @@ MODEL_INFERENCE_JSON_FILE_NAME = "model_inference_config.json"
MODEL_WEIGHTS_DIR_NAME = "trained_models"
@dataclass(frozen=True)
class RunRecovery:
"""
Class to encapsulate information relating to run recovery (eg: check point paths for parent and child runs)
"""
checkpoints_roots: List[Path]
def get_recovery_checkpoint_paths(self) -> List[Path]:
return [get_recovery_checkpoint_path(x) for x in self.checkpoints_roots]
def get_best_checkpoint_paths(self) -> List[Path]:
return [get_best_checkpoint_path(x) for x in self.checkpoints_roots]
def _validate(self) -> None:
check_properties_are_not_none(self)
if len(self.checkpoints_roots) == 0:
raise ValueError("checkpoints_roots must not be empty")
def __post_init__(self) -> None:
self._validate()
logging.info(f"Storing {len(self.checkpoints_roots)}checkpoints roots:")
for p in self.checkpoints_roots:
logging.info(str(p))
class CheckpointHandler:
"""
This class handles which checkpoints are used to initialize the model during train or test time based on the
azure config and model config.
This class handles which checkpoints are used to initialize the model during train or test time
"""
def __init__(self, container: LightningContainer,
project_root: Path, run_context: Optional[Run] = None):
def __init__(self,
container: LightningContainer,
project_root: Path,
run_context: Optional[Run] = None):
self.container = container
# self.run_recovery: Optional[RunRecovery] = None
self.run_recovery: Optional[RunRecovery] = None
self.project_root = project_root
self.run_context = run_context
self.trained_weights_paths: List[Path] = []
@ -58,6 +86,7 @@ class CheckpointHandler:
Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
run_recovery_object, weights_url or local_weights_path.
This is called at the start of training.
:param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
nodes. If False, return the RunRecovery object and download the checkpoint to disk.
@ -77,6 +106,7 @@ class CheckpointHandler:
checkpoint folder. If run_recovery is provided, the checkpoints will have been downloaded to this folder
prior to calling this function. Else, if the run gets pre-empted and automatically restarted in AML,
the latest checkpoint will be present in this folder too.
:return: Constructed checkpoint path to recover from.
"""
if is_global_rank_zero():
@ -86,6 +116,66 @@ class CheckpointHandler:
logging.info(f)
return find_recovery_checkpoint_on_disk_or_cloud(self.container.checkpoint_folder)
def get_best_checkpoints(self) -> List[Path]:
"""
Get a list of checkpoints per epoch for testing/registration from the current training run.
This function also checks that the checkpoint at the returned checkpoint path exists.
"""
if not self.run_recovery and not self.has_continued_training:
raise ValueError("Cannot recover checkpoint, no run recovery object provided and "
"no training has been done in this run.")
checkpoint_paths = []
if self.run_recovery:
checkpoint_paths = self.run_recovery.get_best_checkpoint_paths()
checkpoint_exists = []
# Discard any checkpoint paths that do not exist - they will make inference/registration fail.
# This can happen when some child runs in a hyperdrive run fail; it may still be worth running inference
# or registering the model.
for path in checkpoint_paths:
if path.is_file():
checkpoint_exists.append(path)
else:
logging.warning(f"Could not recover checkpoint path {path}")
checkpoint_paths = checkpoint_exists
if self.has_continued_training:
# Checkpoint is from the current run, whether a new run or a run recovery which has been doing more
# training, so we look for it there.
# checkpoint_from_current_run = self.output_params.get_path_to_best_checkpoint()
checkpoint_from_current_run = get_recovery_checkpoint_path(Path(self.container.checkpoint_folder))
if checkpoint_from_current_run.is_file():
logging.info("Using checkpoints from current run.")
checkpoint_paths = [checkpoint_from_current_run]
else:
logging.info("Training has continued, but not yet written a checkpoint. Using recovery checkpoints.")
else:
logging.info("Using checkpoints from run recovery")
return checkpoint_paths
def get_checkpoints_to_test(self) -> List[Path]:
"""
Find the checkpoints to test. If a run recovery is provided, or if the model has been training, look for
checkpoints corresponding to the epochs in get_test_epochs(). If there is no run recovery and the model was
not trained in this run, then return the checkpoint from the local_weights_path.
"""
checkpoints = []
# If model was trained, look for the best checkpoint
if self.run_recovery or self.has_continued_training:
checkpoints = self.get_best_checkpoints()
elif self.trained_weights_paths:
# Model was not trained, check if there is a local weight path.
logging.info(f"Using model weights from {self.trained_weights_paths} to initialize model")
checkpoints = self.trained_weights_paths
else:
logging.warning("Could not find any local_weights_path, model_weights or model_id to get checkpoints from")
return checkpoints
@staticmethod
def download_weights(urls: List[str], download_folder: Path) -> List[Path]:
"""
@ -195,6 +285,7 @@ def find_recovery_checkpoint_on_disk_or_cloud(path: Path) -> Optional[Path]:
Looks at all the checkpoint files and returns the path to the one that should be used for recovery.
If no checkpoint files are found on disk, the function attempts to download from the current AzureML
run.
:param path: The folder to start searching in.
:return: None if there is no suitable recovery checkpoints, or else a full path to the checkpoint file.
"""
@ -212,6 +303,19 @@ def find_recovery_checkpoint_on_disk_or_cloud(path: Path) -> Optional[Path]:
return recovery_checkpoint
def get_recovery_checkpoint_path(path: Path) -> Path:
"""
Returns the path to the last recovery checkpoint in the given folder or the provided filename. Raises a
FileNotFoundError if no recovery checkpoint file is present.
:param path: Path to checkpoint folder
"""
recovery_checkpoint = find_recovery_checkpoint(path)
if recovery_checkpoint is None:
files = [f.name for f in path.glob("*")]
raise FileNotFoundError(f"No checkpoint files found in {path}. Existing files: {' '.join(files)}")
return recovery_checkpoint
def find_recovery_checkpoint(path: Path) -> Optional[Path]:
"""
Finds the checkpoint file in the given path that can be used for re-starting the present job.

Просмотреть файл

@ -22,17 +22,22 @@ MAX_PATH_LENGTH = 260
empty_string_to_none = lambda x: None if (x is None or len(x.strip()) == 0) else x
string_to_path = lambda x: None if (x is None or len(x.strip()) == 0) else Path(x)
EXPERIMENT_SUMMARY_FILE = "experiment_summary.txt"
CHECKPOINT_FOLDER = "checkpoints"
# file and directory names
CHECKPOINT_SUFFIX = ".ckpt"
AUTOSAVE_CHECKPOINT_FILE_NAME = "autosave"
AUTOSAVE_CHECKPOINT_CANDIDATES = [AUTOSAVE_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX,
AUTOSAVE_CHECKPOINT_FILE_NAME + "-v1" + CHECKPOINT_SUFFIX]
RUN_RECOVERY_ID_KEY = 'run_recovery_id'
EFFECTIVE_RANDOM_SEED_KEY_NAME = "effective_random_seed"
RUN_RECOVERY_FROM_ID_KEY_NAME = "recovered_from"
CHECKPOINT_FOLDER = "checkpoints"
DEFAULT_AML_UPLOAD_DIR = "outputs"
DEFAULT_LOGS_DIR_NAME = "logs"
EXPERIMENT_SUMMARY_FILE = "experiment_summary.txt"
# run recovery
RUN_RECOVERY_ID_KEY = 'run_recovery_id'
RUN_RECOVERY_FROM_ID_KEY_NAME = "recovered_from"
# other
EFFECTIVE_RANDOM_SEED_KEY_NAME = "effective_random_seed"
@unique
@ -281,3 +286,13 @@ def set_model_to_eval_mode(model: Module) -> Generator:
model.eval()
yield
model.train(old_mode)
def is_long_path(path: PathOrString) -> bool:
"""
A long path is a path that has more than MAX_PATH_LENGTH characters
:param path: The path to check the length of
:return: True if the length of the path is greater than MAX_PATH_LENGTH, else False
"""
return len(str(path)) > MAX_PATH_LENGTH

Просмотреть файл

@ -1,40 +1,67 @@
import shutil
from pathlib import Path
import pytest
from typing import Tuple
from unittest.mock import patch, MagicMock, Mock
from pytorch_lightning import Callback
from typing import Generator, Tuple
from unittest.mock import patch
from health_ml.configs.hello_container import HelloContainer
from health_ml.experiment_config import ExperimentConfig
from health_ml.lightning_container import LightningContainer
from health_ml.run_ml import MLRunner
@pytest.fixture
def ml_runner() -> MLRunner:
experiment_config = ExperimentConfig()
@pytest.fixture(scope="module")
def ml_runner_no_setup() -> MLRunner:
experiment_config = ExperimentConfig(model="HelloContainer")
container = LightningContainer(num_epochs=1)
return MLRunner(experiment_config=experiment_config, container=container)
runner = MLRunner(experiment_config=experiment_config, container=container)
return runner
def test_ml_runner_setup(ml_runner: MLRunner) -> None:
"""
Check that all the necessary methods get called during setup
"""
assert not ml_runner._has_setup_run
with patch.object(ml_runner, "container", spec=LightningContainer) as mock_container:
@pytest.fixture(scope="module")
def ml_runner() -> Generator:
experiment_config = ExperimentConfig(model="HelloContainer")
container = LightningContainer(num_epochs=1)
runner = MLRunner(experiment_config=experiment_config, container=container)
runner.setup()
yield runner
output_dir = runner.container.file_system_config.outputs_folder
if output_dir.exists():
shutil.rmtree(output_dir)
@pytest.fixture(scope="module")
def ml_runner_with_container() -> Generator:
experiment_config = ExperimentConfig(model="HelloContainer")
container = HelloContainer()
runner = MLRunner(experiment_config=experiment_config, container=container)
runner.setup()
yield runner
output_dir = runner.container.file_system_config.outputs_folder
if output_dir.exists():
shutil.rmtree(output_dir)
def _mock_model_train(chekpoint_path: Path, container: LightningContainer) -> Tuple[str, str]:
return "trainer", "storing_logger"
def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None:
"""Check that all the necessary methods get called during setup"""
assert not ml_runner_no_setup._has_setup_run
with patch.object(ml_runner_no_setup, "container", spec=LightningContainer) as mock_container:
with patch("health_ml.run_ml.seed_everything") as mock_seed:
# mock_container.get_effectie_random_seed = Mock()
ml_runner.setup()
ml_runner_no_setup.setup()
mock_container.get_effective_random_seed.assert_called_once()
mock_container.setup.assert_called_once()
mock_container.create_lightning_module_and_store.assert_called_once()
assert ml_runner._has_setup_run
assert ml_runner_no_setup._has_setup_run
mock_seed.assert_called_once()
def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None:
"""Test that set_run_tags_from_parents causes set_tags to get called"""
with pytest.raises(AssertionError) as ae:
ml_runner.set_run_tags_from_parent()
assert "should only be called in a Hyperdrive run" in str(ae)
@ -47,47 +74,64 @@ def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None:
def test_run(ml_runner: MLRunner) -> None:
def _mock_model_train(chekpoint_path: Path, container: LightningContainer) -> Tuple[str, str]:
return "trainer", dummy_storing_logger
dummy_storing_logger = "storing_logger"
with patch("health_ml.run_ml.model_train", new=_mock_model_train):
ml_runner.run()
assert ml_runner._has_setup_run
# expect _mock_model_train to be called and the result of ml_runner.storing_logger
# updated accordingly
assert ml_runner.storing_logger == dummy_storing_logger
"""Test that model runner gets called """
ml_runner.setup()
assert not ml_runner.checkpoint_handler.has_continued_training
with patch.object(ml_runner, "run_inference"):
with patch.object(ml_runner, "checkpoint_handler"):
with patch("health_ml.run_ml.model_train", new=_mock_model_train):
ml_runner.run()
assert ml_runner._has_setup_run
# expect _mock_model_train to be called and the result of ml_runner.storing_logger
# updated accordingly
assert ml_runner.storing_logger == "storing_logger"
assert ml_runner.checkpoint_handler.has_continued_training
@patch("health_ml.run_ml.create_lightning_trainer")
def test_run_inference_for_lightning_models(mock_create_trainer: MagicMock, ml_runner: MLRunner,
tmp_path: Path) -> None:
def test_run_inference(ml_runner_with_container: MLRunner, tmp_path: Path) -> None:
"""
Check that all expected methods are called during inference3
Test that run_inference gets called as expected.
"""
mock_trainer = MagicMock()
mock_test_result = [{"result": 1.0}]
mock_trainer.test.return_value = mock_test_result
mock_create_trainer.return_value = mock_trainer, ""
def _expected_files_exist() -> int:
output_dir = ml_runner_with_container.container.outputs_folder
expected_files = [Path("test_mse.txt"), Path("test_mae.txt")]
return sum([p.exists() for p in expected_files] + [output_dir.is_dir()])
with patch.object(ml_runner, "container") as mock_container:
mock_container.num_gpus_per_node.return_value = 0
mock_container.get_trainer_arguments.return_value = {"callbacks": Callback()}
mock_container.load_model_checkpoint.return_value = Mock()
mock_container.get_data_module.return_value = Mock()
mock_container.pl_progress_bar_refresh_rate = None
mock_container.detect_anomaly = False
mock_container.pl_limit_train_batches = 1.0
mock_container.pl_limit_val_batches = 1.0
mock_container.outputs_folder = tmp_path
# create the test data
import numpy as np
import torch
checkpoint_paths = [Path("dummy")]
result = ml_runner.run_inference_for_lightning_models(checkpoint_paths)
assert result == mock_test_result
N = 100
x = torch.rand((N, 1)) * 10
y = 0.2 * x + 0.1 * torch.randn(x.size())
xy = torch.cat((x, y), dim=1)
data_path = tmp_path / "hellocontainer.csv"
np.savetxt(data_path, xy.numpy(), delimiter=",")
mock_create_trainer.assert_called_once()
mock_container.load_model_checkpoint.assert_called_once()
mock_container.get_data_module.assert_called_once()
mock_trainer.test.assert_called_once()
expected_ckpt_path = ml_runner_with_container.container.outputs_folder / "checkpoints" / "last.ckpt"
assert not expected_ckpt_path.exists()
# update the container to look for test data at this location
ml_runner_with_container.container.local_dataset_dir = tmp_path
assert _expected_files_exist() == 0
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
assert actual_train_ckpt_path is None
ml_runner_with_container.run()
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
assert actual_train_ckpt_path == expected_ckpt_path
actual_test_ckpt_path = ml_runner_with_container.checkpoint_handler.get_checkpoints_to_test()
assert actual_test_ckpt_path == [expected_ckpt_path]
assert actual_test_ckpt_path[0].exists()
# After training, the outputs directory should now exist
assert _expected_files_exist() == 3
# if no checkpoint handler, no checkpoint paths will be saved and these are required for
# inference so ValueError will be raised
with pytest.raises(ValueError) as e:
ml_runner_with_container.checkpoint_handler = None # type: ignore
ml_runner_with_container.run()
assert "expects exactly 1 checkpoint for inference, but got 0" in str(e)
Path("test_mae.txt").unlink()
Path("test_mse.txt").unlink()

Просмотреть файл

@ -1,24 +1,23 @@
coverage==5.5
conda-merge==0.1.5
flake8==3.8.4
gitpython==3.1.7
lightning-bolts==0.4.0
matplotlib==3.3.0
monai==0.6.0
more-itertools==8.10.0
mypy==0.910
opencv-python-headless==4.5.1.48
pandas==1.3.4
param==1.9.3
pillow==8.3.2
pydicom==2.0.0
pylint==2.9.5
pycobertura==2.0.1
pytest==6.2.2
pytest-cov==2.11.1
pytest-timeout==2.0.1
types-requests==2.25.6
conda-merge==0.1.5
gitpython==3.1.7
lightning-bolts==0.5.0
matplotlib==3.3.0
monai==0.6.0
more-itertools==8.10.0
opencv-python-headless==4.5.1.48
pandas==1.3.4
param==1.9.3
pydicom==2.0.0
pytorch-lightning==1.5.5
pillow==8.3.2
ruamel.yaml==0.16.12
rpdb==0.1.6
scikit-learn==1.0
@ -27,4 +26,5 @@ simpleitk==1.2.4
torch==1.10.0
torchmetrics==0.6.0
torchvision==0.11.1
types-requests==2.25.6
yacs==0.1.8