зеркало из https://github.com/microsoft/hi-ml.git
Add inference to runner (#186)
Add inference to the hi-ml runner, and downgrade lightning-bolts to 0.4.0 in line with InnerEye
This commit is contained in:
Родитель
2bc397b470
Коммит
fed8220456
|
@ -18,6 +18,7 @@ the environment file since it is necessary for the augmentations.
|
|||
- ([#178](https://github.com/microsoft/hi-ml/pull/178)) Add runner script for running ML experiments.
|
||||
- ([#181](https://github.com/microsoft/hi-ml/pull/181)) Add computational pathology tools in hi-ml-histopathology folder.
|
||||
- ([#187](https://github.com/microsoft/hi-ml/pull/187)) Add mean pooling layer for MIL.
|
||||
- ([#186](https://github.com/microsoft/hi-ml/pull/186)) Add inference to hi-ml runner.
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -41,6 +41,12 @@ The `hi-ml` toolbox provides
|
|||
|
||||
histopathology.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Self supervised learning
|
||||
|
||||
self_supervised_models.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Developers
|
||||
|
|
|
@ -57,6 +57,10 @@ ENV_NODE_RANK = "NODE_RANK"
|
|||
ENV_GLOBAL_RANK = "GLOBAL_RANK"
|
||||
ENV_LOCAL_RANK = "LOCAL_RANK"
|
||||
|
||||
ENVIRONMENT_VERSION = "1"
|
||||
FINAL_MODEL_FOLDER = "final_model"
|
||||
MODEL_ID_KEY_NAME = "model_id"
|
||||
PYTHON_ENVIRONMENT_NAME = "python_environment_name"
|
||||
RUN_CONTEXT = Run.get_context()
|
||||
PARENT_RUN_CONTEXT = getattr(RUN_CONTEXT, "parent", None)
|
||||
WORKSPACE_CONFIG_JSON = "config.json"
|
||||
|
|
|
@ -228,11 +228,11 @@ def test_split_recovery_id_fails() -> None:
|
|||
with pytest.raises(ValueError) as e:
|
||||
id = util.EXPERIMENT_RUN_SEPARATOR.join([str(i) for i in range(3)])
|
||||
util.split_recovery_id(id)
|
||||
assert str(e.value) == f"recovery_id must be in the format: 'experiment_name:run_id', but got: {id}"
|
||||
assert str(e.value) == f"recovery_id must be in the format: 'experiment_name:run_id', but got: {id}"
|
||||
with pytest.raises(ValueError) as e:
|
||||
id = "foo_bar"
|
||||
util.split_recovery_id(id)
|
||||
assert str(e.value) == f"The recovery ID was not in the expected format: {id}"
|
||||
assert str(e.value) == f"The recovery ID was not in the expected format: {id}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["id", "expected1", "expected2"],
|
||||
|
|
|
@ -11,46 +11,31 @@ dependencies:
|
|||
- python-blosc==1.7.0
|
||||
- torchvision=0.11.1
|
||||
- pip:
|
||||
- -r ../test_requirements.txt
|
||||
- azureml-sdk==1.36.0
|
||||
- conda-merge==0.1.5
|
||||
- cryptography==3.3.2
|
||||
- docker==4.3.1
|
||||
- gitpython==3.1.7
|
||||
- flask==2.0.1
|
||||
- gputil==1.4.0
|
||||
- hi-ml>=0.1.12
|
||||
- joblib==0.16.0
|
||||
- jupyter==1.0.0
|
||||
- jupyter-client==6.1.5
|
||||
- lightning-bolts==0.5.0
|
||||
- matplotlib==3.3.0
|
||||
- mlflow==1.17.0
|
||||
- monai==0.6.0
|
||||
- more-itertools==8.10.0
|
||||
- mypy==0.910
|
||||
- mypy-extensions==0.4.3
|
||||
- numba==0.51.2
|
||||
- numpy==1.19.1
|
||||
- opencv-python-headless==4.5.1.48
|
||||
- pandas==1.3.4
|
||||
- param==1.9.3
|
||||
- pillow==8.3.2
|
||||
- psutil==5.7.2
|
||||
- pydicom==2.0.0
|
||||
- pyflakes==2.2.0
|
||||
- PyJWT==1.7.1
|
||||
- pytorch-lightning==1.5.5
|
||||
- rich==5.1.1
|
||||
- rpdb==0.1.6
|
||||
- ruamel.yaml==0.16.12
|
||||
- runstats==1.8.0
|
||||
- scikit-image==0.17.2
|
||||
- scikit-learn==0.23.2
|
||||
- scipy==1.5.2
|
||||
- simpleitk==1.2.4
|
||||
- six==1.15.0
|
||||
- stopit==1.1.2
|
||||
- tabulate==0.8.7
|
||||
- torchprof==1.3.3
|
||||
- torchmetrics==0.6.0
|
||||
- umap-learn==0.5.2
|
||||
- yacs==0.1.8
|
||||
|
|
|
@ -44,7 +44,7 @@ class NIH_RSNA_SimCLR(SSLContainer):
|
|||
super().__init__(ssl_training_dataset_name=SSLDatasetName.NIHCXR,
|
||||
linear_head_dataset_name=SSLDatasetName.RSNAKaggleCXR,
|
||||
# the first Azure dataset is for training, the second is for the linear head
|
||||
azure_dataset_id=[NIH_AZURE_DATASET_ID, RSNA_AZURE_DATASET_ID],
|
||||
azure_datasets=[NIH_AZURE_DATASET_ID, RSNA_AZURE_DATASET_ID],
|
||||
random_seed=1,
|
||||
max_epochs=1000,
|
||||
# We usually train this model with 16 GPUs, giving an effective batch size of 1200
|
||||
|
|
|
@ -142,7 +142,8 @@ class SslOnlineEvaluatorHiml(SSLOnlineEvaluator):
|
|||
batch = batch[SSLDataModuleType.LINEAR_HEAD] if isinstance(batch, dict) else batch
|
||||
x, y = self.to_device(batch, pl_module.device)
|
||||
with torch.no_grad():
|
||||
representations = pl_module(x).flatten(start_dim=1)
|
||||
representations = self.get_representations(pl_module, x)
|
||||
representations = representations.detach()
|
||||
|
||||
# Run the linear-head with SSL embeddings.
|
||||
mlp_preds = self.evaluator(representations)
|
||||
|
|
|
@ -312,7 +312,8 @@ class DeepMILModule(LightningModule):
|
|||
list_slide_dicts.append(slide_dict)
|
||||
list_encoded_features.append(results[ResultsKey.IMAGE][slide_idx])
|
||||
|
||||
outputs_path = fixed_paths.repository_parent_directory() / 'outputs'
|
||||
outputs_path = fixed_paths.repository_root_directory() / 'outputs'
|
||||
assert outputs_path.is_dir, f"No such dir: {outputs_path}"
|
||||
print(f"Metrics results will be output to {outputs_path}")
|
||||
outputs_fig_path = outputs_path / 'fig'
|
||||
csv_filename = outputs_path / 'test_output.csv'
|
||||
|
@ -324,7 +325,7 @@ class DeepMILModule(LightningModule):
|
|||
slide_dict = self.normalize_dict_for_df(slide_dict, use_gpu=False)
|
||||
df_list.append(pd.DataFrame.from_dict(slide_dict))
|
||||
df = pd.concat(df_list, ignore_index=True)
|
||||
df.to_csv(csv_filename, mode='w', header=True)
|
||||
df.to_csv(csv_filename, mode='w+', header=True)
|
||||
|
||||
# Collect all features in a list and save
|
||||
features_list = self.move_list_to_device(list_encoded_features, use_gpu=False)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
dataclasses-json==0.5.2
|
||||
hi-ml-azure>=0.1.8
|
||||
jinja2==3.0.2
|
||||
matplotlib==3.4.3
|
||||
|
|
|
@ -104,6 +104,7 @@ class ExperimentFolderHandler(Parameterized):
|
|||
else:
|
||||
logging.info("All results will be written to a subfolder of the project root folder.")
|
||||
root = project_root.absolute() / DEFAULT_AML_UPLOAD_DIR
|
||||
|
||||
timestamp = create_unique_timestamp_id()
|
||||
run_folder = root / f"{timestamp}_{model_name}"
|
||||
outputs_folder = run_folder
|
||||
|
@ -114,6 +115,7 @@ class ExperimentFolderHandler(Parameterized):
|
|||
run_folder = project_root
|
||||
outputs_folder = project_root / DEFAULT_AML_UPLOAD_DIR
|
||||
logs_folder = project_root / DEFAULT_LOGS_DIR_NAME
|
||||
|
||||
logging.info(f"Run outputs folder: {outputs_folder}")
|
||||
logging.info(f"Logs folder: {logs_folder}")
|
||||
return ExperimentFolderHandler(
|
||||
|
|
|
@ -5,14 +5,14 @@
|
|||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch.multiprocessing
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning import LightningModule, seed_everything
|
||||
|
||||
from health_azure import AzureRunInfo
|
||||
from health_azure.utils import (ENV_OMPI_COMM_WORLD_RANK, RUN_CONTEXT, create_run_recovery_id,
|
||||
PARENT_RUN_CONTEXT, is_running_in_azure_ml)
|
||||
from health_azure.utils import (create_run_recovery_id, ENV_OMPI_COMM_WORLD_RANK,
|
||||
is_running_in_azure_ml, PARENT_RUN_CONTEXT, RUN_CONTEXT)
|
||||
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
|
@ -20,8 +20,8 @@ from health_ml.model_trainer import create_lightning_trainer, model_train
|
|||
from health_ml.utils import fixed_paths
|
||||
from health_ml.utils.checkpoint_utils import CheckpointHandler
|
||||
from health_ml.utils.common_utils import (
|
||||
change_working_directory, logging_section, RUN_RECOVERY_ID_KEY,
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME, RUN_RECOVERY_FROM_ID_KEY_NAME)
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME, logging_section,
|
||||
RUN_RECOVERY_ID_KEY, RUN_RECOVERY_FROM_ID_KEY_NAME)
|
||||
from health_ml.utils.lightning_loggers import StoringLogger
|
||||
from health_ml.utils.type_annotations import PathOrString
|
||||
|
||||
|
@ -142,35 +142,49 @@ class MLRunner:
|
|||
container=self.container)
|
||||
self.storing_logger = storing_logger
|
||||
|
||||
def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> List[Dict[str, float]]:
|
||||
# Since we have trained the model, let the checkpoint_handler object know so it can handle
|
||||
# checkpoints correctly.
|
||||
if self.checkpoint_handler is not None:
|
||||
self.checkpoint_handler.additional_training_done()
|
||||
checkpoint_paths_for_testing = self.checkpoint_handler.get_checkpoints_to_test()
|
||||
else:
|
||||
checkpoint_paths_for_testing = []
|
||||
|
||||
with logging_section("Model inference"):
|
||||
self.run_inference(checkpoint_paths_for_testing)
|
||||
|
||||
def run_inference(self, checkpoint_paths: List[Path]) -> None:
|
||||
"""
|
||||
Run inference on the test set for all models that are specified via a LightningContainer.
|
||||
Run inference on the test set for all models.
|
||||
|
||||
:param checkpoint_paths: The path to the checkpoint that should be used for inference.
|
||||
"""
|
||||
if len(checkpoint_paths) != 1:
|
||||
raise ValueError(f"This method expects exactly 1 checkpoint for inference, but got {len(checkpoint_paths)}")
|
||||
# lightning_model = self.container.model
|
||||
|
||||
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
|
||||
logging.info("Running inference via the LightningModule.test_step method")
|
||||
# Lightning does not cope with having two calls to .fit or .test in the same script. As a workaround for
|
||||
# now, restrict number of GPUs to 1, meaning that it will not start DDP.
|
||||
self.container.max_num_gpus = 1
|
||||
# Without this, the trainer will think it should still operate in multi-node mode, and wrongly start
|
||||
# searching for Horovod
|
||||
if ENV_OMPI_COMM_WORLD_RANK in os.environ:
|
||||
del os.environ[ENV_OMPI_COMM_WORLD_RANK]
|
||||
# From the training setup, torch still thinks that it should run in a distributed manner,
|
||||
# and would block on some GPU operations. Hence, clean up distributed training.
|
||||
if torch.distributed.is_initialized(): # type: ignore
|
||||
torch.distributed.destroy_process_group() # type: ignore
|
||||
lightning_model = self.container.model
|
||||
if type(lightning_model).test_step != LightningModule.test_step:
|
||||
# Run Lightning's built-in test procedure if the `test_step` method has been overridden
|
||||
logging.info("Running inference via the LightningModule.test_step method")
|
||||
# Lightning does not cope with having two calls to .fit or .test in the same script. As a workaround for
|
||||
# now, restrict number of GPUs to 1, meaning that it will not start DDP.
|
||||
self.container.max_num_gpus = 1
|
||||
# Without this, the trainer will think it should still operate in multi-node mode, and wrongly start
|
||||
# searching for Horovod
|
||||
if ENV_OMPI_COMM_WORLD_RANK in os.environ:
|
||||
del os.environ[ENV_OMPI_COMM_WORLD_RANK]
|
||||
# From the training setup, torch still thinks that it should run in a distributed manner,
|
||||
# and would block on some GPU operations. Hence, clean up distributed training.
|
||||
if torch.distributed.is_initialized(): # type: ignore
|
||||
torch.distributed.destroy_process_group() # type: ignore
|
||||
|
||||
trainer, _ = create_lightning_trainer(self.container, num_nodes=1)
|
||||
trainer, _ = create_lightning_trainer(self.container, num_nodes=1)
|
||||
|
||||
self.container.load_model_checkpoint(checkpoint_path=checkpoint_paths[0])
|
||||
# Change the current working directory to ensure that test files go to thr right folder
|
||||
data_module = self.container.get_data_module()
|
||||
with change_working_directory(self.container.outputs_folder):
|
||||
results = trainer.test(self.container.model, datamodule=data_module)
|
||||
return results
|
||||
self.container.load_model_checkpoint(checkpoint_path=checkpoint_paths[0])
|
||||
# Change the current working directory to ensure that test files go to thr right folder
|
||||
data_module = self.container.get_data_module()
|
||||
|
||||
_ = trainer.test(self.container.model, datamodule=data_module)
|
||||
|
||||
else:
|
||||
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
@ -19,7 +20,8 @@ from health_azure.utils import RUN_CONTEXT, download_files_from_run_id, get_run_
|
|||
is_running_in_azure_ml
|
||||
from health_ml.deep_learning_config import OutputParams
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.utils.common_utils import AUTOSAVE_CHECKPOINT_CANDIDATES, CHECKPOINT_FOLDER, DEFAULT_AML_UPLOAD_DIR
|
||||
from health_ml.utils.common_utils import AUTOSAVE_CHECKPOINT_CANDIDATES, CHECKPOINT_FOLDER, DEFAULT_AML_UPLOAD_DIR, \
|
||||
check_properties_are_not_none
|
||||
|
||||
CHECKPOINT_SUFFIX = ".ckpt"
|
||||
# This is a constant that must match a filename defined in pytorch_lightning.ModelCheckpoint, but we don't want
|
||||
|
@ -31,16 +33,42 @@ MODEL_INFERENCE_JSON_FILE_NAME = "model_inference_config.json"
|
|||
MODEL_WEIGHTS_DIR_NAME = "trained_models"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunRecovery:
|
||||
"""
|
||||
Class to encapsulate information relating to run recovery (eg: check point paths for parent and child runs)
|
||||
"""
|
||||
checkpoints_roots: List[Path]
|
||||
|
||||
def get_recovery_checkpoint_paths(self) -> List[Path]:
|
||||
return [get_recovery_checkpoint_path(x) for x in self.checkpoints_roots]
|
||||
|
||||
def get_best_checkpoint_paths(self) -> List[Path]:
|
||||
return [get_best_checkpoint_path(x) for x in self.checkpoints_roots]
|
||||
|
||||
def _validate(self) -> None:
|
||||
check_properties_are_not_none(self)
|
||||
if len(self.checkpoints_roots) == 0:
|
||||
raise ValueError("checkpoints_roots must not be empty")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._validate()
|
||||
logging.info(f"Storing {len(self.checkpoints_roots)}checkpoints roots:")
|
||||
for p in self.checkpoints_roots:
|
||||
logging.info(str(p))
|
||||
|
||||
|
||||
class CheckpointHandler:
|
||||
"""
|
||||
This class handles which checkpoints are used to initialize the model during train or test time based on the
|
||||
azure config and model config.
|
||||
This class handles which checkpoints are used to initialize the model during train or test time
|
||||
"""
|
||||
|
||||
def __init__(self, container: LightningContainer,
|
||||
project_root: Path, run_context: Optional[Run] = None):
|
||||
def __init__(self,
|
||||
container: LightningContainer,
|
||||
project_root: Path,
|
||||
run_context: Optional[Run] = None):
|
||||
self.container = container
|
||||
# self.run_recovery: Optional[RunRecovery] = None
|
||||
self.run_recovery: Optional[RunRecovery] = None
|
||||
self.project_root = project_root
|
||||
self.run_context = run_context
|
||||
self.trained_weights_paths: List[Path] = []
|
||||
|
@ -58,6 +86,7 @@ class CheckpointHandler:
|
|||
Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
|
||||
run_recovery_object, weights_url or local_weights_path.
|
||||
This is called at the start of training.
|
||||
|
||||
:param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually
|
||||
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
|
||||
nodes. If False, return the RunRecovery object and download the checkpoint to disk.
|
||||
|
@ -77,6 +106,7 @@ class CheckpointHandler:
|
|||
checkpoint folder. If run_recovery is provided, the checkpoints will have been downloaded to this folder
|
||||
prior to calling this function. Else, if the run gets pre-empted and automatically restarted in AML,
|
||||
the latest checkpoint will be present in this folder too.
|
||||
|
||||
:return: Constructed checkpoint path to recover from.
|
||||
"""
|
||||
if is_global_rank_zero():
|
||||
|
@ -86,6 +116,66 @@ class CheckpointHandler:
|
|||
logging.info(f)
|
||||
return find_recovery_checkpoint_on_disk_or_cloud(self.container.checkpoint_folder)
|
||||
|
||||
def get_best_checkpoints(self) -> List[Path]:
|
||||
"""
|
||||
Get a list of checkpoints per epoch for testing/registration from the current training run.
|
||||
This function also checks that the checkpoint at the returned checkpoint path exists.
|
||||
"""
|
||||
if not self.run_recovery and not self.has_continued_training:
|
||||
raise ValueError("Cannot recover checkpoint, no run recovery object provided and "
|
||||
"no training has been done in this run.")
|
||||
|
||||
checkpoint_paths = []
|
||||
if self.run_recovery:
|
||||
checkpoint_paths = self.run_recovery.get_best_checkpoint_paths()
|
||||
|
||||
checkpoint_exists = []
|
||||
# Discard any checkpoint paths that do not exist - they will make inference/registration fail.
|
||||
# This can happen when some child runs in a hyperdrive run fail; it may still be worth running inference
|
||||
# or registering the model.
|
||||
for path in checkpoint_paths:
|
||||
if path.is_file():
|
||||
checkpoint_exists.append(path)
|
||||
else:
|
||||
logging.warning(f"Could not recover checkpoint path {path}")
|
||||
checkpoint_paths = checkpoint_exists
|
||||
|
||||
if self.has_continued_training:
|
||||
# Checkpoint is from the current run, whether a new run or a run recovery which has been doing more
|
||||
# training, so we look for it there.
|
||||
# checkpoint_from_current_run = self.output_params.get_path_to_best_checkpoint()
|
||||
checkpoint_from_current_run = get_recovery_checkpoint_path(Path(self.container.checkpoint_folder))
|
||||
if checkpoint_from_current_run.is_file():
|
||||
logging.info("Using checkpoints from current run.")
|
||||
checkpoint_paths = [checkpoint_from_current_run]
|
||||
else:
|
||||
logging.info("Training has continued, but not yet written a checkpoint. Using recovery checkpoints.")
|
||||
else:
|
||||
logging.info("Using checkpoints from run recovery")
|
||||
|
||||
return checkpoint_paths
|
||||
|
||||
def get_checkpoints_to_test(self) -> List[Path]:
|
||||
"""
|
||||
Find the checkpoints to test. If a run recovery is provided, or if the model has been training, look for
|
||||
checkpoints corresponding to the epochs in get_test_epochs(). If there is no run recovery and the model was
|
||||
not trained in this run, then return the checkpoint from the local_weights_path.
|
||||
"""
|
||||
|
||||
checkpoints = []
|
||||
|
||||
# If model was trained, look for the best checkpoint
|
||||
if self.run_recovery or self.has_continued_training:
|
||||
checkpoints = self.get_best_checkpoints()
|
||||
elif self.trained_weights_paths:
|
||||
# Model was not trained, check if there is a local weight path.
|
||||
logging.info(f"Using model weights from {self.trained_weights_paths} to initialize model")
|
||||
checkpoints = self.trained_weights_paths
|
||||
else:
|
||||
logging.warning("Could not find any local_weights_path, model_weights or model_id to get checkpoints from")
|
||||
|
||||
return checkpoints
|
||||
|
||||
@staticmethod
|
||||
def download_weights(urls: List[str], download_folder: Path) -> List[Path]:
|
||||
"""
|
||||
|
@ -195,6 +285,7 @@ def find_recovery_checkpoint_on_disk_or_cloud(path: Path) -> Optional[Path]:
|
|||
Looks at all the checkpoint files and returns the path to the one that should be used for recovery.
|
||||
If no checkpoint files are found on disk, the function attempts to download from the current AzureML
|
||||
run.
|
||||
|
||||
:param path: The folder to start searching in.
|
||||
:return: None if there is no suitable recovery checkpoints, or else a full path to the checkpoint file.
|
||||
"""
|
||||
|
@ -212,6 +303,19 @@ def find_recovery_checkpoint_on_disk_or_cloud(path: Path) -> Optional[Path]:
|
|||
return recovery_checkpoint
|
||||
|
||||
|
||||
def get_recovery_checkpoint_path(path: Path) -> Path:
|
||||
"""
|
||||
Returns the path to the last recovery checkpoint in the given folder or the provided filename. Raises a
|
||||
FileNotFoundError if no recovery checkpoint file is present.
|
||||
:param path: Path to checkpoint folder
|
||||
"""
|
||||
recovery_checkpoint = find_recovery_checkpoint(path)
|
||||
if recovery_checkpoint is None:
|
||||
files = [f.name for f in path.glob("*")]
|
||||
raise FileNotFoundError(f"No checkpoint files found in {path}. Existing files: {' '.join(files)}")
|
||||
return recovery_checkpoint
|
||||
|
||||
|
||||
def find_recovery_checkpoint(path: Path) -> Optional[Path]:
|
||||
"""
|
||||
Finds the checkpoint file in the given path that can be used for re-starting the present job.
|
||||
|
|
|
@ -22,17 +22,22 @@ MAX_PATH_LENGTH = 260
|
|||
empty_string_to_none = lambda x: None if (x is None or len(x.strip()) == 0) else x
|
||||
string_to_path = lambda x: None if (x is None or len(x.strip()) == 0) else Path(x)
|
||||
|
||||
EXPERIMENT_SUMMARY_FILE = "experiment_summary.txt"
|
||||
CHECKPOINT_FOLDER = "checkpoints"
|
||||
# file and directory names
|
||||
CHECKPOINT_SUFFIX = ".ckpt"
|
||||
AUTOSAVE_CHECKPOINT_FILE_NAME = "autosave"
|
||||
AUTOSAVE_CHECKPOINT_CANDIDATES = [AUTOSAVE_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX,
|
||||
AUTOSAVE_CHECKPOINT_FILE_NAME + "-v1" + CHECKPOINT_SUFFIX]
|
||||
RUN_RECOVERY_ID_KEY = 'run_recovery_id'
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME = "effective_random_seed"
|
||||
RUN_RECOVERY_FROM_ID_KEY_NAME = "recovered_from"
|
||||
CHECKPOINT_FOLDER = "checkpoints"
|
||||
DEFAULT_AML_UPLOAD_DIR = "outputs"
|
||||
DEFAULT_LOGS_DIR_NAME = "logs"
|
||||
EXPERIMENT_SUMMARY_FILE = "experiment_summary.txt"
|
||||
|
||||
# run recovery
|
||||
RUN_RECOVERY_ID_KEY = 'run_recovery_id'
|
||||
RUN_RECOVERY_FROM_ID_KEY_NAME = "recovered_from"
|
||||
|
||||
# other
|
||||
EFFECTIVE_RANDOM_SEED_KEY_NAME = "effective_random_seed"
|
||||
|
||||
|
||||
@unique
|
||||
|
@ -281,3 +286,13 @@ def set_model_to_eval_mode(model: Module) -> Generator:
|
|||
model.eval()
|
||||
yield
|
||||
model.train(old_mode)
|
||||
|
||||
|
||||
def is_long_path(path: PathOrString) -> bool:
|
||||
"""
|
||||
A long path is a path that has more than MAX_PATH_LENGTH characters
|
||||
|
||||
:param path: The path to check the length of
|
||||
:return: True if the length of the path is greater than MAX_PATH_LENGTH, else False
|
||||
"""
|
||||
return len(str(path)) > MAX_PATH_LENGTH
|
||||
|
|
|
@ -1,40 +1,67 @@
|
|||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
from pytorch_lightning import Callback
|
||||
from typing import Generator, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
from health_ml.configs.hello_container import HelloContainer
|
||||
from health_ml.experiment_config import ExperimentConfig
|
||||
from health_ml.lightning_container import LightningContainer
|
||||
from health_ml.run_ml import MLRunner
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ml_runner() -> MLRunner:
|
||||
experiment_config = ExperimentConfig()
|
||||
@pytest.fixture(scope="module")
|
||||
def ml_runner_no_setup() -> MLRunner:
|
||||
experiment_config = ExperimentConfig(model="HelloContainer")
|
||||
container = LightningContainer(num_epochs=1)
|
||||
return MLRunner(experiment_config=experiment_config, container=container)
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
return runner
|
||||
|
||||
|
||||
def test_ml_runner_setup(ml_runner: MLRunner) -> None:
|
||||
"""
|
||||
Check that all the necessary methods get called during setup
|
||||
"""
|
||||
assert not ml_runner._has_setup_run
|
||||
with patch.object(ml_runner, "container", spec=LightningContainer) as mock_container:
|
||||
@pytest.fixture(scope="module")
|
||||
def ml_runner() -> Generator:
|
||||
experiment_config = ExperimentConfig(model="HelloContainer")
|
||||
container = LightningContainer(num_epochs=1)
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
runner.setup()
|
||||
yield runner
|
||||
output_dir = runner.container.file_system_config.outputs_folder
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ml_runner_with_container() -> Generator:
|
||||
experiment_config = ExperimentConfig(model="HelloContainer")
|
||||
container = HelloContainer()
|
||||
runner = MLRunner(experiment_config=experiment_config, container=container)
|
||||
runner.setup()
|
||||
yield runner
|
||||
output_dir = runner.container.file_system_config.outputs_folder
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
|
||||
def _mock_model_train(chekpoint_path: Path, container: LightningContainer) -> Tuple[str, str]:
|
||||
return "trainer", "storing_logger"
|
||||
|
||||
|
||||
def test_ml_runner_setup(ml_runner_no_setup: MLRunner) -> None:
|
||||
"""Check that all the necessary methods get called during setup"""
|
||||
assert not ml_runner_no_setup._has_setup_run
|
||||
with patch.object(ml_runner_no_setup, "container", spec=LightningContainer) as mock_container:
|
||||
with patch("health_ml.run_ml.seed_everything") as mock_seed:
|
||||
# mock_container.get_effectie_random_seed = Mock()
|
||||
ml_runner.setup()
|
||||
ml_runner_no_setup.setup()
|
||||
mock_container.get_effective_random_seed.assert_called_once()
|
||||
mock_container.setup.assert_called_once()
|
||||
mock_container.create_lightning_module_and_store.assert_called_once()
|
||||
assert ml_runner._has_setup_run
|
||||
assert ml_runner_no_setup._has_setup_run
|
||||
mock_seed.assert_called_once()
|
||||
|
||||
|
||||
def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None:
|
||||
"""Test that set_run_tags_from_parents causes set_tags to get called"""
|
||||
with pytest.raises(AssertionError) as ae:
|
||||
ml_runner.set_run_tags_from_parent()
|
||||
assert "should only be called in a Hyperdrive run" in str(ae)
|
||||
|
@ -47,47 +74,64 @@ def test_set_run_tags_from_parent(ml_runner: MLRunner) -> None:
|
|||
|
||||
|
||||
def test_run(ml_runner: MLRunner) -> None:
|
||||
|
||||
def _mock_model_train(chekpoint_path: Path, container: LightningContainer) -> Tuple[str, str]:
|
||||
return "trainer", dummy_storing_logger
|
||||
|
||||
dummy_storing_logger = "storing_logger"
|
||||
|
||||
with patch("health_ml.run_ml.model_train", new=_mock_model_train):
|
||||
ml_runner.run()
|
||||
assert ml_runner._has_setup_run
|
||||
# expect _mock_model_train to be called and the result of ml_runner.storing_logger
|
||||
# updated accordingly
|
||||
assert ml_runner.storing_logger == dummy_storing_logger
|
||||
"""Test that model runner gets called """
|
||||
ml_runner.setup()
|
||||
assert not ml_runner.checkpoint_handler.has_continued_training
|
||||
with patch.object(ml_runner, "run_inference"):
|
||||
with patch.object(ml_runner, "checkpoint_handler"):
|
||||
with patch("health_ml.run_ml.model_train", new=_mock_model_train):
|
||||
ml_runner.run()
|
||||
assert ml_runner._has_setup_run
|
||||
# expect _mock_model_train to be called and the result of ml_runner.storing_logger
|
||||
# updated accordingly
|
||||
assert ml_runner.storing_logger == "storing_logger"
|
||||
assert ml_runner.checkpoint_handler.has_continued_training
|
||||
|
||||
|
||||
@patch("health_ml.run_ml.create_lightning_trainer")
|
||||
def test_run_inference_for_lightning_models(mock_create_trainer: MagicMock, ml_runner: MLRunner,
|
||||
tmp_path: Path) -> None:
|
||||
def test_run_inference(ml_runner_with_container: MLRunner, tmp_path: Path) -> None:
|
||||
"""
|
||||
Check that all expected methods are called during inference3
|
||||
Test that run_inference gets called as expected.
|
||||
"""
|
||||
mock_trainer = MagicMock()
|
||||
mock_test_result = [{"result": 1.0}]
|
||||
mock_trainer.test.return_value = mock_test_result
|
||||
mock_create_trainer.return_value = mock_trainer, ""
|
||||
def _expected_files_exist() -> int:
|
||||
output_dir = ml_runner_with_container.container.outputs_folder
|
||||
expected_files = [Path("test_mse.txt"), Path("test_mae.txt")]
|
||||
return sum([p.exists() for p in expected_files] + [output_dir.is_dir()])
|
||||
|
||||
with patch.object(ml_runner, "container") as mock_container:
|
||||
mock_container.num_gpus_per_node.return_value = 0
|
||||
mock_container.get_trainer_arguments.return_value = {"callbacks": Callback()}
|
||||
mock_container.load_model_checkpoint.return_value = Mock()
|
||||
mock_container.get_data_module.return_value = Mock()
|
||||
mock_container.pl_progress_bar_refresh_rate = None
|
||||
mock_container.detect_anomaly = False
|
||||
mock_container.pl_limit_train_batches = 1.0
|
||||
mock_container.pl_limit_val_batches = 1.0
|
||||
mock_container.outputs_folder = tmp_path
|
||||
# create the test data
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
checkpoint_paths = [Path("dummy")]
|
||||
result = ml_runner.run_inference_for_lightning_models(checkpoint_paths)
|
||||
assert result == mock_test_result
|
||||
N = 100
|
||||
x = torch.rand((N, 1)) * 10
|
||||
y = 0.2 * x + 0.1 * torch.randn(x.size())
|
||||
xy = torch.cat((x, y), dim=1)
|
||||
data_path = tmp_path / "hellocontainer.csv"
|
||||
np.savetxt(data_path, xy.numpy(), delimiter=",")
|
||||
|
||||
mock_create_trainer.assert_called_once()
|
||||
mock_container.load_model_checkpoint.assert_called_once()
|
||||
mock_container.get_data_module.assert_called_once()
|
||||
mock_trainer.test.assert_called_once()
|
||||
expected_ckpt_path = ml_runner_with_container.container.outputs_folder / "checkpoints" / "last.ckpt"
|
||||
assert not expected_ckpt_path.exists()
|
||||
# update the container to look for test data at this location
|
||||
ml_runner_with_container.container.local_dataset_dir = tmp_path
|
||||
assert _expected_files_exist() == 0
|
||||
|
||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
assert actual_train_ckpt_path is None
|
||||
ml_runner_with_container.run()
|
||||
actual_train_ckpt_path = ml_runner_with_container.checkpoint_handler.get_recovery_or_checkpoint_path_train()
|
||||
assert actual_train_ckpt_path == expected_ckpt_path
|
||||
|
||||
actual_test_ckpt_path = ml_runner_with_container.checkpoint_handler.get_checkpoints_to_test()
|
||||
assert actual_test_ckpt_path == [expected_ckpt_path]
|
||||
assert actual_test_ckpt_path[0].exists()
|
||||
# After training, the outputs directory should now exist
|
||||
assert _expected_files_exist() == 3
|
||||
|
||||
# if no checkpoint handler, no checkpoint paths will be saved and these are required for
|
||||
# inference so ValueError will be raised
|
||||
with pytest.raises(ValueError) as e:
|
||||
ml_runner_with_container.checkpoint_handler = None # type: ignore
|
||||
ml_runner_with_container.run()
|
||||
assert "expects exactly 1 checkpoint for inference, but got 0" in str(e)
|
||||
|
||||
Path("test_mae.txt").unlink()
|
||||
Path("test_mse.txt").unlink()
|
||||
|
|
|
@ -1,24 +1,23 @@
|
|||
coverage==5.5
|
||||
conda-merge==0.1.5
|
||||
flake8==3.8.4
|
||||
gitpython==3.1.7
|
||||
lightning-bolts==0.4.0
|
||||
matplotlib==3.3.0
|
||||
monai==0.6.0
|
||||
more-itertools==8.10.0
|
||||
mypy==0.910
|
||||
opencv-python-headless==4.5.1.48
|
||||
pandas==1.3.4
|
||||
param==1.9.3
|
||||
pillow==8.3.2
|
||||
pydicom==2.0.0
|
||||
pylint==2.9.5
|
||||
pycobertura==2.0.1
|
||||
pytest==6.2.2
|
||||
pytest-cov==2.11.1
|
||||
pytest-timeout==2.0.1
|
||||
types-requests==2.25.6
|
||||
conda-merge==0.1.5
|
||||
gitpython==3.1.7
|
||||
lightning-bolts==0.5.0
|
||||
matplotlib==3.3.0
|
||||
monai==0.6.0
|
||||
more-itertools==8.10.0
|
||||
opencv-python-headless==4.5.1.48
|
||||
pandas==1.3.4
|
||||
param==1.9.3
|
||||
pydicom==2.0.0
|
||||
pytorch-lightning==1.5.5
|
||||
pillow==8.3.2
|
||||
ruamel.yaml==0.16.12
|
||||
rpdb==0.1.6
|
||||
scikit-learn==1.0
|
||||
|
@ -27,4 +26,5 @@ simpleitk==1.2.4
|
|||
torch==1.10.0
|
||||
torchmetrics==0.6.0
|
||||
torchvision==0.11.1
|
||||
types-requests==2.25.6
|
||||
yacs==0.1.8
|
Загрузка…
Ссылка в новой задаче