Run inference using checkpoints from registered models (#509)

This commit is contained in:
Shruthi42 2021-07-15 15:31:15 +01:00 коммит произвёл GitHub
Родитель 732ddfcb34
Коммит 9fcc08f6cd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
25 изменённых файлов: 525 добавлений и 427 удалений

Просмотреть файл

@ -18,7 +18,8 @@ module on test data with partial ground truth files. (Also [522](https://github.
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference.
- ([#492](https://github.com/microsoft/InnerEye-DeepLearning/pull/492)) Adding capability for regression tests for test
jobs that run in AzureML.
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Run inference on registered models (single and
ensemble) using the parameter `model_id`.
### Changed
- ([#531])(https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
- ([#533](https://github.com/microsoft/InnerEye-DeepLearning/pull/533)) Better defaults for inference on ensemble children.
@ -26,6 +27,8 @@ jobs that run in AzureML.
- ([#496](https://github.com/microsoft/InnerEye-DeepLearning/pull/496)) All plots are now saved as PNG, rather than JPG.
- ([#497](https://github.com/microsoft/InnerEye-DeepLearning/pull/497)) Reducing the size of the code snapshot that
gets uploaded to AzureML, by skipping all test folders.
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameter `extra_downloaded_run_id` has been
renamed to `pretraining_run_checkpoints`.
### Fixed
- ([#525](https://github.com/microsoft/InnerEye-DeepLearning/pull/525)) Enable --store_dataset_sample
@ -40,11 +43,15 @@ LightningContainer models can get stuck at test set inference.
multiple large checkpoints can time out.
- ([#515](https://github.com/microsoft/InnerEye-DeepLearning/pull/515)) Workaround for occasional issues with dataset
mounting and running matplotblib on some machines. Re-instantiated a disabled test.
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Fix issue where model checkpoints were not loaded
in inference-only runs when using lightning containers.
### Removed
- ([#542](https://github.com/microsoft/InnerEye-DeepLearning/pull/542)) Removed Windows test leg from build pipeline.
- ([#520](https://github.com/microsoft/InnerEye-DeepLearning/pull/520)) Disable glaucoma job from Azure pipeline.
- ([#542](https://github.com/microsoft/InnerEye-DeepLearning/pull/542)) Removed Windows test leg from build pipeline.
- ([#520](https://github.com/microsoft/InnerEye-DeepLearning/pull/520)) Disable glaucoma job from Azure pipeline.
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameters `local_weights_path` and
`weights_url` can no longer be used to initialize a training run, only inference runs.
### Deprecated

Просмотреть файл

@ -77,14 +77,9 @@ class AzureConfig(GenericConfig):
pytest_mark: str = param.String(doc="If provided, run pytest instead of model training. pytest will only "
"run the tests that have the mark given in this argument "
"('--pytest_mark gpu' will run all tests marked with 'pytest.mark.gpu')")
run_recovery_id: str = param.String(doc="A run recovery id string in the form 'experiment name:run id'"
" to use for inference or recovering a model training run.")
pretraining_run_recovery_id: str = param.String(default=None,
allow_None=True,
doc="Extra run recovery id to download checkpoints from,"
"for custom modules (e.g. for loading pretrained weights)."
"Warning: this argument will be ignored for InnerEyeContainer"
"models.")
run_recovery_id: str = param.String(doc="A run recovery id string in the form 'experiment name:run id' "
"to use for inference, recovering a model training run or to register "
"a model.")
experiment_name: str = param.String(doc="If provided, use this string as the name of the AzureML experiment. "
"If not provided, create the experiment off the git branch name.")
build_number: int = param.Integer(0, doc="The numeric ID of the Azure pipeline that triggered this training run.")

Просмотреть файл

@ -40,9 +40,6 @@ DEFAULT_MODEL_SUMMARIES_DIR_PATH = Path(DEFAULT_LOGS_DIR_NAME) / "model_summarie
# The folder at the project root directory that holds datasets for local execution.
DATASETS_DIR_NAME = "datasets"
# Points to a folder at the project root directory that holds model weights downloaded from URLs.
MODEL_WEIGHTS_DIR_NAME = "modelweights"
ML_RELATIVE_SOURCE_PATH = os.path.join("ML")
ML_RELATIVE_RUNNER_PATH = os.path.join(ML_RELATIVE_SOURCE_PATH, "runner.py")
ML_FULL_SOURCE_FOLDER_PATH = str(repository_root_directory() / ML_RELATIVE_SOURCE_PATH)

Просмотреть файл

@ -170,21 +170,21 @@ class CovidHierarchicalModel(ScalarModelBase):
def _get_ssl_checkpoint_path(self) -> Path:
# Get the SSL weights from the AML run provided via "pretraining_run_recovery_id" command line argument.
# Accessible via extra_downloaded_run_id field of the config.
assert self.extra_downloaded_run_id is not None
assert isinstance(self.extra_downloaded_run_id, RunRecovery)
# Accessible via pretraining_run_checkpoints field of the config.
assert self.pretraining_run_checkpoints is not None
assert isinstance(self.pretraining_run_checkpoints, RunRecovery)
ssl_path = self.checkpoint_folder / "ssl_checkpoint.ckpt"
if not ssl_path.exists(): # for test (when it is already present) we don't need to redo this.
if self.name_of_checkpoint is not None:
logging.info(f"Using checkpoint: {self.name_of_checkpoint} as starting point.")
path_to_checkpoint = self.extra_downloaded_run_id.checkpoints_roots[0] / self.name_of_checkpoint
path_to_checkpoint = self.pretraining_run_checkpoints.checkpoints_roots[0] / self.name_of_checkpoint
else:
path_to_checkpoint = self.extra_downloaded_run_id.get_best_checkpoint_paths()[0]
path_to_checkpoint = self.pretraining_run_checkpoints.get_best_checkpoint_paths()[0]
if not path_to_checkpoint.exists():
logging.info("No best checkpoint found for this model. Getting the latest recovery "
"checkpoint instead.")
path_to_checkpoint = self.extra_downloaded_run_id.get_recovery_checkpoint_paths()[0]
path_to_checkpoint = self.pretraining_run_checkpoints.get_recovery_checkpoint_paths()[0]
assert path_to_checkpoint.exists()
path_to_checkpoint.rename(ssl_path)
return ssl_path

Просмотреть файл

@ -230,7 +230,7 @@ class HelloRegression(LightningModule):
"""
average_mse = torch.mean(torch.stack(self.test_mse))
Path("test_mse.txt").write_text(str(average_mse.item()))
Path("test_mae.txt").write_text(str(self.test_mae.compute()))
Path("test_mae.txt").write_text(str(self.test_mae.compute().item()))
class HelloContainer(LightningContainer):

Просмотреть файл

@ -33,7 +33,6 @@ VISUALIZATION_FOLDER = "visualizations"
EXTRA_RUN_SUBFOLDER = "extra_run_id"
ARGS_TXT = "args.txt"
WEIGHTS_FILE = "weights.pth"
@unique
@ -216,16 +215,25 @@ class WorkflowParams(param.Parameterized):
ensemble_inference_on_test_set: Optional[bool] = \
param.Boolean(None,
doc="If set, enable/disable full image inference on test set after ensemble training.")
weights_url: str = param.String(doc="If provided, a url from which weights will be downloaded and used for model "
"initialization.")
local_weights_path: Optional[Path] = param.ClassSelector(class_=Path,
default=None,
allow_None=True,
doc="The path to the weights to use for model "
"initialization, when training outside AzureML.")
weights_url: List[str] = param.List(default=[], class_=str,
doc="If provided, a set of urls from which checkpoints will be downloaded"
"and used for inference.")
local_weights_path: List[Path] = param.List(default=[], class_=Path,
doc="A list of checkpoints paths to use for inference, "
"when the job is running outside Azure.")
model_id: str = param.String(default="",
doc="A model id string in the form 'model name:version' "
"to use a registered model for inference.")
generate_report: bool = param.Boolean(default=True,
doc="If True (default), write a modelling report in HTML format. If False,"
"do not write that report.")
pretraining_run_recovery_id: str = param.String(default=None,
allow_None=True,
doc="Extra run recovery id to download checkpoints from,"
"for custom modules (e.g. for loading pretrained weights)."
"The downloaded RunRecovery object will be available in"
"pretraining_run_checkpoints.")
# The default multiprocessing start_method in both PyTorch and the Python standard library is "fork" for Linux and
# "spawn" (the only available method) for Windows. There is some evidence that using "forkserver" on Linux
# can reduce the chance of stuck jobs.
@ -248,8 +256,13 @@ class WorkflowParams(param.Parameterized):
"be relative to the repository root directory.")
def validate(self) -> None:
if self.weights_url and self.local_weights_path:
raise ValueError("Cannot specify both local_weights_path and weights_url.")
if sum([bool(param) for param in [self.weights_url, self.local_weights_path, self.model_id]]) > 1:
raise ValueError("Cannot specify more than one of local_weights_path, weights_url or model_id.")
if self.model_id:
if len(self.model_id.split(":")) != 2:
raise ValueError(
f"model_id should be in the form 'model_name:version', got {self.model_id}")
if self.number_of_cross_validation_splits == 1:
raise ValueError("At least two splits required to perform cross validation, but got "
@ -713,7 +726,7 @@ class DeepLearningConfig(WorkflowParams,
self.create_filesystem(fixed_paths.repository_root_directory())
# Disable the PL progress bar because all InnerEye models have their own console output
self.pl_progress_bar_refresh_rate = 0
self.extra_downloaded_run_id: Optional[Any] = None
self.pretraining_run_checkpoints: Optional[Any] = None
def validate(self) -> None:
"""

Просмотреть файл

@ -3,8 +3,8 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import abc
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple
from pathlib import Path
import param
import torch
@ -19,7 +19,7 @@ from InnerEye.Common.generic_parsing import GenericConfig, create_from_matching_
from InnerEye.Common.metrics_constants import TrackedMetrics
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.deep_learning_config import DatasetParams, OptimizerParams, OutputParams, TrainerParams, \
WorkflowParams, load_checkpoint
WorkflowParams
from InnerEye.ML.utils import model_util
from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp
from InnerEye.ML.utils.run_recovery import RunRecovery
@ -150,7 +150,7 @@ class LightningContainer(GenericConfig,
super().__init__(**kwargs)
self._model: Optional[LightningModule] = None
self._model_name = type(self).__name__
self.extra_downloaded_run_id: Optional[RunRecovery] = None
self.pretraining_run_checkpoints: Optional[RunRecovery] = None
self.num_nodes = 1
def validate(self) -> None:
@ -249,36 +249,6 @@ class LightningContainer(GenericConfig,
"""
pass
def load_checkpoint_and_modify(self, path_to_checkpoint: Path) -> Dict[str, Any]:
"""
This method is called when a file with weights for network initialization is supplied at container level,
in the self.weights_url or self.local_weights_path fields. It can load that file as a Torch checkpoint,
and rename parameters.
By default, uses torch.load to read and return the state dict from the checkpoint file, and does no modification
of the checkpoint file.
Overloading this function:
When weights_url or local_weights_path is set, the file downloaded may not be in the exact
format expected by the model's load_state_dict() - for example, pretrained Imagenet weights for networks
may have mismatched layer names in different implementations.
In such cases, you can overload this function to extract the state dict from the checkpoint.
NOTE: The model checkpoint will be loaded using the torch function load_state_dict() with argument strict=False,
so extra care needs to be taken to check that the state dict is valid.
Check the logs for warnings related to missing and unexpected keys.
See https://pytorch.org/tutorials/beginner/saving_loading_models.html#warmstarting-model-using-parameters
-from-a-different-model
for an explanation on why strict=False is useful when loading parameters from other models.
:param path_to_checkpoint: Path to the checkpoint file.
:return: Dictionary with model and optimizer state dicts. The dict should have at least the following keys:
1. Key ModelAndInfo.MODEL_STATE_DICT_KEY and value set to the model state dict.
2. Key ModelAndInfo.EPOCH_KEY and value set to the checkpoint epoch.
Other (optional) entries corresponding to keys ModelAndInfo.OPTIMIZER_STATE_DICT_KEY and
ModelAndInfo.MEAN_TEACHER_STATE_DICT_KEY are also supported.
"""
return load_checkpoint(path_to_checkpoint=path_to_checkpoint, use_gpu=self.use_gpu)
# The code from here on does not need to be modified.
@property
@ -333,6 +303,15 @@ class LightningContainer(GenericConfig,
else:
return self.get_parameter_search_hyperdrive_config(run_config)
def load_model_checkpoint(self, checkpoint_path: Path) -> None:
"""
Load a checkpoint from the given path. We need to define a separate method since pytorch lightning cannot
access the _model attribute to modify it.
"""
if self._model is None:
raise ValueError("No Lightning module has been set yet.")
self._model = type(self._model).load_from_checkpoint(checkpoint_path=str(checkpoint_path))
def __str__(self) -> str:
"""Returns a string describing the present object, as a list of key: value strings."""
arguments_str = "\nContainer:\n"

Просмотреть файл

@ -35,7 +35,6 @@ from InnerEye.ML.reports.segmentation_report import boxplot_per_structure
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils import io_util, ml_util
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
from InnerEye.ML.utils.image_util import binaries_from_multi_label_array
from InnerEye.ML.utils.io_util import ImageHeader, MedicalImageFileType, load_nifti_image, save_lines_to_file
from InnerEye.ML.utils.metrics_util import MetricsPerPatientWriter
@ -47,7 +46,7 @@ MODEL_OUTPUT_CSV = "model_outputs.csv"
def model_test(config: ModelConfigBase,
data_split: ModelExecutionMode,
checkpoint_handler: CheckpointHandler,
checkpoint_paths: List[Path],
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> Optional[InferenceMetrics]:
"""
Runs model inference on segmentation or classification models, using a given dataset (that could be training,
@ -55,7 +54,7 @@ def model_test(config: ModelConfigBase,
differ for model categories (classification, segmentation).
:param config: The configuration of the model
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
:param checkpoint_paths: Checkpoint paths to initialize model.
:param model_proc: whether we are testing an ensemble or single model; this affects where results are written.
:return: The metrics that the model achieved on the given data set, or None if the data set is empty.
"""
@ -67,17 +66,19 @@ def model_test(config: ModelConfigBase,
"and additional data loaders are likely to block.")
return None
with logging_section(f"Running {model_proc.value} model on {data_split.name.lower()} set"):
if not checkpoint_paths:
raise ValueError("There were no checkpoints available for model testing.")
if isinstance(config, SegmentationModelBase):
return segmentation_model_test(config, data_split, checkpoint_handler, model_proc)
return segmentation_model_test(config, data_split, checkpoint_paths, model_proc)
if isinstance(config, ScalarModelBase):
return classification_model_test(config, data_split, checkpoint_handler, model_proc,
return classification_model_test(config, data_split, checkpoint_paths, model_proc,
config.cross_validation_split_index)
raise ValueError(f"There is no testing code for models of type {type(config)}")
def segmentation_model_test(config: SegmentationModelBase,
execution_mode: ModelExecutionMode,
checkpoint_handler: CheckpointHandler,
checkpoint_paths: List[Path],
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> InferenceMetricsForSegmentation:
"""
The main testing loop for segmentation models.
@ -89,10 +90,6 @@ def segmentation_model_test(config: SegmentationModelBase,
:param patient_id: String which contains subject identifier.
:return: InferenceMetric object that contains metrics related for all of the checkpoint epochs.
"""
checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test()
if not checkpoints_to_test:
raise ValueError("There were no checkpoints available for model testing.")
epoch_results_folder = config.outputs_folder / get_best_epoch_results_path(execution_mode, model_proc)
# save the datasets.csv used
@ -100,7 +97,7 @@ def segmentation_model_test(config: SegmentationModelBase,
epoch_and_split = f"{execution_mode.value} set"
epoch_dice_per_image = segmentation_model_test_epoch(config=copy.deepcopy(config),
execution_mode=execution_mode,
checkpoint_paths=checkpoints_to_test,
checkpoint_paths=checkpoint_paths,
results_folder=epoch_results_folder,
epoch_and_split=epoch_and_split)
if epoch_dice_per_image is None:
@ -410,7 +407,7 @@ def create_metrics_dict_for_scalar_models(config: ScalarModelBase) -> \
def classification_model_test(config: ScalarModelBase,
data_split: ModelExecutionMode,
checkpoint_handler: CheckpointHandler,
checkpoint_paths: List[Path],
model_proc: ModelProcessing,
cross_val_split_index: int) -> InferenceMetricsForClassification:
"""
@ -419,16 +416,12 @@ def classification_model_test(config: ScalarModelBase,
:param config: The model configuration.
:param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir,
used mainly in model evaluation using different dataset splits.
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
:param checkpoint_paths: Checkpoint paths to initialize model
:param model_proc: whether we are testing an ensemble or single model
:return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.
"""
posthoc_label_transform = config.get_posthoc_label_transform()
checkpoint_paths = checkpoint_handler.get_checkpoints_to_test()
if not checkpoint_paths:
raise ValueError("There were no checkpoints available for model testing.")
pipeline = create_inference_pipeline(config=config,
checkpoint_paths=checkpoint_paths)
if pipeline is None:

Просмотреть файл

@ -28,7 +28,6 @@ from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.lightning_loggers import AzureMLLogger, StoringLogger
from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \
get_subject_output_file_per_rank
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
TEMP_PREFIX = "temp/"
@ -204,14 +203,14 @@ def start_resource_monitor(config: LightningContainer) -> ResourceMonitor:
return resource_monitor
def model_train(checkpoint_handler: CheckpointHandler,
def model_train(checkpoint_path: Optional[Path],
container: LightningContainer,
num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]:
"""
The main training loop. It creates the Pytorch model based on the configuration options passed in,
creates a Pytorch Lightning trainer, and trains the model.
If a checkpoint was specified, then it loads the checkpoint before resuming training.
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
:param checkpoint_path: Checkpoint path for model initialization
:param num_nodes: The number of nodes to use in distributed training.
:param container: A container object that holds the training data in PyTorch Lightning format
and the model to train.
@ -219,8 +218,6 @@ def model_train(checkpoint_handler: CheckpointHandler,
the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when
fitting other models.
"""
# Get the path to the checkpoint to recover from
checkpoint_path = checkpoint_handler.get_recovery_path_train()
lightning_model = container.model
resource_monitor: Optional[ResourceMonitor] = None
@ -303,10 +300,6 @@ def model_train(checkpoint_handler: CheckpointHandler,
logging.info("Finished training")
# Since we have trained the model further, let the checkpoint_handler object know so it can handle
# checkpoints correctly.
checkpoint_handler.additional_training_done()
# Upload visualization directory to AML run context to be able to see it in the Azure UI.
if isinstance(container, InnerEyeContainer):
if container.config.max_batch_grad_cam > 0 and container.visualization_folder.exists():

Просмотреть файл

@ -39,7 +39,7 @@ from InnerEye.ML.baselines_util import compare_folders_and_run_outputs
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, DeepLearningConfig, FINAL_ENSEMBLE_MODEL_FOLDER, \
FINAL_MODEL_FOLDER, ModelCategory, MultiprocessingStartMethod, load_checkpoint
FINAL_MODEL_FOLDER, ModelCategory, MultiprocessingStartMethod, load_checkpoint, EXTRA_RUN_SUBFOLDER
from InnerEye.ML.lightning_base import InnerEyeContainer
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer
from InnerEye.ML.metrics import InferenceMetrics, InferenceMetricsForSegmentation
@ -53,6 +53,7 @@ from InnerEye.ML.reports.notebook_report import generate_classification_crossval
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
from InnerEye.ML.utils.run_recovery import RunRecovery
from InnerEye.ML.visualizers import activation_maps
from InnerEye.ML.visualizers.plot_cross_validation import \
get_config_and_results_for_offline_runs, plot_cross_validation_from_files
@ -223,11 +224,19 @@ class MLRunner:
run_context=RUN_CONTEXT)
self.checkpoint_handler.download_recovery_checkpoints_or_weights(only_return_path=not is_global_rank_zero())
if self.container.pretraining_run_recovery_id is not None:
run_to_recover = self.azure_config.fetch_run(self.container.pretraining_run_recovery_id.strip())
run_recovery_object = RunRecovery.download_all_checkpoints_from_run(self.container,
run_to_recover,
EXTRA_RUN_SUBFOLDER,
only_return_path=not is_global_rank_zero())
self.container.pretraining_run_checkpoints = run_recovery_object
# A lot of the code for the built-in InnerEye models expects the output paths directly in the config files.
if isinstance(self.container, InnerEyeContainer):
self.container.config.local_dataset = self.container.local_dataset
self.container.config.file_system_config = self.container.file_system_config
self.container.config.extra_downloaded_run_id = self.container.extra_downloaded_run_id
self.container.config.pretraining_run_checkpoints = self.container.pretraining_run_checkpoints
self.container.setup()
self.container.create_lightning_module_and_store()
self._has_setup_run = True
@ -361,9 +370,12 @@ class MLRunner:
# train a new model if required
if self.azure_config.train:
with logging_section("Model training"):
model_train(self.checkpoint_handler,
model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(),
container=self.container,
num_nodes=self.azure_config.num_nodes)
# Since we have trained the model further, let the checkpoint_handler object know so it can handle
# checkpoints correctly.
self.checkpoint_handler.additional_training_done()
# log the number of epochs used for model training
RUN_CONTEXT.log(name="Train epochs", value=self.container.num_epochs)
elif isinstance(self.container, InnerEyeContainer):
@ -374,14 +386,15 @@ class MLRunner:
# AzureML.
if not self.is_offline_run:
if self.should_register_model():
self.register_model(self.checkpoint_handler, ModelProcessing.DEFAULT)
self.register_model(self.checkpoint_handler.get_best_checkpoints(), ModelProcessing.DEFAULT)
if not self.azure_config.only_register_model:
checkpoint_paths_for_testing = self.checkpoint_handler.get_checkpoints_to_test()
if isinstance(self.container, InnerEyeContainer):
# Inference for the InnerEye built-in models
# We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run,
# because the current run is a single one. See the documentation of ModelProcessing for more details.
self.run_inference(self.checkpoint_handler, ModelProcessing.DEFAULT)
self.run_inference(checkpoint_paths_for_testing, ModelProcessing.DEFAULT)
if self.container.generate_report:
self.generate_report(ModelProcessing.DEFAULT)
@ -397,7 +410,7 @@ class MLRunner:
else:
# Inference for all models that are specified via LightningContainers.
with logging_section("Model inference"):
self.run_inference_for_lightning_models(self.checkpoint_handler.get_checkpoints_to_test())
self.run_inference_for_lightning_models(checkpoint_paths_for_testing)
# We can't enforce that files are written to the output folder, hence change the working directory
# manually
with change_working_directory(self.container.outputs_folder):
@ -483,26 +496,26 @@ class MLRunner:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
trainer, _ = create_lightning_trainer(self.container, num_nodes=1)
self.container.load_model_checkpoint(checkpoint_path=checkpoint_paths[0])
# When training models that are not built-in InnerEye models, we have no guarantee that they write
# files to the right folder. Best guess is to change the current working directory to where files should go.
with change_working_directory(self.container.outputs_folder):
trainer.test(self.container.model,
test_dataloaders=self.container.get_data_module().test_dataloader(),
ckpt_path=str(checkpoint_paths[0]))
test_dataloaders=self.container.get_data_module().test_dataloader())
else:
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")
def run_inference(self, checkpoint_handler: CheckpointHandler,
def run_inference(self, checkpoint_paths: List[Path],
model_proc: ModelProcessing) -> None:
"""
Run inference on InnerEyeContainer models
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
:param checkpoint_paths: Checkpoint paths to initialize model
:param model_proc: whether we are running an ensemble model from within a child run with index 0. If we are,
then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
"""
# run full image inference on existing or newly trained model on the training, and testing set
self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
self.model_inference_train_and_test(checkpoint_paths=checkpoint_paths,
model_proc=model_proc)
self.try_compare_scores_against_baselines(model_proc)
@ -594,12 +607,12 @@ class MLRunner:
compare_scores_against_baselines(self.model_config, self.azure_config, model_proc)
def register_model(self,
checkpoint_handler: CheckpointHandler,
checkpoint_paths: List[Path],
model_proc: ModelProcessing) -> Tuple[model.Model, Any]:
"""
Registers a new model in the workspace's model registry on AzureML to be deployed further.
The AzureML run's tags are updated to describe with information about ensemble creation and the parent run ID.
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model registration.
:param checkpoint_path: Checkpoint paths to register.
:param model_proc: whether it's a single or ensemble model.
:returns Tuple element 1: AML model object, or None if no model could be registered.
Tuple element 2: The result of running the model_deployment_hook, or None if no hook was supplied.
@ -607,7 +620,6 @@ class MLRunner:
if self.is_offline_run:
raise ValueError("Cannot register models when InnerEye is running outside of AzureML.")
checkpoint_paths = checkpoint_handler.get_checkpoints_to_test()
if not checkpoint_paths:
raise ValueError("Model registration failed: No checkpoints found")
@ -761,7 +773,7 @@ class MLRunner:
raise ValueError(f"Checkpoint file {checkpoint_source} does not exist")
def model_inference_train_and_test(self,
checkpoint_handler: CheckpointHandler,
checkpoint_paths: List[Path],
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> \
Dict[ModelExecutionMode, InferenceMetrics]:
metrics: Dict[ModelExecutionMode, InferenceMetrics] = {}
@ -770,7 +782,7 @@ class MLRunner:
for data_split in ModelExecutionMode:
if self.container.inference_on_set(model_proc, data_split):
opt_metrics = model_test(config, data_split=data_split, checkpoint_handler=checkpoint_handler,
opt_metrics = model_test(config, data_split=data_split, checkpoint_paths=checkpoint_paths,
model_proc=model_proc)
if opt_metrics is not None:
metrics[data_split] = opt_metrics
@ -831,10 +843,10 @@ class MLRunner:
# AzureML.
if not self.is_offline_run:
if self.should_register_model():
self.register_model(checkpoint_handler, ModelProcessing.ENSEMBLE_CREATION)
self.register_model(checkpoint_handler.get_best_checkpoints(), ModelProcessing.ENSEMBLE_CREATION)
if not self.azure_config.only_register_model:
self.run_inference(checkpoint_handler=checkpoint_handler,
self.run_inference(checkpoint_paths=checkpoint_handler.get_checkpoints_to_test(),
model_proc=ModelProcessing.ENSEMBLE_CREATION)
crossval_dir = self.plot_cross_validation_and_upload_results()

Просмотреть файл

@ -11,15 +11,18 @@ from typing import List, Optional
from urllib.parse import urlparse
import requests
import torch
from azureml.core import Run
from azureml.core import Run, Workspace, Model
from InnerEye.Azure.azure_config import AzureConfig
from InnerEye.Common import fixed_paths
from InnerEye.Common.fixed_paths import MODEL_INFERENCE_JSON_FILE_NAME
from InnerEye.ML.common import find_recovery_checkpoint_and_epoch
from InnerEye.ML.deep_learning_config import EXTRA_RUN_SUBFOLDER, OutputParams, WEIGHTS_FILE
from InnerEye.ML.deep_learning_config import OutputParams
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.utils.run_recovery import RunRecovery
from InnerEye.ML.model_inference_config import read_model_inference_config
MODEL_WEIGHTS_DIR_NAME = "trained_models"
class CheckpointHandler:
@ -35,7 +38,7 @@ class CheckpointHandler:
self.run_recovery: Optional[RunRecovery] = None
self.project_root = project_root
self.run_context = run_context
self.local_weights_path: Optional[Path] = None
self.trained_weights_paths: List[Path] = []
self.has_continued_training = False
@property
@ -47,7 +50,7 @@ class CheckpointHandler:
def download_checkpoints_from_hyperdrive_child_runs(self, hyperdrive_parent_run: Run) -> None:
"""
Downloads the best checkpoints from all child runs of a Hyperdrive parent runs. This is used to gather results
Downloads the best checkpoints from all child runs of a Hyperdrive parent run. This is used to gather results
for ensemble creation.
"""
self.run_recovery = RunRecovery.download_best_checkpoints_from_child_runs(self.output_params,
@ -70,21 +73,9 @@ class CheckpointHandler:
run_to_recover = self.azure_config.fetch_run(self.azure_config.run_recovery_id.strip())
self.run_recovery = RunRecovery.download_all_checkpoints_from_run(self.output_params, run_to_recover,
only_return_path=only_return_path)
else:
self.run_recovery = None
if self.azure_config.pretraining_run_recovery_id is not None:
run_to_recover = self.azure_config.fetch_run(self.azure_config.pretraining_run_recovery_id.strip())
run_recovery_object = RunRecovery.download_all_checkpoints_from_run(self.output_params,
run_to_recover,
EXTRA_RUN_SUBFOLDER,
only_return_path=only_return_path)
self.container.extra_downloaded_run_id = run_recovery_object
else:
self.container.extra_downloaded_run_id = None
if self.container.weights_url or self.container.local_weights_path:
self.local_weights_path = self.get_and_save_modified_weights()
if self.container.weights_url or self.container.local_weights_path or self.container.model_id:
self.trained_weights_paths = self.get_local_checkpoints_path_or_download()
def additional_training_done(self) -> None:
"""
@ -92,7 +83,7 @@ class CheckpointHandler:
"""
self.has_continued_training = True
def get_recovery_path_train(self) -> Optional[Path]:
def get_recovery_or_checkpoint_path_train(self) -> Optional[Path]:
"""
Decides the checkpoint path to use for the current training run. Looks for the latest checkpoint in the
checkpoint folder. If run_recovery is provided, the checkpoints will have been downloaded to this folder
@ -105,29 +96,19 @@ class CheckpointHandler:
local_recovery_path, recovery_epoch = recovery
self.container._start_epoch = recovery_epoch
return local_recovery_path
elif self.local_weights_path:
return self.local_weights_path
else:
return None
def get_best_checkpoint(self) -> List[Path]:
def get_best_checkpoints(self) -> List[Path]:
"""
Get a list of checkpoints per epoch for testing/registration.
1. If a run recovery object is used and no training was done in this run, use checkpoints from run recovery.
2. If a run recovery object is used, and training was done in this run, but the start epoch is larger than
the epoch parameter provided, use checkpoints from run recovery.
3. If a run recovery object is used, and training was done in this run, but the start epoch is smaller than
the epoch parameter provided, use checkpoints from the current training run.
This function also checks that all the checkpoints at the returned checkpoint paths exist,
and drops any that do not.
Get a list of checkpoints per epoch for testing/registration from the current training run.
This function also checks that the checkpoint at the returned checkpoint path exists.
"""
if not self.run_recovery and not self.has_continued_training:
raise ValueError("Cannot recover checkpoint, no run recovery object provided and"
raise ValueError("Cannot recover checkpoint, no run recovery object provided and "
"no training has been done in this run.")
checkpoint_paths = []
if self.run_recovery:
checkpoint_paths = self.run_recovery.get_best_checkpoint_paths()
@ -165,77 +146,81 @@ class CheckpointHandler:
checkpoints = []
# If recovery object exists, or model was trained, look for checkpoints by epoch
# If model was trained, look for the best checkpoint
if self.run_recovery or self.has_continued_training:
checkpoints = self.get_best_checkpoint()
elif self.local_weights_path and not self.has_continued_training:
# No recovery object and model was not trained, check if there is a local weight path.
if self.local_weights_path.exists():
logging.info(f"Using model weights at {self.local_weights_path} to initialize model")
checkpoints = [self.local_weights_path]
else:
logging.warning(f"local_weights_path does not exist, "
f"cannot recover from {self.local_weights_path}")
checkpoints = self.get_best_checkpoints()
elif self.trained_weights_paths:
# Model was not trained, check if there is a local weight path.
logging.info(f"Using model weights from {self.trained_weights_paths} to initialize model")
checkpoints = self.trained_weights_paths
else:
logging.warning("Could not find any run recovery object or local_weights_path to get checkpoints from")
logging.warning("Could not find any local_weights_path, model_weights or model_id to get checkpoints from")
return checkpoints
def download_weights(self) -> Path:
@staticmethod
def download_weights(urls: List[str], download_folder: Path) -> List[Path]:
"""
Download a checkpoint from weights_url to the modelweights directory.
"""
target_folder = self.project_root / fixed_paths.MODEL_WEIGHTS_DIR_NAME
target_folder.mkdir(exist_ok=True)
checkpoint_paths = []
for url in urls:
# assign the same filename as in the download url if possible, so that we can check for duplicates
# If that fails, map to a random uuid
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
result_file = download_folder / file_name
checkpoint_paths.append(result_file)
# only download if hasn't already been downloaded
if result_file.exists():
logging.info(f"File already exists, skipping download: {result_file}")
else:
logging.info(f"Downloading weights from URL {url}")
url = self.container.weights_url
response = requests.get(url, stream=True)
response.raise_for_status()
with open(result_file, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
# assign the same filename as in the download url if possible, so that we can check for duplicates
# If that fails, map to a random uuid
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
result_file = target_folder / file_name
return checkpoint_paths
# only download if hasn't already been downloaded
if result_file.exists():
logging.info(f"File already exists, skipping download: {result_file}")
return result_file
@staticmethod
def get_checkpoints_from_model(model_id: str, workspace: Workspace, download_path: Path) -> List[Path]:
if len(model_id.split(":")) != 2:
raise ValueError(
f"model_id should be in the form 'model_name:version', got {model_id}")
logging.info(f"Downloading weights from URL {url}")
model_name, model_version = model_id.split(":")
model = Model(workspace=workspace, name=model_name, version=int(model_version))
model_path = Path(model.download(str(download_path), exist_ok=True))
model_inference_config = read_model_inference_config(model_path / MODEL_INFERENCE_JSON_FILE_NAME)
checkpoint_paths = [model_path / x for x in model_inference_config.checkpoint_paths]
return checkpoint_paths
response = requests.get(url, stream=True)
response.raise_for_status()
with open(result_file, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
file.write(chunk)
return result_file
def get_local_weights_path_or_download(self) -> Optional[Path]:
def get_local_checkpoints_path_or_download(self) -> List[Path]:
"""
Get the path to the local weights to use or download them and set local_weights_path
"""
if self.container.local_weights_path:
weights_path = self.container.local_weights_path
elif self.container.weights_url:
weights_path = self.download_weights()
else:
raise ValueError("Cannot download/modify weights - neither local_weights_path nor weights_url is set in"
if not self.container.model_id and not self.container.local_weights_path and not self.container.weights_url:
raise ValueError("Cannot download weights - none of model_id, local_weights_path or weights_url is set in "
"the model config.")
return weights_path
if self.container.local_weights_path:
checkpoint_paths = self.container.local_weights_path
else:
download_folder = self.output_params.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
download_folder.mkdir(exist_ok=True, parents=True)
def get_and_save_modified_weights(self) -> Path:
"""
Downloads the checkpoint weights if needed.
Then passes the downloaded or local checkpoint to the modify_checkpoint function from the model_config and saves
the modified state dict from the function in the outputs folder with the name weights.pth.
"""
weights_path = self.get_local_weights_path_or_download()
if self.container.model_id:
checkpoint_paths = CheckpointHandler.get_checkpoints_from_model(model_id=self.container.model_id,
workspace=self.azure_config.get_workspace(),
download_path=download_folder)
elif self.container.weights_url:
urls = self.container.weights_url
checkpoint_paths = CheckpointHandler.download_weights(urls=urls,
download_folder=download_folder)
if not weights_path or not weights_path.is_file():
raise FileNotFoundError(f"Could not find the weights file at {weights_path}")
modified_weights = self.container.load_checkpoint_and_modify(weights_path)
target_file = self.output_params.outputs_folder / WEIGHTS_FILE
torch.save(modified_weights, target_file)
return target_file
for checkpoint_path in checkpoint_paths:
if not checkpoint_path or not checkpoint_path.is_file():
raise FileNotFoundError(f"Could not find the weights file at {checkpoint_path}")
return checkpoint_paths

Просмотреть файл

@ -36,16 +36,18 @@ from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.Common.spawn_subprocess import spawn_and_monitor_subprocess
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode
from InnerEye.ML.configs.segmentation.BasicModel2Epochs import BasicModel2Epochs
from InnerEye.ML.configs.other.HelloContainer import HelloContainer
from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, ModelCategory
from InnerEye.ML.model_inference_config import read_model_inference_config
from InnerEye.ML.model_testing import THUMBNAILS_FOLDER
from InnerEye.ML.reports.notebook_report import get_html_report_name
from InnerEye.ML.runner import main
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.utils.config_loader import ModelConfigLoader
from InnerEye.ML.utils.image_util import get_unit_image_header
from InnerEye.ML.utils.io_util import zip_random_dicom_series
from InnerEye.Scripts import submit_for_inference
from Tests.ML.util import assert_nifti_content, get_default_azure_config, get_nifti_shape
from Tests.ML.util import assert_nifti_content, get_default_azure_config, get_nifti_shape, get_default_workspace
FALLBACK_SINGLE_RUN = "refs_pull_498_merge:refs_pull_498_merge_1624292750_743430ab"
FALLBACK_ENSEMBLE_RUN = "refs_pull_498_merge:HD_4bf4efc3-182a-4596-8f93-76f128418142"
@ -87,10 +89,10 @@ def get_most_recent_run(fallback_run_id_for_local_execution: str = FALLBACK_SING
return get_default_azure_config().fetch_run(run_recovery_id=run_recovery_id)
def get_most_recent_model(fallback_run_id_for_local_execution: str = FALLBACK_SINGLE_RUN) -> Model:
def get_most_recent_model_id(fallback_run_id_for_local_execution: str = FALLBACK_SINGLE_RUN) -> str:
"""
Gets the string name of the most recently executed AzureML run, extracts which model that run had registered,
and return the instantiated model object.
and return the model id.
:param fallback_run_id_for_local_execution: A hardcoded AzureML run ID that is used when executing this code
on a local box, outside of Azure build agents.
"""
@ -101,7 +103,18 @@ def get_most_recent_model(fallback_run_id_for_local_execution: str = FALLBACK_SI
tags = run.get_tags()
model_id = tags.get(MODEL_ID_KEY_NAME, None)
assert model_id, f"No model_id tag was found on run {most_recent_run}"
return Model(workspace=azure_config.get_workspace(), id=model_id)
return model_id
def get_most_recent_model(fallback_run_id_for_local_execution: str = FALLBACK_SINGLE_RUN) -> Model:
"""
Gets the string name of the most recently executed AzureML run, extracts which model that run had registered,
and return the instantiated model object.
:param fallback_run_id_for_local_execution: A hardcoded AzureML run ID that is used when executing this code
on a local box, outside of Azure build agents.
"""
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=fallback_run_id_for_local_execution)
return Model(workspace=get_default_workspace(), id=model_id)
def get_experiment_name_from_environment() -> str:
@ -433,3 +446,44 @@ def test_download_outputs_skipped(test_output_dirs: OutputFolderForTests) -> Non
download_run_outputs_by_prefix(prefix, test_output_dirs.root_dir, run=run)
all_files = list(test_output_dirs.root_dir.rglob("*"))
assert len(all_files) == 0
@pytest.mark.after_training_hello_container
def test_model_inference_on_single_run(test_output_dirs: OutputFolderForTests) -> None:
fallback_run_id_for_local_execution = FALLBACK_HELLO_CONTAINER_RUN
files_to_check = ["test_mse.txt", "test_mae.txt"]
training_run = get_most_recent_run(fallback_run_id_for_local_execution=fallback_run_id_for_local_execution)
all_training_files = training_run.get_file_names()
for file in files_to_check:
assert f"outputs/{file}" in all_training_files, f"{file} is missing"
training_folder = test_output_dirs.root_dir / "training"
training_folder.mkdir()
training_files = [training_folder / file for file in files_to_check]
for file, download_path in zip(files_to_check, training_files):
training_run.download_file(f"outputs/{file}", output_file_path=str(download_path))
container = HelloContainer()
container.set_output_to(test_output_dirs.root_dir)
container.model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=fallback_run_id_for_local_execution)
azure_config = get_default_azure_config()
azure_config.train = False
ml_runner = MLRunner(container=container, azure_config=azure_config, project_root=test_output_dirs.root_dir)
ml_runner.setup()
ml_runner.start_logging_to_file()
ml_runner.run()
inference_files = [container.outputs_folder / file for file in files_to_check]
for inference_file in inference_files:
assert inference_file.exists(), f"{inference_file} is missing"
for training_file, inference_file in zip(training_files, inference_files):
training_lines = training_file.read_text().splitlines()
inference_lines = inference_file.read_text().splitlines()
# We expect all the files we are reading to have a single float value
assert len(training_lines) == 1
train_value = float(training_lines[0].strip())
assert len(inference_lines) == 1
inference_value = float(inference_lines[0].strip())
assert inference_value == pytest.approx(train_value, 1e-6)

Просмотреть файл

@ -75,7 +75,7 @@ def test_non_image_encoder(test_output_dirs: OutputFolderForTests,
# run model inference
runner = MLRunner(config)
runner.setup()
runner.model_inference_train_and_test(checkpoint_handler=checkpoint_handler)
runner.model_inference_train_and_test(checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
assert config.get_total_number_of_non_imaging_features() == 18

Просмотреть файл

@ -44,5 +44,6 @@ def test_train_2d_classification_model(test_output_dirs: OutputFolderForTests,
assert actual_train_loss == pytest.approx(expected_train_loss, abs=1e-6)
assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6)
assert actual_lr == pytest.approx(expected_learning_rates, rel=1e-5)
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN, checkpoint_handler=checkpoint_handler)
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN,
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
assert isinstance(test_results, InferenceMetricsForClassification)

Просмотреть файл

@ -92,7 +92,7 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol
assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6), "Validation loss"
assert actual_lr == pytest.approx(expected_learning_rates, rel=1e-5), "Learning rates"
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN,
checkpoint_handler=checkpoint_handler)
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
assert isinstance(test_results, InferenceMetricsForClassification)
expected_metrics = [0.636085, 0.735952]
assert test_results.metrics.values(class_name)[MetricType.CROSS_ENTROPY.value] == \
@ -205,7 +205,7 @@ def test_train_classification_multilabel_model(test_output_dirs: OutputFolderFor
assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6), "Validation loss"
assert actual_lr == pytest.approx(expected_learning_rates, rel=1e-5), "Learning rates"
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN,
checkpoint_handler=checkpoint_handler)
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
assert isinstance(test_results, InferenceMetricsForClassification)
expected_metrics = {MetricType.CROSS_ENTROPY: [1.3996, 5.2966, 1.4020, 0.3553, 0.6908],
@ -387,7 +387,7 @@ def test_runner_restart(test_output_dirs: OutputFolderForTests) -> None:
checkpoint_handler = CheckpointHandler(azure_config=azure_config,
container=runner.container,
project_root=test_output_dirs.root_dir)
_, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
_, storing_logger = model_train(checkpoint_path=checkpoint_handler.get_recovery_or_checkpoint_path_train(),
container=runner.container)
# We expect to have 4 checkpoints, FIXED_EPOCH (recovery), FIXED_EPOCH+1, FIXED_EPOCH and best.
assert len(os.listdir(runner.container.checkpoint_folder)) == 4

Просмотреть файл

@ -233,7 +233,7 @@ def run_model_inference_train_and_test(test_output_dirs: OutputFolderForTests,
with mock.patch("InnerEye.ML.model_testing.PARENT_RUN_CONTEXT", Mock()) as m:
metrics = MLRunner(config).model_inference_train_and_test(
checkpoint_handler=checkpoint_handler,
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test(),
model_proc=model_proc)
if model_proc == ModelProcessing.DEFAULT:

Просмотреть файл

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:352b6e6e3dc074c7574b50892faa0474aef9a826e25e4fb8f8f41c7d97f5d8b0
size 6447

Просмотреть файл

@ -7,73 +7,107 @@ import pytest
from pathlib import Path
from InnerEye.ML.deep_learning_config import DeepLearningConfig
from InnerEye.ML.deep_learning_config import DatasetParams, WorkflowParams
def test_validate_dataset_params() -> None:
# DatasetParams cannot be initialized with neither of these these parameters set
with pytest.raises(ValueError) as ex:
DeepLearningConfig(local_dataset=None, azure_dataset_id="")
DatasetParams(local_dataset=None, azure_dataset_id="").validate()
assert ex.value.args[0] == "Either of local_dataset or azure_dataset_id must be set."
# The following should be okay
DeepLearningConfig(local_dataset=Path("foo"))
DeepLearningConfig(azure_dataset_id="bar")
DatasetParams(local_dataset=Path("foo")).validate()
DatasetParams(azure_dataset_id="bar").validate()
config = DeepLearningConfig(local_dataset=Path("foo"),
azure_dataset_id="",
extra_azure_dataset_ids=[])
config = DatasetParams(local_dataset=Path("foo"),
azure_dataset_id="",
extra_azure_dataset_ids=[])
config.validate()
assert not config.all_azure_dataset_ids()
config = DeepLearningConfig(azure_dataset_id="foo",
extra_azure_dataset_ids=[])
config = DatasetParams(azure_dataset_id="foo",
extra_azure_dataset_ids=[])
config.validate()
assert len(config.all_azure_dataset_ids()) == 1
config = DeepLearningConfig(local_dataset=Path("foo"),
azure_dataset_id="",
extra_azure_dataset_ids=["bar"])
config = DatasetParams(local_dataset=Path("foo"),
azure_dataset_id="",
extra_azure_dataset_ids=["bar"])
config.validate()
assert len(config.all_azure_dataset_ids()) == 1
config = DeepLearningConfig(azure_dataset_id="foo",
extra_azure_dataset_ids=["bar"])
config = DatasetParams(azure_dataset_id="foo",
extra_azure_dataset_ids=["bar"])
config.validate()
assert len(config.all_azure_dataset_ids()) == 2
config = DeepLearningConfig(azure_dataset_id="foo",
dataset_mountpoint="",
extra_dataset_mountpoints=[])
config = DatasetParams(azure_dataset_id="foo",
dataset_mountpoint="",
extra_dataset_mountpoints=[])
config.validate()
assert not config.all_dataset_mountpoints()
config = DeepLearningConfig(azure_dataset_id="foo",
dataset_mountpoint="foo",
extra_dataset_mountpoints=[])
config = DatasetParams(azure_dataset_id="foo",
dataset_mountpoint="foo",
extra_dataset_mountpoints=[])
config.validate()
assert len(config.all_dataset_mountpoints()) == 1
config = DeepLearningConfig(azure_dataset_id="foo",
dataset_mountpoint="",
extra_dataset_mountpoints=["bar"])
config = DatasetParams(azure_dataset_id="foo",
dataset_mountpoint="",
extra_dataset_mountpoints=["bar"])
config.validate()
assert len(config.all_dataset_mountpoints()) == 1
config = DeepLearningConfig(azure_dataset_id="foo",
extra_azure_dataset_ids=["bar"],
dataset_mountpoint="foo",
extra_dataset_mountpoints=["bar"])
config = DatasetParams(azure_dataset_id="foo",
extra_azure_dataset_ids=["bar"],
dataset_mountpoint="foo",
extra_dataset_mountpoints=["bar"])
config.validate()
assert len(config.all_dataset_mountpoints()) == 2
with pytest.raises(ValueError) as ex:
DeepLearningConfig(azure_dataset_id="foo",
dataset_mountpoint="foo",
extra_dataset_mountpoints=["bar"])
DatasetParams(azure_dataset_id="foo",
dataset_mountpoint="foo",
extra_dataset_mountpoints=["bar"]).validate()
assert "Expected the number of azure datasets to equal the number of mountpoints" in ex.value.args[0]
def test_validate_deep_learning_config() -> None:
def test_validate_workflow_params() -> None:
# DeepLearningConfig cannot be initialized with both these parameters set
# DeepLearningConfig cannot be initialized with more than one of these parameters set
with pytest.raises(ValueError) as ex:
DeepLearningConfig(local_dataset=Path("foo"),
local_weights_path=Path("foo"), weights_url="bar")
assert ex.value.args[0] == "Cannot specify both local_weights_path and weights_url."
WorkflowParams(local_dataset=Path("foo"),
local_weights_path=[Path("foo")],
weights_url=["bar"]).validate()
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
with pytest.raises(ValueError) as ex:
WorkflowParams(local_dataset=Path("foo"),
local_weights_path=[Path("foo")],
model_id="foo:1").validate()
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
with pytest.raises(ValueError) as ex:
WorkflowParams(local_dataset=Path("foo"),
weights_url=["foo"],
model_id="foo:1").validate()
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
with pytest.raises(ValueError) as ex:
WorkflowParams(local_dataset=Path("foo"),
local_weights_path=[Path("foo")],
weights_url=["foo"],
model_id="foo:1").validate()
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
with pytest.raises(ValueError) as ex:
WorkflowParams(local_dataset=Path("foo"),
model_id="foo").validate()
assert "model_id should be in the form 'model_name:version'" in ex.value.args[0]
# The following should be okay
DeepLearningConfig(local_dataset=Path("foo"), local_weights_path=Path("foo"))
DeepLearningConfig(local_dataset=Path("foo"), weights_url="bar")
WorkflowParams(local_dataset=Path("foo"), local_weights_path=[Path("foo")]).validate()
WorkflowParams(local_dataset=Path("foo"), weights_url=["foo"]).validate()
WorkflowParams(local_dataset=Path("foo"), model_id="foo:1").validate()

Просмотреть файл

@ -98,13 +98,13 @@ def test_innereye_container_init() -> None:
"""
# The constructor should copy all fields that belong to either WorkflowParams or DatasetParams from the
# config object to the container.
for (attrib, type_) in [("weights_url", WorkflowParams), ("azure_dataset_id", DatasetParams)]:
for (attrib, type_) in [("weights_url", WorkflowParams), ("extra_dataset_mountpoints", DatasetParams)]:
config = ModelConfigBase(should_validate=False)
assert hasattr(type_, attrib)
assert hasattr(config, attrib)
setattr(config, attrib, "foo")
setattr(config, attrib, ["foo"])
container = InnerEyeContainer(config)
assert getattr(container, attrib) == "foo"
assert getattr(container, attrib) == ["foo"]
def test_copied_properties() -> None:

Просмотреть файл

@ -102,7 +102,7 @@ def test_model_test(
checkpoint_handler.additional_training_done()
inference_results = model_testing.segmentation_model_test(config,
execution_mode=execution_mode,
checkpoint_handler=checkpoint_handler)
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
epoch_dir = config.outputs_folder / get_best_epoch_results_path(execution_mode)
total_num_patients_column_name = f"total_{MetricsFileColumns.Patient.value}".lower()
if not total_num_patients_column_name.endswith("s"):

Просмотреть файл

@ -38,7 +38,8 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool,
assert len(train_results.train_results_per_epoch()) == config.num_epochs
# Run inference on this
test_results = model_test(config=config, data_split=ModelExecutionMode.TEST, checkpoint_handler=checkpoint_handler)
test_results = model_test(config=config, data_split=ModelExecutionMode.TEST,
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
assert isinstance(test_results, InferenceMetricsForClassification)
# Mimic using a run recovery and see if it is the same
@ -55,7 +56,7 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool,
shutil.copytree(str(config.checkpoint_folder), str(checkpoint_root))
checkpoint_handler_run_recovery.run_recovery = RunRecovery([checkpoint_root])
test_results_run_recovery = model_test(config_run_recovery, data_split=ModelExecutionMode.TEST,
checkpoint_handler=checkpoint_handler_run_recovery)
checkpoint_paths=checkpoint_handler_run_recovery.get_checkpoints_to_test())
assert isinstance(test_results_run_recovery, InferenceMetricsForClassification)
assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
test_results_run_recovery.metrics.values()[MetricType.CROSS_ENTROPY.value]
@ -70,13 +71,13 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool,
local_weights_path = test_output_dirs.root_dir / "local_weights_file.pth"
shutil.copyfile(str(config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX),
local_weights_path)
config_local_weights.local_weights_path = local_weights_path
config_local_weights.local_weights_path = [local_weights_path]
checkpoint_handler_local_weights = get_default_checkpoint_handler(model_config=config_local_weights,
project_root=test_output_dirs.root_dir)
checkpoint_handler_local_weights.download_recovery_checkpoints_or_weights()
test_results_local_weights = model_test(config_local_weights, data_split=ModelExecutionMode.TEST,
checkpoint_handler=checkpoint_handler_local_weights)
checkpoint_paths=checkpoint_handler_local_weights.get_checkpoints_to_test())
assert isinstance(test_results_local_weights, InferenceMetricsForClassification)
assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
test_results_local_weights.metrics.values()[MetricType.CROSS_ENTROPY.value]

Просмотреть файл

@ -285,8 +285,9 @@ def model_train_unittest(config: Optional[DeepLearningConfig],
checkpoint_handler = CheckpointHandler(azure_config=azure_config,
container=runner.container,
project_root=dirs.root_dir)
_, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
_, storing_logger = model_train(checkpoint_path=checkpoint_handler.get_recovery_or_checkpoint_path_train(),
container=runner.container)
checkpoint_handler.additional_training_done()
return storing_logger, checkpoint_handler # type: ignore

Просмотреть файл

@ -5,22 +5,23 @@
import os
from pathlib import Path
from unittest import mock
from urllib.parse import urlparse
import pytest
import torch
from InnerEye.Common.common_util import OTHER_RUNS_SUBDIR_NAME
from InnerEye.Common.fixed_paths import MODEL_WEIGHTS_DIR_NAME
from InnerEye.Common.fixed_paths import MODEL_INFERENCE_JSON_FILE_NAME
from InnerEye.ML.utils.checkpoint_handling import MODEL_WEIGHTS_DIR_NAME
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, get_recovery_checkpoint_path
from InnerEye.ML.deep_learning_config import WEIGHTS_FILE
from InnerEye.ML.deep_learning_config import FINAL_MODEL_FOLDER, FINAL_ENSEMBLE_MODEL_FOLDER
from InnerEye.ML.model_config_base import ModelConfigBase
from InnerEye.ML.model_inference_config import read_model_inference_config
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
from Tests.AfterTraining.test_after_training import FALLBACK_ENSEMBLE_RUN, FALLBACK_SINGLE_RUN, get_most_recent_run, \
get_most_recent_run_id
from Tests.ML.configs.DummyModel import DummyModel
from Tests.ML.util import get_default_checkpoint_handler
get_most_recent_run_id, get_most_recent_model_id
from Tests.ML.util import get_default_checkpoint_handler, get_default_workspace
EXTERNAL_WEIGHTS_URL_EXAMPLE = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
@ -36,10 +37,9 @@ def create_checkpoint_file(file: Path) -> None:
assert loaded, "Unable to read the checkpoint file that was just created"
def test_use_local_weights_file(test_output_dirs: OutputFolderForTests) -> None:
def test_use_checkpoint_paths_or_urls(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
config.outputs_folder.mkdir()
# No checkpoint handling options set.
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
@ -47,29 +47,31 @@ def test_use_local_weights_file(test_output_dirs: OutputFolderForTests) -> None:
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert not checkpoint_handler.run_recovery
assert not checkpoint_handler.local_weights_path
assert not checkpoint_handler.trained_weights_paths
# weights from local_weights_path and weights_url will be modified if needed and stored at this location
expected_path = checkpoint_handler.output_params.outputs_folder / WEIGHTS_FILE
# Set a weights_path
checkpoint_handler.azure_config.run_recovery_id = ""
checkpoint_handler.container.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE
checkpoint_handler.container.weights_url = [EXTERNAL_WEIGHTS_URL_EXAMPLE]
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.local_weights_path == expected_path
assert checkpoint_handler.local_weights_path.is_file()
expected_download_path = checkpoint_handler.output_params.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME /\
os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
assert checkpoint_handler.trained_weights_paths[0] == expected_download_path
assert checkpoint_handler.trained_weights_paths[0].is_file()
# set a local_weights_path
checkpoint_handler.container.weights_url = ""
checkpoint_handler.container.weights_url = []
local_weights_path = test_output_dirs.root_dir / "exist.pth"
create_checkpoint_file(local_weights_path)
checkpoint_handler.container.local_weights_path = local_weights_path
checkpoint_handler.container.local_weights_path = [local_weights_path]
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.local_weights_path == expected_path
assert checkpoint_handler.trained_weights_paths[0] == local_weights_path
assert checkpoint_handler.trained_weights_paths[0].is_file()
@pytest.mark.after_training_single_run
def test_download_checkpoints_from_single_run(test_output_dirs: OutputFolderForTests) -> None:
def test_download_recovery_checkpoints_from_single_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
@ -92,7 +94,7 @@ def test_download_checkpoints_from_single_run(test_output_dirs: OutputFolderForT
@pytest.mark.after_training_ensemble_run
def test_download_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
def test_download_recovery_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
@ -104,33 +106,63 @@ def test_download_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderFo
assert "has child runs" in str(ex)
@pytest.mark.after_training_single_run
def test_download_model_from_single_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
# No checkpoint handling options set.
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
# Set a run recovery object - non ensemble
checkpoint_handler.container.model_id = model_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_paths
expected_model_root = config.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME / FINAL_MODEL_FOLDER
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
assert len(expected_paths) == 1 # A registered model for a non-ensemble run should contain only one checkpoint
assert len(checkpoint_handler.trained_weights_paths) == 1
assert expected_paths[0] == checkpoint_handler.trained_weights_paths[0]
assert expected_paths[0].is_file()
@pytest.mark.after_training_ensemble_run
def test_download_model_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
# No checkpoint handling options set.
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
# Set a run recovery object - non ensemble
checkpoint_handler.container.model_id = model_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.trained_weights_paths
expected_model_root = config.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME / FINAL_ENSEMBLE_MODEL_FOLDER
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
assert len(checkpoint_handler.trained_weights_paths) == len(expected_paths)
assert set(checkpoint_handler.trained_weights_paths) == set(expected_paths)
for path in expected_paths:
assert path.is_file()
def test_get_recovery_path_train(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
config.outputs_folder.mkdir()
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
assert checkpoint_handler.get_recovery_path_train() is None
# weights from local_weights_path and weights_url will be modified if needed and stored at this location
expected_path = checkpoint_handler.output_params.outputs_folder / WEIGHTS_FILE
# Set a weights_url to get checkpoint from
checkpoint_handler.azure_config.run_recovery_id = ""
checkpoint_handler.container.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.local_weights_path == expected_path
assert checkpoint_handler.get_recovery_path_train() == expected_path
# Set a local_weights_path to get checkpoint from
checkpoint_handler.container.weights_url = ""
local_weights_path = test_output_dirs.root_dir / "exist.pth"
create_checkpoint_file(local_weights_path)
checkpoint_handler.container.local_weights_path = local_weights_path
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.local_weights_path == expected_path
assert checkpoint_handler.get_recovery_path_train() == expected_path
assert checkpoint_handler.get_recovery_or_checkpoint_path_train() is None
@pytest.mark.after_training_single_run
@ -147,7 +179,7 @@ def test_get_recovery_path_train_single_run(test_output_dirs: OutputFolderForTes
# Run recovery with start epoch provided should succeed
expected_path = get_recovery_checkpoint_path(path=config.checkpoint_folder)
assert checkpoint_handler.get_recovery_path_train() == expected_path
assert checkpoint_handler.get_recovery_or_checkpoint_path_train() == expected_path
@pytest.mark.after_training_single_run
@ -159,8 +191,8 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
# We have not set a run_recovery, nor have we trained, so this should fail to get a checkpoint
with pytest.raises(ValueError) as ex:
checkpoint_handler.get_best_checkpoint()
assert "no run recovery object provided and no training has been done in this run" in ex.value.args[0]
checkpoint_handler.get_best_checkpoints()
assert "no run recovery object provided and no training has been done in this run" in ex.value.args[0]
run_recovery_id = get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
@ -169,7 +201,7 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
expected_checkpoint = config.checkpoint_folder / f"{BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}"
checkpoint_paths = checkpoint_handler.get_best_checkpoint()
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
assert checkpoint_paths
assert len(checkpoint_paths) == 1
assert expected_checkpoint == checkpoint_paths[0]
@ -183,7 +215,7 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
checkpoint_handler.download_recovery_checkpoints_or_weights()
# There is no checkpoint in the current run - use the one from run_recovery
checkpoint_paths = checkpoint_handler.get_best_checkpoint()
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
assert checkpoint_paths
assert len(checkpoint_paths) == 1
@ -192,24 +224,23 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
# Copy over checkpoints to make it look like training has happened and a better checkpoint written
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
expected_checkpoint.touch()
checkpoint_paths = checkpoint_handler.get_best_checkpoint()
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
assert checkpoint_paths
assert len(checkpoint_paths) == 1
assert expected_checkpoint == checkpoint_paths[0]
@pytest.mark.after_training_ensemble_run
def test_get_all_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
def test_download_checkpoints_from_hyperdrive_child_runs(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
config.outputs_folder.mkdir()
manage_recovery = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
hyperdrive_run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
manage_recovery.download_checkpoints_from_hyperdrive_child_runs(hyperdrive_run)
checkpoint_handler.download_checkpoints_from_hyperdrive_child_runs(hyperdrive_run)
expected_checkpoints = [config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME / str(i)
/ BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX for i in range(2)]
checkpoint_paths = manage_recovery.get_best_checkpoint()
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
assert checkpoint_paths
assert len(checkpoint_paths) == 2
assert set(expected_checkpoints) == set(checkpoint_paths)
@ -218,28 +249,27 @@ def test_get_all_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderFor
def test_get_checkpoints_to_test(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
config.outputs_folder.mkdir()
manage_recovery = get_default_checkpoint_handler(model_config=config,
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
# Set a local_weights_path to get checkpoint from. Model has not trained and no run recovery provided,
# so the local weights should be used ignoring any epochs to test
local_weights_path = test_output_dirs.root_dir / "exist.pth"
create_checkpoint_file(local_weights_path)
manage_recovery.container.local_weights_path = local_weights_path
manage_recovery.download_recovery_checkpoints_or_weights()
checkpoint_and_paths = manage_recovery.get_checkpoints_to_test()
checkpoint_handler.container.local_weights_path = [local_weights_path]
checkpoint_handler.download_recovery_checkpoints_or_weights()
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
assert checkpoint_and_paths
assert len(checkpoint_and_paths) == 1
assert checkpoint_and_paths[0] == manage_recovery.output_params.outputs_folder / WEIGHTS_FILE
assert checkpoint_and_paths[0] == local_weights_path
manage_recovery.additional_training_done()
manage_recovery.container.checkpoint_folder.mkdir()
checkpoint_handler.additional_training_done()
checkpoint_handler.container.checkpoint_folder.mkdir(parents=True)
# Copy checkpoint to make it seem like training has happened
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
expected_checkpoint.touch()
checkpoint_and_paths = manage_recovery.get_checkpoints_to_test()
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
assert checkpoint_and_paths
assert len(checkpoint_and_paths) == 1
@ -250,20 +280,19 @@ def test_get_checkpoints_to_test(test_output_dirs: OutputFolderForTests) -> None
def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
config.outputs_folder.mkdir()
manage_recovery = get_default_checkpoint_handler(model_config=config,
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
run_recovery_id = get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
# Now set a run recovery object and set the start epoch to 1, so we get one epoch from
# run recovery and one from the training checkpoints
manage_recovery.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
manage_recovery.additional_training_done()
manage_recovery.download_recovery_checkpoints_or_weights()
checkpoint_handler.additional_training_done()
checkpoint_handler.download_recovery_checkpoints_or_weights()
checkpoint_and_paths = manage_recovery.get_checkpoints_to_test()
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
assert checkpoint_and_paths
assert len(checkpoint_and_paths) == 1
@ -272,7 +301,7 @@ def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTes
# Copy checkpoint to make it seem like training has happened
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
expected_checkpoint.touch()
checkpoint_and_paths = manage_recovery.get_checkpoints_to_test()
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
assert checkpoint_and_paths
assert len(checkpoint_and_paths) == 1
@ -281,108 +310,94 @@ def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTes
def test_download_model_weights(test_output_dirs: OutputFolderForTests) -> None:
# Download a sample ResNet model from a URL given in the Pytorch docs
# The downloaded model does not match the architecture, which is okay since we are only testing the download here.
result_path = CheckpointHandler.download_weights(urls=[EXTERNAL_WEIGHTS_URL_EXAMPLE],
download_folder=test_output_dirs.root_dir)
assert len(result_path) == 1
assert result_path[0] == test_output_dirs.root_dir / os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
assert result_path[0].is_file()
model_config = DummyModel(weights_url=EXTERNAL_WEIGHTS_URL_EXAMPLE)
manage_recovery = get_default_checkpoint_handler(model_config=model_config,
project_root=test_output_dirs.root_dir)
result_path = manage_recovery.download_weights()
assert result_path.is_file()
modified_time = result_path[0].stat().st_mtime
result_path = CheckpointHandler.download_weights(urls=[EXTERNAL_WEIGHTS_URL_EXAMPLE, EXTERNAL_WEIGHTS_URL_EXAMPLE],
download_folder=test_output_dirs.root_dir)
assert len(result_path) == 2
assert len(list(test_output_dirs.root_dir.glob("*"))) == 1
assert result_path[0].samefile(result_path[1])
assert result_path[0] == test_output_dirs.root_dir / os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
assert result_path[0].is_file()
# This call should not re-download the files, just return the existing ones
assert result_path[0].stat().st_mtime == modified_time
@pytest.mark.after_training_single_run
def test_get_checkpoints_from_model_single_run(test_output_dirs: OutputFolderForTests) -> None:
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
workspace=get_default_workspace(),
download_path=test_output_dirs.root_dir)
# Check a single checkpoint has been downloaded
expected_model_root = test_output_dirs.root_dir / FINAL_MODEL_FOLDER
assert expected_model_root.is_dir()
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
assert len(expected_paths) == 1 # A registered model for a non-ensemble run should contain only one checkpoint
assert len(downloaded_checkpoints) == 1
assert expected_paths[0] == downloaded_checkpoints[0]
assert expected_paths[0].is_file()
@pytest.mark.after_training_ensemble_run
def test_get_checkpoints_from_model_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
workspace=get_default_workspace(),
download_path=test_output_dirs.root_dir)
# Check that all the ensemble checkpoints have been downloaded
expected_model_root = test_output_dirs.root_dir / FINAL_ENSEMBLE_MODEL_FOLDER
assert expected_model_root.is_dir()
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
assert len(expected_paths) == len(downloaded_checkpoints)
assert set(expected_paths) == set(downloaded_checkpoints)
for expected_path in expected_paths:
assert expected_path.is_file()
def test_get_local_weights_path_or_download(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
manage_recovery = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
config.set_output_to(test_output_dirs.root_dir)
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
# If the model has neither local_weights_path or weights_url set, should fail.
with pytest.raises(ValueError) as ex:
manage_recovery.get_local_weights_path_or_download()
assert "neither local_weights_path nor weights_url is set in the model config" in ex.value.args[0]
checkpoint_handler.get_local_checkpoints_path_or_download()
assert "none of model_id, local_weights_path or weights_url is set in the model config." in ex.value.args[0]
# If local_weights_path folder exists, get_local_weights_path_or_download should not do anything.
local_weights_path = manage_recovery.project_root / "exist.pth"
# If local_weights_path folder exists, get_local_checkpoints_path_or_download should not do anything.
local_weights_path = test_output_dirs.root_dir / "exist.pth"
create_checkpoint_file(local_weights_path)
manage_recovery.container.local_weights_path = local_weights_path
returned_weights_path = manage_recovery.get_local_weights_path_or_download()
assert local_weights_path == returned_weights_path
checkpoint_handler.container.local_weights_path = [local_weights_path]
returned_weights_path = checkpoint_handler.get_local_checkpoints_path_or_download()
assert local_weights_path == returned_weights_path[0]
# Pointing the model to a URL should trigger a download
manage_recovery.container.local_weights_path = None
manage_recovery.container.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE
downloaded_weights = manage_recovery.get_local_weights_path_or_download()
# Download goes into <project_root> / "modelweights" / "resnet18-5c106cde.pth"
expected_path = manage_recovery.project_root / MODEL_WEIGHTS_DIR_NAME / \
checkpoint_handler.container.local_weights_path = []
checkpoint_handler.container.weights_url = [EXTERNAL_WEIGHTS_URL_EXAMPLE]
downloaded_weights = checkpoint_handler.get_local_checkpoints_path_or_download()
expected_path = checkpoint_handler.output_params.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME / \
os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
assert downloaded_weights
assert downloaded_weights.is_file()
assert expected_path == downloaded_weights
assert len(downloaded_weights) == 1
assert downloaded_weights[0].is_file()
assert expected_path == downloaded_weights[0]
# try again, should not re-download
modified_time = downloaded_weights.stat().st_mtime
downloaded_weights_new = manage_recovery.get_local_weights_path_or_download()
assert downloaded_weights_new
assert downloaded_weights_new.stat().st_mtime == modified_time
def test_get_and_modify_local_weights(test_output_dirs: OutputFolderForTests) -> None:
config = ModelConfigBase(should_validate=False)
config.set_output_to(test_output_dirs.root_dir)
config.outputs_folder.mkdir()
manage_recovery = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
# If the model has neither local_weights_path or weights_url set, should fail.
with pytest.raises(ValueError) as ex:
manage_recovery.get_and_save_modified_weights()
assert "neither local_weights_path nor weights_url is set in the model config" in ex.value.args[0]
# Pointing the model to a local_weights_path that does not exist will raise an error.
manage_recovery.container.local_weights_path = manage_recovery.project_root / "non_exist"
with pytest.raises(FileNotFoundError) as file_ex:
manage_recovery.get_and_save_modified_weights()
assert "Could not find the weights file" in file_ex.value.args[0]
# Test that weights are properly modified when a local_weights_path is set
# set a method to modify weights:
with mock.patch.object(ModelConfigBase,
'load_checkpoint_and_modify',
lambda self, path_to_checkpoint: {"modified": "local", # type: ignore
"path": path_to_checkpoint}):
# Set the local_weights_path to an empty file, which will be passed to modify_checkpoint
local_weights_path = manage_recovery.project_root / "exist.pth"
create_checkpoint_file(local_weights_path)
manage_recovery.container.local_weights_path = local_weights_path
weights_path = manage_recovery.get_and_save_modified_weights()
expected_path = manage_recovery.output_params.outputs_folder / WEIGHTS_FILE
# read from weights_path and check that the dict has been written
assert weights_path.is_file()
assert expected_path == weights_path
read = torch.load(str(weights_path))
assert read.keys() == {"modified", "path"}
assert read["modified"] == "local"
assert read["path"] == local_weights_path
# clean up
weights_path.unlink()
# Test that weights are properly modified when weights_url is set
# set a different method to modify weights, to avoid using old files from other tests:
with mock.patch.object(ModelConfigBase,
'load_checkpoint_and_modify',
lambda self, path_to_checkpoint: {"modified": "url", "path": path_to_checkpoint}):
# Set the weights_url to the sample pytorch URL, which will be passed to modify_checkpoint
manage_recovery.container.local_weights_path = None
manage_recovery.container.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE
weights_path = manage_recovery.get_and_save_modified_weights()
expected_path = manage_recovery.output_params.outputs_folder / WEIGHTS_FILE
# read from weights_path and check that the dict has been written
assert weights_path.is_file()
assert expected_path == weights_path
read = torch.load(str(weights_path))
assert read.keys() == {"modified", "path"}
assert read["modified"] == "url"
assert read["path"] == manage_recovery.project_root / MODEL_WEIGHTS_DIR_NAME / \
os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
modified_time = downloaded_weights[0].stat().st_mtime
downloaded_weights_new = checkpoint_handler.get_local_checkpoints_path_or_download()
assert len(downloaded_weights_new) == 1
assert downloaded_weights_new[0].stat().st_mtime == modified_time

Просмотреть файл

@ -36,7 +36,7 @@ steps:
# hence don't set PYTHONPATH
- bash: |
source activate InnerEye
pytest ./Tests/ -m "not (gpu or azureml or after_training_single_run or after_training_ensemble_run or inference or after_training_2node or after_training_glaucoma_cv_run)" --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-config=.coveragerc --cov-report=xml -n 2 --dist=loadscope --verbose
pytest ./Tests/ -m "not (gpu or azureml or after_training_single_run or after_training_ensemble_run or inference or after_training_2node or after_training_glaucoma_cv_run or after_training_hello_container)" --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-config=.coveragerc --cov-report=xml -n 2 --dist=loadscope --verbose
env:
APPLICATION_KEY: $(InnerEyeDeepLearningServicePrincipalKey)
DATASETS_ACCOUNT_KEY: $(InnerEyePublicDatasetsStorageKey)

Просмотреть файл

@ -139,9 +139,9 @@ empty, in which case the union of all validation sets for the `N` child runs wil
### Recovering failed runs and continuing training
To train further with an already-created model, give the above command with additional switches like these:
To train further with an already-created model, give the above command with the `run_recovery_id` argument:
```
--run_recovery_id=foo_bar:foo_bar_12345_abcd --start_epoch=120
--run_recovery_id=foo_bar:foo_bar_12345_abcd
```
The run recovery ID is of the form "experiment_id:run_id". When you trained your original model, it will have been
queued as a "Run" inside of an "Experiment". The experiment will be given a name derived from the branch name - for
@ -160,24 +160,45 @@ of the AzureML UI. The easiest way to get it is to go to any of the child runs a
run recovery ID without the final underscore and digit.
### Testing an existing model
To evaluate an existing model on a test set, you can use models from previous runs in AzureML or from local checkpoints.
To evaluate an existing model on a test set, you can use registered models from previous runs in AzureML, a set of
local checkpoints or a set of URLs pointing to model checkpoints. For all these options, you will need to set the
flag `no-train` along with additional command line arguments to specify the checkpoints.
#### From a previus run in AzureML:
This is similar to continuing training using a run_recovery object, but you will need to set `--no-train`.
Thus your command should look like this:
#### From a registered model on AzureML:
You will need to specify the registered model to run on using the `model_id` argument. You can find the model name and
version by clicking on `Registered Models` on the Details tab of a run in the AzureML UI.
The model id is of the form "model_name:model_version". Thus your command should look like this:
```shell script
python Inner/ML/runner.py --azureml --model=Prostate --no-train --cluster=my_cluster_name \
--run_recovery_id=foo_bar:foo_bar_12345_abcd --start_epoch=120
python Inner/ML/runner.py --azureml --model=Prostate --cluster=my_cluster_name \
--no-train --model_id=Prostate:1
```
#### From a local checkpoint:
To evaluate a model using a local checkpoint, use the local_weights_path to specify the path to the model checkpoint
and set train to `False`.
#### From local checkpoints:
To evaluate a model using one or more local checkpoints, use the `local_weights_path` argument to specify the path(s) to the
model checkpoint(s) on the local disk.
```shell script
python Inner/ML/runner.py --model=Prostate --no-train --local_weights_path=path_to_your_checkpoint
```
To run on multiple checkpoints (if you have trained an ensemble model), specify each checkpoint using the argument
`local_weights_path`.
```shell script
python Inner/ML/runner.py --model=Prostate --no-train --local_weights_path=path_to_first_checkpoint,path_to_second_checkpoint
```
Alternatively, to submit an AzureML run to apply a model to a single image on your local disc,
#### From URLs:
To evaluate a model using one or more checkpoints each specified by a URL, use the `weights_url` argument to specify the
url(s) from which the model checkpoint(s) should be downloaded.
```shell script
python Inner/ML/runner.py --model=Prostate --no-train --weights_url=url_for_your_checkpoint
```
To run on multiple checkpoints (if you have trained an ensemble model), specify each checkpoint using the argument
`weights_url`.
```shell script
python Inner/ML/runner.py --model=Prostate --no-train --weights_url=url_for_first_checkpoint,url_for_second_checkpoint
```
#### Running a registered AzureML model on a single image on the local disk
To submit an AzureML run to apply a model to a single image on your local disc,
you can use the script `submit_for_inference.py`, with a command of this form:
```shell script
python InnerEye/Scripts/submit_for_inference.py --image_file ~/somewhere/ct.nii.gz --model_id Prostate:555 \