Run inference using checkpoints from registered models (#509)
This commit is contained in:
Родитель
732ddfcb34
Коммит
9fcc08f6cd
13
CHANGELOG.md
13
CHANGELOG.md
|
@ -18,7 +18,8 @@ module on test data with partial ground truth files. (Also [522](https://github.
|
|||
- ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference.
|
||||
- ([#492](https://github.com/microsoft/InnerEye-DeepLearning/pull/492)) Adding capability for regression tests for test
|
||||
jobs that run in AzureML.
|
||||
|
||||
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Run inference on registered models (single and
|
||||
ensemble) using the parameter `model_id`.
|
||||
### Changed
|
||||
- ([#531])(https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
|
||||
- ([#533](https://github.com/microsoft/InnerEye-DeepLearning/pull/533)) Better defaults for inference on ensemble children.
|
||||
|
@ -26,6 +27,8 @@ jobs that run in AzureML.
|
|||
- ([#496](https://github.com/microsoft/InnerEye-DeepLearning/pull/496)) All plots are now saved as PNG, rather than JPG.
|
||||
- ([#497](https://github.com/microsoft/InnerEye-DeepLearning/pull/497)) Reducing the size of the code snapshot that
|
||||
gets uploaded to AzureML, by skipping all test folders.
|
||||
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameter `extra_downloaded_run_id` has been
|
||||
renamed to `pretraining_run_checkpoints`.
|
||||
|
||||
### Fixed
|
||||
- ([#525](https://github.com/microsoft/InnerEye-DeepLearning/pull/525)) Enable --store_dataset_sample
|
||||
|
@ -40,11 +43,15 @@ LightningContainer models can get stuck at test set inference.
|
|||
multiple large checkpoints can time out.
|
||||
- ([#515](https://github.com/microsoft/InnerEye-DeepLearning/pull/515)) Workaround for occasional issues with dataset
|
||||
mounting and running matplotblib on some machines. Re-instantiated a disabled test.
|
||||
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Fix issue where model checkpoints were not loaded
|
||||
in inference-only runs when using lightning containers.
|
||||
|
||||
### Removed
|
||||
|
||||
- ([#542](https://github.com/microsoft/InnerEye-DeepLearning/pull/542)) Removed Windows test leg from build pipeline.
|
||||
- ([#520](https://github.com/microsoft/InnerEye-DeepLearning/pull/520)) Disable glaucoma job from Azure pipeline.
|
||||
- ([#542](https://github.com/microsoft/InnerEye-DeepLearning/pull/542)) Removed Windows test leg from build pipeline.
|
||||
- ([#520](https://github.com/microsoft/InnerEye-DeepLearning/pull/520)) Disable glaucoma job from Azure pipeline.
|
||||
- ([#509](https://github.com/microsoft/InnerEye-DeepLearning/pull/509)) Parameters `local_weights_path` and
|
||||
`weights_url` can no longer be used to initialize a training run, only inference runs.
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
|
|
@ -77,14 +77,9 @@ class AzureConfig(GenericConfig):
|
|||
pytest_mark: str = param.String(doc="If provided, run pytest instead of model training. pytest will only "
|
||||
"run the tests that have the mark given in this argument "
|
||||
"('--pytest_mark gpu' will run all tests marked with 'pytest.mark.gpu')")
|
||||
run_recovery_id: str = param.String(doc="A run recovery id string in the form 'experiment name:run id'"
|
||||
" to use for inference or recovering a model training run.")
|
||||
pretraining_run_recovery_id: str = param.String(default=None,
|
||||
allow_None=True,
|
||||
doc="Extra run recovery id to download checkpoints from,"
|
||||
"for custom modules (e.g. for loading pretrained weights)."
|
||||
"Warning: this argument will be ignored for InnerEyeContainer"
|
||||
"models.")
|
||||
run_recovery_id: str = param.String(doc="A run recovery id string in the form 'experiment name:run id' "
|
||||
"to use for inference, recovering a model training run or to register "
|
||||
"a model.")
|
||||
experiment_name: str = param.String(doc="If provided, use this string as the name of the AzureML experiment. "
|
||||
"If not provided, create the experiment off the git branch name.")
|
||||
build_number: int = param.Integer(0, doc="The numeric ID of the Azure pipeline that triggered this training run.")
|
||||
|
|
|
@ -40,9 +40,6 @@ DEFAULT_MODEL_SUMMARIES_DIR_PATH = Path(DEFAULT_LOGS_DIR_NAME) / "model_summarie
|
|||
# The folder at the project root directory that holds datasets for local execution.
|
||||
DATASETS_DIR_NAME = "datasets"
|
||||
|
||||
# Points to a folder at the project root directory that holds model weights downloaded from URLs.
|
||||
MODEL_WEIGHTS_DIR_NAME = "modelweights"
|
||||
|
||||
ML_RELATIVE_SOURCE_PATH = os.path.join("ML")
|
||||
ML_RELATIVE_RUNNER_PATH = os.path.join(ML_RELATIVE_SOURCE_PATH, "runner.py")
|
||||
ML_FULL_SOURCE_FOLDER_PATH = str(repository_root_directory() / ML_RELATIVE_SOURCE_PATH)
|
||||
|
|
|
@ -170,21 +170,21 @@ class CovidHierarchicalModel(ScalarModelBase):
|
|||
|
||||
def _get_ssl_checkpoint_path(self) -> Path:
|
||||
# Get the SSL weights from the AML run provided via "pretraining_run_recovery_id" command line argument.
|
||||
# Accessible via extra_downloaded_run_id field of the config.
|
||||
assert self.extra_downloaded_run_id is not None
|
||||
assert isinstance(self.extra_downloaded_run_id, RunRecovery)
|
||||
# Accessible via pretraining_run_checkpoints field of the config.
|
||||
assert self.pretraining_run_checkpoints is not None
|
||||
assert isinstance(self.pretraining_run_checkpoints, RunRecovery)
|
||||
ssl_path = self.checkpoint_folder / "ssl_checkpoint.ckpt"
|
||||
|
||||
if not ssl_path.exists(): # for test (when it is already present) we don't need to redo this.
|
||||
if self.name_of_checkpoint is not None:
|
||||
logging.info(f"Using checkpoint: {self.name_of_checkpoint} as starting point.")
|
||||
path_to_checkpoint = self.extra_downloaded_run_id.checkpoints_roots[0] / self.name_of_checkpoint
|
||||
path_to_checkpoint = self.pretraining_run_checkpoints.checkpoints_roots[0] / self.name_of_checkpoint
|
||||
else:
|
||||
path_to_checkpoint = self.extra_downloaded_run_id.get_best_checkpoint_paths()[0]
|
||||
path_to_checkpoint = self.pretraining_run_checkpoints.get_best_checkpoint_paths()[0]
|
||||
if not path_to_checkpoint.exists():
|
||||
logging.info("No best checkpoint found for this model. Getting the latest recovery "
|
||||
"checkpoint instead.")
|
||||
path_to_checkpoint = self.extra_downloaded_run_id.get_recovery_checkpoint_paths()[0]
|
||||
path_to_checkpoint = self.pretraining_run_checkpoints.get_recovery_checkpoint_paths()[0]
|
||||
assert path_to_checkpoint.exists()
|
||||
path_to_checkpoint.rename(ssl_path)
|
||||
return ssl_path
|
||||
|
|
|
@ -230,7 +230,7 @@ class HelloRegression(LightningModule):
|
|||
"""
|
||||
average_mse = torch.mean(torch.stack(self.test_mse))
|
||||
Path("test_mse.txt").write_text(str(average_mse.item()))
|
||||
Path("test_mae.txt").write_text(str(self.test_mae.compute()))
|
||||
Path("test_mae.txt").write_text(str(self.test_mae.compute().item()))
|
||||
|
||||
|
||||
class HelloContainer(LightningContainer):
|
||||
|
|
|
@ -33,7 +33,6 @@ VISUALIZATION_FOLDER = "visualizations"
|
|||
EXTRA_RUN_SUBFOLDER = "extra_run_id"
|
||||
|
||||
ARGS_TXT = "args.txt"
|
||||
WEIGHTS_FILE = "weights.pth"
|
||||
|
||||
|
||||
@unique
|
||||
|
@ -216,16 +215,25 @@ class WorkflowParams(param.Parameterized):
|
|||
ensemble_inference_on_test_set: Optional[bool] = \
|
||||
param.Boolean(None,
|
||||
doc="If set, enable/disable full image inference on test set after ensemble training.")
|
||||
weights_url: str = param.String(doc="If provided, a url from which weights will be downloaded and used for model "
|
||||
"initialization.")
|
||||
local_weights_path: Optional[Path] = param.ClassSelector(class_=Path,
|
||||
default=None,
|
||||
allow_None=True,
|
||||
doc="The path to the weights to use for model "
|
||||
"initialization, when training outside AzureML.")
|
||||
weights_url: List[str] = param.List(default=[], class_=str,
|
||||
doc="If provided, a set of urls from which checkpoints will be downloaded"
|
||||
"and used for inference.")
|
||||
local_weights_path: List[Path] = param.List(default=[], class_=Path,
|
||||
doc="A list of checkpoints paths to use for inference, "
|
||||
"when the job is running outside Azure.")
|
||||
model_id: str = param.String(default="",
|
||||
doc="A model id string in the form 'model name:version' "
|
||||
"to use a registered model for inference.")
|
||||
generate_report: bool = param.Boolean(default=True,
|
||||
doc="If True (default), write a modelling report in HTML format. If False,"
|
||||
"do not write that report.")
|
||||
pretraining_run_recovery_id: str = param.String(default=None,
|
||||
allow_None=True,
|
||||
doc="Extra run recovery id to download checkpoints from,"
|
||||
"for custom modules (e.g. for loading pretrained weights)."
|
||||
"The downloaded RunRecovery object will be available in"
|
||||
"pretraining_run_checkpoints.")
|
||||
|
||||
# The default multiprocessing start_method in both PyTorch and the Python standard library is "fork" for Linux and
|
||||
# "spawn" (the only available method) for Windows. There is some evidence that using "forkserver" on Linux
|
||||
# can reduce the chance of stuck jobs.
|
||||
|
@ -248,8 +256,13 @@ class WorkflowParams(param.Parameterized):
|
|||
"be relative to the repository root directory.")
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.weights_url and self.local_weights_path:
|
||||
raise ValueError("Cannot specify both local_weights_path and weights_url.")
|
||||
if sum([bool(param) for param in [self.weights_url, self.local_weights_path, self.model_id]]) > 1:
|
||||
raise ValueError("Cannot specify more than one of local_weights_path, weights_url or model_id.")
|
||||
|
||||
if self.model_id:
|
||||
if len(self.model_id.split(":")) != 2:
|
||||
raise ValueError(
|
||||
f"model_id should be in the form 'model_name:version', got {self.model_id}")
|
||||
|
||||
if self.number_of_cross_validation_splits == 1:
|
||||
raise ValueError("At least two splits required to perform cross validation, but got "
|
||||
|
@ -713,7 +726,7 @@ class DeepLearningConfig(WorkflowParams,
|
|||
self.create_filesystem(fixed_paths.repository_root_directory())
|
||||
# Disable the PL progress bar because all InnerEye models have their own console output
|
||||
self.pl_progress_bar_refresh_rate = 0
|
||||
self.extra_downloaded_run_id: Optional[Any] = None
|
||||
self.pretraining_run_checkpoints: Optional[Any] = None
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
import param
|
||||
import torch
|
||||
|
@ -19,7 +19,7 @@ from InnerEye.Common.generic_parsing import GenericConfig, create_from_matching_
|
|||
from InnerEye.Common.metrics_constants import TrackedMetrics
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.deep_learning_config import DatasetParams, OptimizerParams, OutputParams, TrainerParams, \
|
||||
WorkflowParams, load_checkpoint
|
||||
WorkflowParams
|
||||
from InnerEye.ML.utils import model_util
|
||||
from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp
|
||||
from InnerEye.ML.utils.run_recovery import RunRecovery
|
||||
|
@ -150,7 +150,7 @@ class LightningContainer(GenericConfig,
|
|||
super().__init__(**kwargs)
|
||||
self._model: Optional[LightningModule] = None
|
||||
self._model_name = type(self).__name__
|
||||
self.extra_downloaded_run_id: Optional[RunRecovery] = None
|
||||
self.pretraining_run_checkpoints: Optional[RunRecovery] = None
|
||||
self.num_nodes = 1
|
||||
|
||||
def validate(self) -> None:
|
||||
|
@ -249,36 +249,6 @@ class LightningContainer(GenericConfig,
|
|||
"""
|
||||
pass
|
||||
|
||||
def load_checkpoint_and_modify(self, path_to_checkpoint: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
This method is called when a file with weights for network initialization is supplied at container level,
|
||||
in the self.weights_url or self.local_weights_path fields. It can load that file as a Torch checkpoint,
|
||||
and rename parameters.
|
||||
|
||||
By default, uses torch.load to read and return the state dict from the checkpoint file, and does no modification
|
||||
of the checkpoint file.
|
||||
|
||||
Overloading this function:
|
||||
When weights_url or local_weights_path is set, the file downloaded may not be in the exact
|
||||
format expected by the model's load_state_dict() - for example, pretrained Imagenet weights for networks
|
||||
may have mismatched layer names in different implementations.
|
||||
In such cases, you can overload this function to extract the state dict from the checkpoint.
|
||||
|
||||
NOTE: The model checkpoint will be loaded using the torch function load_state_dict() with argument strict=False,
|
||||
so extra care needs to be taken to check that the state dict is valid.
|
||||
Check the logs for warnings related to missing and unexpected keys.
|
||||
See https://pytorch.org/tutorials/beginner/saving_loading_models.html#warmstarting-model-using-parameters
|
||||
-from-a-different-model
|
||||
for an explanation on why strict=False is useful when loading parameters from other models.
|
||||
:param path_to_checkpoint: Path to the checkpoint file.
|
||||
:return: Dictionary with model and optimizer state dicts. The dict should have at least the following keys:
|
||||
1. Key ModelAndInfo.MODEL_STATE_DICT_KEY and value set to the model state dict.
|
||||
2. Key ModelAndInfo.EPOCH_KEY and value set to the checkpoint epoch.
|
||||
Other (optional) entries corresponding to keys ModelAndInfo.OPTIMIZER_STATE_DICT_KEY and
|
||||
ModelAndInfo.MEAN_TEACHER_STATE_DICT_KEY are also supported.
|
||||
"""
|
||||
return load_checkpoint(path_to_checkpoint=path_to_checkpoint, use_gpu=self.use_gpu)
|
||||
|
||||
# The code from here on does not need to be modified.
|
||||
|
||||
@property
|
||||
|
@ -333,6 +303,15 @@ class LightningContainer(GenericConfig,
|
|||
else:
|
||||
return self.get_parameter_search_hyperdrive_config(run_config)
|
||||
|
||||
def load_model_checkpoint(self, checkpoint_path: Path) -> None:
|
||||
"""
|
||||
Load a checkpoint from the given path. We need to define a separate method since pytorch lightning cannot
|
||||
access the _model attribute to modify it.
|
||||
"""
|
||||
if self._model is None:
|
||||
raise ValueError("No Lightning module has been set yet.")
|
||||
self._model = type(self._model).load_from_checkpoint(checkpoint_path=str(checkpoint_path))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Returns a string describing the present object, as a list of key: value strings."""
|
||||
arguments_str = "\nContainer:\n"
|
||||
|
|
|
@ -35,7 +35,6 @@ from InnerEye.ML.reports.segmentation_report import boxplot_per_structure
|
|||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils import io_util, ml_util
|
||||
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
|
||||
from InnerEye.ML.utils.image_util import binaries_from_multi_label_array
|
||||
from InnerEye.ML.utils.io_util import ImageHeader, MedicalImageFileType, load_nifti_image, save_lines_to_file
|
||||
from InnerEye.ML.utils.metrics_util import MetricsPerPatientWriter
|
||||
|
@ -47,7 +46,7 @@ MODEL_OUTPUT_CSV = "model_outputs.csv"
|
|||
|
||||
def model_test(config: ModelConfigBase,
|
||||
data_split: ModelExecutionMode,
|
||||
checkpoint_handler: CheckpointHandler,
|
||||
checkpoint_paths: List[Path],
|
||||
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> Optional[InferenceMetrics]:
|
||||
"""
|
||||
Runs model inference on segmentation or classification models, using a given dataset (that could be training,
|
||||
|
@ -55,7 +54,7 @@ def model_test(config: ModelConfigBase,
|
|||
differ for model categories (classification, segmentation).
|
||||
:param config: The configuration of the model
|
||||
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
|
||||
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
|
||||
:param checkpoint_paths: Checkpoint paths to initialize model.
|
||||
:param model_proc: whether we are testing an ensemble or single model; this affects where results are written.
|
||||
:return: The metrics that the model achieved on the given data set, or None if the data set is empty.
|
||||
"""
|
||||
|
@ -67,17 +66,19 @@ def model_test(config: ModelConfigBase,
|
|||
"and additional data loaders are likely to block.")
|
||||
return None
|
||||
with logging_section(f"Running {model_proc.value} model on {data_split.name.lower()} set"):
|
||||
if not checkpoint_paths:
|
||||
raise ValueError("There were no checkpoints available for model testing.")
|
||||
if isinstance(config, SegmentationModelBase):
|
||||
return segmentation_model_test(config, data_split, checkpoint_handler, model_proc)
|
||||
return segmentation_model_test(config, data_split, checkpoint_paths, model_proc)
|
||||
if isinstance(config, ScalarModelBase):
|
||||
return classification_model_test(config, data_split, checkpoint_handler, model_proc,
|
||||
return classification_model_test(config, data_split, checkpoint_paths, model_proc,
|
||||
config.cross_validation_split_index)
|
||||
raise ValueError(f"There is no testing code for models of type {type(config)}")
|
||||
|
||||
|
||||
def segmentation_model_test(config: SegmentationModelBase,
|
||||
execution_mode: ModelExecutionMode,
|
||||
checkpoint_handler: CheckpointHandler,
|
||||
checkpoint_paths: List[Path],
|
||||
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> InferenceMetricsForSegmentation:
|
||||
"""
|
||||
The main testing loop for segmentation models.
|
||||
|
@ -89,10 +90,6 @@ def segmentation_model_test(config: SegmentationModelBase,
|
|||
:param patient_id: String which contains subject identifier.
|
||||
:return: InferenceMetric object that contains metrics related for all of the checkpoint epochs.
|
||||
"""
|
||||
checkpoints_to_test = checkpoint_handler.get_checkpoints_to_test()
|
||||
|
||||
if not checkpoints_to_test:
|
||||
raise ValueError("There were no checkpoints available for model testing.")
|
||||
|
||||
epoch_results_folder = config.outputs_folder / get_best_epoch_results_path(execution_mode, model_proc)
|
||||
# save the datasets.csv used
|
||||
|
@ -100,7 +97,7 @@ def segmentation_model_test(config: SegmentationModelBase,
|
|||
epoch_and_split = f"{execution_mode.value} set"
|
||||
epoch_dice_per_image = segmentation_model_test_epoch(config=copy.deepcopy(config),
|
||||
execution_mode=execution_mode,
|
||||
checkpoint_paths=checkpoints_to_test,
|
||||
checkpoint_paths=checkpoint_paths,
|
||||
results_folder=epoch_results_folder,
|
||||
epoch_and_split=epoch_and_split)
|
||||
if epoch_dice_per_image is None:
|
||||
|
@ -410,7 +407,7 @@ def create_metrics_dict_for_scalar_models(config: ScalarModelBase) -> \
|
|||
|
||||
def classification_model_test(config: ScalarModelBase,
|
||||
data_split: ModelExecutionMode,
|
||||
checkpoint_handler: CheckpointHandler,
|
||||
checkpoint_paths: List[Path],
|
||||
model_proc: ModelProcessing,
|
||||
cross_val_split_index: int) -> InferenceMetricsForClassification:
|
||||
"""
|
||||
|
@ -419,16 +416,12 @@ def classification_model_test(config: ScalarModelBase,
|
|||
:param config: The model configuration.
|
||||
:param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir,
|
||||
used mainly in model evaluation using different dataset splits.
|
||||
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
|
||||
:param checkpoint_paths: Checkpoint paths to initialize model
|
||||
:param model_proc: whether we are testing an ensemble or single model
|
||||
:return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.
|
||||
"""
|
||||
posthoc_label_transform = config.get_posthoc_label_transform()
|
||||
|
||||
checkpoint_paths = checkpoint_handler.get_checkpoints_to_test()
|
||||
if not checkpoint_paths:
|
||||
raise ValueError("There were no checkpoints available for model testing.")
|
||||
|
||||
pipeline = create_inference_pipeline(config=config,
|
||||
checkpoint_paths=checkpoint_paths)
|
||||
if pipeline is None:
|
||||
|
|
|
@ -28,7 +28,6 @@ from InnerEye.ML.lightning_container import LightningContainer
|
|||
from InnerEye.ML.lightning_loggers import AzureMLLogger, StoringLogger
|
||||
from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \
|
||||
get_subject_output_file_per_rank
|
||||
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
|
||||
|
||||
TEMP_PREFIX = "temp/"
|
||||
|
||||
|
@ -204,14 +203,14 @@ def start_resource_monitor(config: LightningContainer) -> ResourceMonitor:
|
|||
return resource_monitor
|
||||
|
||||
|
||||
def model_train(checkpoint_handler: CheckpointHandler,
|
||||
def model_train(checkpoint_path: Optional[Path],
|
||||
container: LightningContainer,
|
||||
num_nodes: int = 1) -> Tuple[Trainer, Optional[StoringLogger]]:
|
||||
"""
|
||||
The main training loop. It creates the Pytorch model based on the configuration options passed in,
|
||||
creates a Pytorch Lightning trainer, and trains the model.
|
||||
If a checkpoint was specified, then it loads the checkpoint before resuming training.
|
||||
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
|
||||
:param checkpoint_path: Checkpoint path for model initialization
|
||||
:param num_nodes: The number of nodes to use in distributed training.
|
||||
:param container: A container object that holds the training data in PyTorch Lightning format
|
||||
and the model to train.
|
||||
|
@ -219,8 +218,6 @@ def model_train(checkpoint_handler: CheckpointHandler,
|
|||
the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when
|
||||
fitting other models.
|
||||
"""
|
||||
# Get the path to the checkpoint to recover from
|
||||
checkpoint_path = checkpoint_handler.get_recovery_path_train()
|
||||
lightning_model = container.model
|
||||
|
||||
resource_monitor: Optional[ResourceMonitor] = None
|
||||
|
@ -303,10 +300,6 @@ def model_train(checkpoint_handler: CheckpointHandler,
|
|||
|
||||
logging.info("Finished training")
|
||||
|
||||
# Since we have trained the model further, let the checkpoint_handler object know so it can handle
|
||||
# checkpoints correctly.
|
||||
checkpoint_handler.additional_training_done()
|
||||
|
||||
# Upload visualization directory to AML run context to be able to see it in the Azure UI.
|
||||
if isinstance(container, InnerEyeContainer):
|
||||
if container.config.max_batch_grad_cam > 0 and container.visualization_folder.exists():
|
||||
|
|
|
@ -39,7 +39,7 @@ from InnerEye.ML.baselines_util import compare_folders_and_run_outputs
|
|||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, DeepLearningConfig, FINAL_ENSEMBLE_MODEL_FOLDER, \
|
||||
FINAL_MODEL_FOLDER, ModelCategory, MultiprocessingStartMethod, load_checkpoint
|
||||
FINAL_MODEL_FOLDER, ModelCategory, MultiprocessingStartMethod, load_checkpoint, EXTRA_RUN_SUBFOLDER
|
||||
from InnerEye.ML.lightning_base import InnerEyeContainer
|
||||
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer
|
||||
from InnerEye.ML.metrics import InferenceMetrics, InferenceMetricsForSegmentation
|
||||
|
@ -53,6 +53,7 @@ from InnerEye.ML.reports.notebook_report import generate_classification_crossval
|
|||
from InnerEye.ML.scalar_config import ScalarModelBase
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
|
||||
from InnerEye.ML.utils.run_recovery import RunRecovery
|
||||
from InnerEye.ML.visualizers import activation_maps
|
||||
from InnerEye.ML.visualizers.plot_cross_validation import \
|
||||
get_config_and_results_for_offline_runs, plot_cross_validation_from_files
|
||||
|
@ -223,11 +224,19 @@ class MLRunner:
|
|||
run_context=RUN_CONTEXT)
|
||||
self.checkpoint_handler.download_recovery_checkpoints_or_weights(only_return_path=not is_global_rank_zero())
|
||||
|
||||
if self.container.pretraining_run_recovery_id is not None:
|
||||
run_to_recover = self.azure_config.fetch_run(self.container.pretraining_run_recovery_id.strip())
|
||||
run_recovery_object = RunRecovery.download_all_checkpoints_from_run(self.container,
|
||||
run_to_recover,
|
||||
EXTRA_RUN_SUBFOLDER,
|
||||
only_return_path=not is_global_rank_zero())
|
||||
self.container.pretraining_run_checkpoints = run_recovery_object
|
||||
|
||||
# A lot of the code for the built-in InnerEye models expects the output paths directly in the config files.
|
||||
if isinstance(self.container, InnerEyeContainer):
|
||||
self.container.config.local_dataset = self.container.local_dataset
|
||||
self.container.config.file_system_config = self.container.file_system_config
|
||||
self.container.config.extra_downloaded_run_id = self.container.extra_downloaded_run_id
|
||||
self.container.config.pretraining_run_checkpoints = self.container.pretraining_run_checkpoints
|
||||
self.container.setup()
|
||||
self.container.create_lightning_module_and_store()
|
||||
self._has_setup_run = True
|
||||
|
@ -361,9 +370,12 @@ class MLRunner:
|
|||
# train a new model if required
|
||||
if self.azure_config.train:
|
||||
with logging_section("Model training"):
|
||||
model_train(self.checkpoint_handler,
|
||||
model_train(self.checkpoint_handler.get_recovery_or_checkpoint_path_train(),
|
||||
container=self.container,
|
||||
num_nodes=self.azure_config.num_nodes)
|
||||
# Since we have trained the model further, let the checkpoint_handler object know so it can handle
|
||||
# checkpoints correctly.
|
||||
self.checkpoint_handler.additional_training_done()
|
||||
# log the number of epochs used for model training
|
||||
RUN_CONTEXT.log(name="Train epochs", value=self.container.num_epochs)
|
||||
elif isinstance(self.container, InnerEyeContainer):
|
||||
|
@ -374,14 +386,15 @@ class MLRunner:
|
|||
# AzureML.
|
||||
if not self.is_offline_run:
|
||||
if self.should_register_model():
|
||||
self.register_model(self.checkpoint_handler, ModelProcessing.DEFAULT)
|
||||
self.register_model(self.checkpoint_handler.get_best_checkpoints(), ModelProcessing.DEFAULT)
|
||||
|
||||
if not self.azure_config.only_register_model:
|
||||
checkpoint_paths_for_testing = self.checkpoint_handler.get_checkpoints_to_test()
|
||||
if isinstance(self.container, InnerEyeContainer):
|
||||
# Inference for the InnerEye built-in models
|
||||
# We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run,
|
||||
# because the current run is a single one. See the documentation of ModelProcessing for more details.
|
||||
self.run_inference(self.checkpoint_handler, ModelProcessing.DEFAULT)
|
||||
self.run_inference(checkpoint_paths_for_testing, ModelProcessing.DEFAULT)
|
||||
|
||||
if self.container.generate_report:
|
||||
self.generate_report(ModelProcessing.DEFAULT)
|
||||
|
@ -397,7 +410,7 @@ class MLRunner:
|
|||
else:
|
||||
# Inference for all models that are specified via LightningContainers.
|
||||
with logging_section("Model inference"):
|
||||
self.run_inference_for_lightning_models(self.checkpoint_handler.get_checkpoints_to_test())
|
||||
self.run_inference_for_lightning_models(checkpoint_paths_for_testing)
|
||||
# We can't enforce that files are written to the output folder, hence change the working directory
|
||||
# manually
|
||||
with change_working_directory(self.container.outputs_folder):
|
||||
|
@ -483,26 +496,26 @@ class MLRunner:
|
|||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
trainer, _ = create_lightning_trainer(self.container, num_nodes=1)
|
||||
self.container.load_model_checkpoint(checkpoint_path=checkpoint_paths[0])
|
||||
# When training models that are not built-in InnerEye models, we have no guarantee that they write
|
||||
# files to the right folder. Best guess is to change the current working directory to where files should go.
|
||||
with change_working_directory(self.container.outputs_folder):
|
||||
trainer.test(self.container.model,
|
||||
test_dataloaders=self.container.get_data_module().test_dataloader(),
|
||||
ckpt_path=str(checkpoint_paths[0]))
|
||||
test_dataloaders=self.container.get_data_module().test_dataloader())
|
||||
else:
|
||||
logging.warning("None of the suitable test methods is overridden. Skipping inference completely.")
|
||||
|
||||
def run_inference(self, checkpoint_handler: CheckpointHandler,
|
||||
def run_inference(self, checkpoint_paths: List[Path],
|
||||
model_proc: ModelProcessing) -> None:
|
||||
"""
|
||||
Run inference on InnerEyeContainer models
|
||||
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization
|
||||
:param checkpoint_paths: Checkpoint paths to initialize model
|
||||
:param model_proc: whether we are running an ensemble model from within a child run with index 0. If we are,
|
||||
then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
|
||||
"""
|
||||
|
||||
# run full image inference on existing or newly trained model on the training, and testing set
|
||||
self.model_inference_train_and_test(checkpoint_handler=checkpoint_handler,
|
||||
self.model_inference_train_and_test(checkpoint_paths=checkpoint_paths,
|
||||
model_proc=model_proc)
|
||||
|
||||
self.try_compare_scores_against_baselines(model_proc)
|
||||
|
@ -594,12 +607,12 @@ class MLRunner:
|
|||
compare_scores_against_baselines(self.model_config, self.azure_config, model_proc)
|
||||
|
||||
def register_model(self,
|
||||
checkpoint_handler: CheckpointHandler,
|
||||
checkpoint_paths: List[Path],
|
||||
model_proc: ModelProcessing) -> Tuple[model.Model, Any]:
|
||||
"""
|
||||
Registers a new model in the workspace's model registry on AzureML to be deployed further.
|
||||
The AzureML run's tags are updated to describe with information about ensemble creation and the parent run ID.
|
||||
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model registration.
|
||||
:param checkpoint_path: Checkpoint paths to register.
|
||||
:param model_proc: whether it's a single or ensemble model.
|
||||
:returns Tuple element 1: AML model object, or None if no model could be registered.
|
||||
Tuple element 2: The result of running the model_deployment_hook, or None if no hook was supplied.
|
||||
|
@ -607,7 +620,6 @@ class MLRunner:
|
|||
if self.is_offline_run:
|
||||
raise ValueError("Cannot register models when InnerEye is running outside of AzureML.")
|
||||
|
||||
checkpoint_paths = checkpoint_handler.get_checkpoints_to_test()
|
||||
if not checkpoint_paths:
|
||||
raise ValueError("Model registration failed: No checkpoints found")
|
||||
|
||||
|
@ -761,7 +773,7 @@ class MLRunner:
|
|||
raise ValueError(f"Checkpoint file {checkpoint_source} does not exist")
|
||||
|
||||
def model_inference_train_and_test(self,
|
||||
checkpoint_handler: CheckpointHandler,
|
||||
checkpoint_paths: List[Path],
|
||||
model_proc: ModelProcessing = ModelProcessing.DEFAULT) -> \
|
||||
Dict[ModelExecutionMode, InferenceMetrics]:
|
||||
metrics: Dict[ModelExecutionMode, InferenceMetrics] = {}
|
||||
|
@ -770,7 +782,7 @@ class MLRunner:
|
|||
|
||||
for data_split in ModelExecutionMode:
|
||||
if self.container.inference_on_set(model_proc, data_split):
|
||||
opt_metrics = model_test(config, data_split=data_split, checkpoint_handler=checkpoint_handler,
|
||||
opt_metrics = model_test(config, data_split=data_split, checkpoint_paths=checkpoint_paths,
|
||||
model_proc=model_proc)
|
||||
if opt_metrics is not None:
|
||||
metrics[data_split] = opt_metrics
|
||||
|
@ -831,10 +843,10 @@ class MLRunner:
|
|||
# AzureML.
|
||||
if not self.is_offline_run:
|
||||
if self.should_register_model():
|
||||
self.register_model(checkpoint_handler, ModelProcessing.ENSEMBLE_CREATION)
|
||||
self.register_model(checkpoint_handler.get_best_checkpoints(), ModelProcessing.ENSEMBLE_CREATION)
|
||||
|
||||
if not self.azure_config.only_register_model:
|
||||
self.run_inference(checkpoint_handler=checkpoint_handler,
|
||||
self.run_inference(checkpoint_paths=checkpoint_handler.get_checkpoints_to_test(),
|
||||
model_proc=ModelProcessing.ENSEMBLE_CREATION)
|
||||
|
||||
crossval_dir = self.plot_cross_validation_and_upload_results()
|
||||
|
|
|
@ -11,15 +11,18 @@ from typing import List, Optional
|
|||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from azureml.core import Run
|
||||
from azureml.core import Run, Workspace, Model
|
||||
|
||||
from InnerEye.Azure.azure_config import AzureConfig
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.Common.fixed_paths import MODEL_INFERENCE_JSON_FILE_NAME
|
||||
from InnerEye.ML.common import find_recovery_checkpoint_and_epoch
|
||||
from InnerEye.ML.deep_learning_config import EXTRA_RUN_SUBFOLDER, OutputParams, WEIGHTS_FILE
|
||||
from InnerEye.ML.deep_learning_config import OutputParams
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
from InnerEye.ML.utils.run_recovery import RunRecovery
|
||||
from InnerEye.ML.model_inference_config import read_model_inference_config
|
||||
|
||||
|
||||
MODEL_WEIGHTS_DIR_NAME = "trained_models"
|
||||
|
||||
|
||||
class CheckpointHandler:
|
||||
|
@ -35,7 +38,7 @@ class CheckpointHandler:
|
|||
self.run_recovery: Optional[RunRecovery] = None
|
||||
self.project_root = project_root
|
||||
self.run_context = run_context
|
||||
self.local_weights_path: Optional[Path] = None
|
||||
self.trained_weights_paths: List[Path] = []
|
||||
self.has_continued_training = False
|
||||
|
||||
@property
|
||||
|
@ -47,7 +50,7 @@ class CheckpointHandler:
|
|||
|
||||
def download_checkpoints_from_hyperdrive_child_runs(self, hyperdrive_parent_run: Run) -> None:
|
||||
"""
|
||||
Downloads the best checkpoints from all child runs of a Hyperdrive parent runs. This is used to gather results
|
||||
Downloads the best checkpoints from all child runs of a Hyperdrive parent run. This is used to gather results
|
||||
for ensemble creation.
|
||||
"""
|
||||
self.run_recovery = RunRecovery.download_best_checkpoints_from_child_runs(self.output_params,
|
||||
|
@ -70,21 +73,9 @@ class CheckpointHandler:
|
|||
run_to_recover = self.azure_config.fetch_run(self.azure_config.run_recovery_id.strip())
|
||||
self.run_recovery = RunRecovery.download_all_checkpoints_from_run(self.output_params, run_to_recover,
|
||||
only_return_path=only_return_path)
|
||||
else:
|
||||
self.run_recovery = None
|
||||
|
||||
if self.azure_config.pretraining_run_recovery_id is not None:
|
||||
run_to_recover = self.azure_config.fetch_run(self.azure_config.pretraining_run_recovery_id.strip())
|
||||
run_recovery_object = RunRecovery.download_all_checkpoints_from_run(self.output_params,
|
||||
run_to_recover,
|
||||
EXTRA_RUN_SUBFOLDER,
|
||||
only_return_path=only_return_path)
|
||||
self.container.extra_downloaded_run_id = run_recovery_object
|
||||
else:
|
||||
self.container.extra_downloaded_run_id = None
|
||||
|
||||
if self.container.weights_url or self.container.local_weights_path:
|
||||
self.local_weights_path = self.get_and_save_modified_weights()
|
||||
if self.container.weights_url or self.container.local_weights_path or self.container.model_id:
|
||||
self.trained_weights_paths = self.get_local_checkpoints_path_or_download()
|
||||
|
||||
def additional_training_done(self) -> None:
|
||||
"""
|
||||
|
@ -92,7 +83,7 @@ class CheckpointHandler:
|
|||
"""
|
||||
self.has_continued_training = True
|
||||
|
||||
def get_recovery_path_train(self) -> Optional[Path]:
|
||||
def get_recovery_or_checkpoint_path_train(self) -> Optional[Path]:
|
||||
"""
|
||||
Decides the checkpoint path to use for the current training run. Looks for the latest checkpoint in the
|
||||
checkpoint folder. If run_recovery is provided, the checkpoints will have been downloaded to this folder
|
||||
|
@ -105,29 +96,19 @@ class CheckpointHandler:
|
|||
local_recovery_path, recovery_epoch = recovery
|
||||
self.container._start_epoch = recovery_epoch
|
||||
return local_recovery_path
|
||||
|
||||
elif self.local_weights_path:
|
||||
return self.local_weights_path
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_best_checkpoint(self) -> List[Path]:
|
||||
def get_best_checkpoints(self) -> List[Path]:
|
||||
"""
|
||||
Get a list of checkpoints per epoch for testing/registration.
|
||||
1. If a run recovery object is used and no training was done in this run, use checkpoints from run recovery.
|
||||
2. If a run recovery object is used, and training was done in this run, but the start epoch is larger than
|
||||
the epoch parameter provided, use checkpoints from run recovery.
|
||||
3. If a run recovery object is used, and training was done in this run, but the start epoch is smaller than
|
||||
the epoch parameter provided, use checkpoints from the current training run.
|
||||
This function also checks that all the checkpoints at the returned checkpoint paths exist,
|
||||
and drops any that do not.
|
||||
Get a list of checkpoints per epoch for testing/registration from the current training run.
|
||||
This function also checks that the checkpoint at the returned checkpoint path exists.
|
||||
"""
|
||||
if not self.run_recovery and not self.has_continued_training:
|
||||
raise ValueError("Cannot recover checkpoint, no run recovery object provided and"
|
||||
raise ValueError("Cannot recover checkpoint, no run recovery object provided and "
|
||||
"no training has been done in this run.")
|
||||
|
||||
checkpoint_paths = []
|
||||
|
||||
if self.run_recovery:
|
||||
checkpoint_paths = self.run_recovery.get_best_checkpoint_paths()
|
||||
|
||||
|
@ -165,77 +146,81 @@ class CheckpointHandler:
|
|||
|
||||
checkpoints = []
|
||||
|
||||
# If recovery object exists, or model was trained, look for checkpoints by epoch
|
||||
# If model was trained, look for the best checkpoint
|
||||
if self.run_recovery or self.has_continued_training:
|
||||
checkpoints = self.get_best_checkpoint()
|
||||
elif self.local_weights_path and not self.has_continued_training:
|
||||
# No recovery object and model was not trained, check if there is a local weight path.
|
||||
if self.local_weights_path.exists():
|
||||
logging.info(f"Using model weights at {self.local_weights_path} to initialize model")
|
||||
checkpoints = [self.local_weights_path]
|
||||
else:
|
||||
logging.warning(f"local_weights_path does not exist, "
|
||||
f"cannot recover from {self.local_weights_path}")
|
||||
checkpoints = self.get_best_checkpoints()
|
||||
elif self.trained_weights_paths:
|
||||
# Model was not trained, check if there is a local weight path.
|
||||
logging.info(f"Using model weights from {self.trained_weights_paths} to initialize model")
|
||||
checkpoints = self.trained_weights_paths
|
||||
else:
|
||||
logging.warning("Could not find any run recovery object or local_weights_path to get checkpoints from")
|
||||
logging.warning("Could not find any local_weights_path, model_weights or model_id to get checkpoints from")
|
||||
|
||||
return checkpoints
|
||||
|
||||
def download_weights(self) -> Path:
|
||||
@staticmethod
|
||||
def download_weights(urls: List[str], download_folder: Path) -> List[Path]:
|
||||
"""
|
||||
Download a checkpoint from weights_url to the modelweights directory.
|
||||
"""
|
||||
target_folder = self.project_root / fixed_paths.MODEL_WEIGHTS_DIR_NAME
|
||||
target_folder.mkdir(exist_ok=True)
|
||||
checkpoint_paths = []
|
||||
for url in urls:
|
||||
# assign the same filename as in the download url if possible, so that we can check for duplicates
|
||||
# If that fails, map to a random uuid
|
||||
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
|
||||
result_file = download_folder / file_name
|
||||
checkpoint_paths.append(result_file)
|
||||
# only download if hasn't already been downloaded
|
||||
if result_file.exists():
|
||||
logging.info(f"File already exists, skipping download: {result_file}")
|
||||
else:
|
||||
logging.info(f"Downloading weights from URL {url}")
|
||||
|
||||
url = self.container.weights_url
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(result_file, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
file.write(chunk)
|
||||
|
||||
# assign the same filename as in the download url if possible, so that we can check for duplicates
|
||||
# If that fails, map to a random uuid
|
||||
file_name = os.path.basename(urlparse(url).path) or str(uuid.uuid4().hex)
|
||||
result_file = target_folder / file_name
|
||||
return checkpoint_paths
|
||||
|
||||
# only download if hasn't already been downloaded
|
||||
if result_file.exists():
|
||||
logging.info(f"File already exists, skipping download: {result_file}")
|
||||
return result_file
|
||||
@staticmethod
|
||||
def get_checkpoints_from_model(model_id: str, workspace: Workspace, download_path: Path) -> List[Path]:
|
||||
if len(model_id.split(":")) != 2:
|
||||
raise ValueError(
|
||||
f"model_id should be in the form 'model_name:version', got {model_id}")
|
||||
|
||||
logging.info(f"Downloading weights from URL {url}")
|
||||
model_name, model_version = model_id.split(":")
|
||||
model = Model(workspace=workspace, name=model_name, version=int(model_version))
|
||||
model_path = Path(model.download(str(download_path), exist_ok=True))
|
||||
model_inference_config = read_model_inference_config(model_path / MODEL_INFERENCE_JSON_FILE_NAME)
|
||||
checkpoint_paths = [model_path / x for x in model_inference_config.checkpoint_paths]
|
||||
return checkpoint_paths
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(result_file, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
file.write(chunk)
|
||||
|
||||
return result_file
|
||||
|
||||
def get_local_weights_path_or_download(self) -> Optional[Path]:
|
||||
def get_local_checkpoints_path_or_download(self) -> List[Path]:
|
||||
"""
|
||||
Get the path to the local weights to use or download them and set local_weights_path
|
||||
"""
|
||||
if self.container.local_weights_path:
|
||||
weights_path = self.container.local_weights_path
|
||||
elif self.container.weights_url:
|
||||
weights_path = self.download_weights()
|
||||
else:
|
||||
raise ValueError("Cannot download/modify weights - neither local_weights_path nor weights_url is set in"
|
||||
if not self.container.model_id and not self.container.local_weights_path and not self.container.weights_url:
|
||||
raise ValueError("Cannot download weights - none of model_id, local_weights_path or weights_url is set in "
|
||||
"the model config.")
|
||||
|
||||
return weights_path
|
||||
if self.container.local_weights_path:
|
||||
checkpoint_paths = self.container.local_weights_path
|
||||
else:
|
||||
download_folder = self.output_params.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME
|
||||
download_folder.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def get_and_save_modified_weights(self) -> Path:
|
||||
"""
|
||||
Downloads the checkpoint weights if needed.
|
||||
Then passes the downloaded or local checkpoint to the modify_checkpoint function from the model_config and saves
|
||||
the modified state dict from the function in the outputs folder with the name weights.pth.
|
||||
"""
|
||||
weights_path = self.get_local_weights_path_or_download()
|
||||
if self.container.model_id:
|
||||
checkpoint_paths = CheckpointHandler.get_checkpoints_from_model(model_id=self.container.model_id,
|
||||
workspace=self.azure_config.get_workspace(),
|
||||
download_path=download_folder)
|
||||
elif self.container.weights_url:
|
||||
urls = self.container.weights_url
|
||||
checkpoint_paths = CheckpointHandler.download_weights(urls=urls,
|
||||
download_folder=download_folder)
|
||||
|
||||
if not weights_path or not weights_path.is_file():
|
||||
raise FileNotFoundError(f"Could not find the weights file at {weights_path}")
|
||||
|
||||
modified_weights = self.container.load_checkpoint_and_modify(weights_path)
|
||||
target_file = self.output_params.outputs_folder / WEIGHTS_FILE
|
||||
torch.save(modified_weights, target_file)
|
||||
return target_file
|
||||
for checkpoint_path in checkpoint_paths:
|
||||
if not checkpoint_path or not checkpoint_path.is_file():
|
||||
raise FileNotFoundError(f"Could not find the weights file at {checkpoint_path}")
|
||||
return checkpoint_paths
|
||||
|
|
|
@ -36,16 +36,18 @@ from InnerEye.Common.output_directories import OutputFolderForTests
|
|||
from InnerEye.Common.spawn_subprocess import spawn_and_monitor_subprocess
|
||||
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode
|
||||
from InnerEye.ML.configs.segmentation.BasicModel2Epochs import BasicModel2Epochs
|
||||
from InnerEye.ML.configs.other.HelloContainer import HelloContainer
|
||||
from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, ModelCategory
|
||||
from InnerEye.ML.model_inference_config import read_model_inference_config
|
||||
from InnerEye.ML.model_testing import THUMBNAILS_FOLDER
|
||||
from InnerEye.ML.reports.notebook_report import get_html_report_name
|
||||
from InnerEye.ML.runner import main
|
||||
from InnerEye.ML.run_ml import MLRunner
|
||||
from InnerEye.ML.utils.config_loader import ModelConfigLoader
|
||||
from InnerEye.ML.utils.image_util import get_unit_image_header
|
||||
from InnerEye.ML.utils.io_util import zip_random_dicom_series
|
||||
from InnerEye.Scripts import submit_for_inference
|
||||
from Tests.ML.util import assert_nifti_content, get_default_azure_config, get_nifti_shape
|
||||
from Tests.ML.util import assert_nifti_content, get_default_azure_config, get_nifti_shape, get_default_workspace
|
||||
|
||||
FALLBACK_SINGLE_RUN = "refs_pull_498_merge:refs_pull_498_merge_1624292750_743430ab"
|
||||
FALLBACK_ENSEMBLE_RUN = "refs_pull_498_merge:HD_4bf4efc3-182a-4596-8f93-76f128418142"
|
||||
|
@ -87,10 +89,10 @@ def get_most_recent_run(fallback_run_id_for_local_execution: str = FALLBACK_SING
|
|||
return get_default_azure_config().fetch_run(run_recovery_id=run_recovery_id)
|
||||
|
||||
|
||||
def get_most_recent_model(fallback_run_id_for_local_execution: str = FALLBACK_SINGLE_RUN) -> Model:
|
||||
def get_most_recent_model_id(fallback_run_id_for_local_execution: str = FALLBACK_SINGLE_RUN) -> str:
|
||||
"""
|
||||
Gets the string name of the most recently executed AzureML run, extracts which model that run had registered,
|
||||
and return the instantiated model object.
|
||||
and return the model id.
|
||||
:param fallback_run_id_for_local_execution: A hardcoded AzureML run ID that is used when executing this code
|
||||
on a local box, outside of Azure build agents.
|
||||
"""
|
||||
|
@ -101,7 +103,18 @@ def get_most_recent_model(fallback_run_id_for_local_execution: str = FALLBACK_SI
|
|||
tags = run.get_tags()
|
||||
model_id = tags.get(MODEL_ID_KEY_NAME, None)
|
||||
assert model_id, f"No model_id tag was found on run {most_recent_run}"
|
||||
return Model(workspace=azure_config.get_workspace(), id=model_id)
|
||||
return model_id
|
||||
|
||||
|
||||
def get_most_recent_model(fallback_run_id_for_local_execution: str = FALLBACK_SINGLE_RUN) -> Model:
|
||||
"""
|
||||
Gets the string name of the most recently executed AzureML run, extracts which model that run had registered,
|
||||
and return the instantiated model object.
|
||||
:param fallback_run_id_for_local_execution: A hardcoded AzureML run ID that is used when executing this code
|
||||
on a local box, outside of Azure build agents.
|
||||
"""
|
||||
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=fallback_run_id_for_local_execution)
|
||||
return Model(workspace=get_default_workspace(), id=model_id)
|
||||
|
||||
|
||||
def get_experiment_name_from_environment() -> str:
|
||||
|
@ -433,3 +446,44 @@ def test_download_outputs_skipped(test_output_dirs: OutputFolderForTests) -> Non
|
|||
download_run_outputs_by_prefix(prefix, test_output_dirs.root_dir, run=run)
|
||||
all_files = list(test_output_dirs.root_dir.rglob("*"))
|
||||
assert len(all_files) == 0
|
||||
|
||||
|
||||
@pytest.mark.after_training_hello_container
|
||||
def test_model_inference_on_single_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
fallback_run_id_for_local_execution = FALLBACK_HELLO_CONTAINER_RUN
|
||||
|
||||
files_to_check = ["test_mse.txt", "test_mae.txt"]
|
||||
|
||||
training_run = get_most_recent_run(fallback_run_id_for_local_execution=fallback_run_id_for_local_execution)
|
||||
all_training_files = training_run.get_file_names()
|
||||
for file in files_to_check:
|
||||
assert f"outputs/{file}" in all_training_files, f"{file} is missing"
|
||||
training_folder = test_output_dirs.root_dir / "training"
|
||||
training_folder.mkdir()
|
||||
training_files = [training_folder / file for file in files_to_check]
|
||||
for file, download_path in zip(files_to_check, training_files):
|
||||
training_run.download_file(f"outputs/{file}", output_file_path=str(download_path))
|
||||
|
||||
container = HelloContainer()
|
||||
container.set_output_to(test_output_dirs.root_dir)
|
||||
container.model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=fallback_run_id_for_local_execution)
|
||||
azure_config = get_default_azure_config()
|
||||
azure_config.train = False
|
||||
ml_runner = MLRunner(container=container, azure_config=azure_config, project_root=test_output_dirs.root_dir)
|
||||
ml_runner.setup()
|
||||
ml_runner.start_logging_to_file()
|
||||
ml_runner.run()
|
||||
|
||||
inference_files = [container.outputs_folder / file for file in files_to_check]
|
||||
for inference_file in inference_files:
|
||||
assert inference_file.exists(), f"{inference_file} is missing"
|
||||
|
||||
for training_file, inference_file in zip(training_files, inference_files):
|
||||
training_lines = training_file.read_text().splitlines()
|
||||
inference_lines = inference_file.read_text().splitlines()
|
||||
# We expect all the files we are reading to have a single float value
|
||||
assert len(training_lines) == 1
|
||||
train_value = float(training_lines[0].strip())
|
||||
assert len(inference_lines) == 1
|
||||
inference_value = float(inference_lines[0].strip())
|
||||
assert inference_value == pytest.approx(train_value, 1e-6)
|
||||
|
|
|
@ -75,7 +75,7 @@ def test_non_image_encoder(test_output_dirs: OutputFolderForTests,
|
|||
# run model inference
|
||||
runner = MLRunner(config)
|
||||
runner.setup()
|
||||
runner.model_inference_train_and_test(checkpoint_handler=checkpoint_handler)
|
||||
runner.model_inference_train_and_test(checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
|
||||
assert config.get_total_number_of_non_imaging_features() == 18
|
||||
|
||||
|
||||
|
|
|
@ -44,5 +44,6 @@ def test_train_2d_classification_model(test_output_dirs: OutputFolderForTests,
|
|||
assert actual_train_loss == pytest.approx(expected_train_loss, abs=1e-6)
|
||||
assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6)
|
||||
assert actual_lr == pytest.approx(expected_learning_rates, rel=1e-5)
|
||||
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN, checkpoint_handler=checkpoint_handler)
|
||||
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN,
|
||||
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
|
||||
assert isinstance(test_results, InferenceMetricsForClassification)
|
||||
|
|
|
@ -92,7 +92,7 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol
|
|||
assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6), "Validation loss"
|
||||
assert actual_lr == pytest.approx(expected_learning_rates, rel=1e-5), "Learning rates"
|
||||
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN,
|
||||
checkpoint_handler=checkpoint_handler)
|
||||
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
|
||||
assert isinstance(test_results, InferenceMetricsForClassification)
|
||||
expected_metrics = [0.636085, 0.735952]
|
||||
assert test_results.metrics.values(class_name)[MetricType.CROSS_ENTROPY.value] == \
|
||||
|
@ -205,7 +205,7 @@ def test_train_classification_multilabel_model(test_output_dirs: OutputFolderFor
|
|||
assert actual_val_loss == pytest.approx(expected_val_loss, abs=1e-6), "Validation loss"
|
||||
assert actual_lr == pytest.approx(expected_learning_rates, rel=1e-5), "Learning rates"
|
||||
test_results = model_testing.model_test(config, ModelExecutionMode.TRAIN,
|
||||
checkpoint_handler=checkpoint_handler)
|
||||
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
|
||||
assert isinstance(test_results, InferenceMetricsForClassification)
|
||||
|
||||
expected_metrics = {MetricType.CROSS_ENTROPY: [1.3996, 5.2966, 1.4020, 0.3553, 0.6908],
|
||||
|
@ -387,7 +387,7 @@ def test_runner_restart(test_output_dirs: OutputFolderForTests) -> None:
|
|||
checkpoint_handler = CheckpointHandler(azure_config=azure_config,
|
||||
container=runner.container,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
_, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
|
||||
_, storing_logger = model_train(checkpoint_path=checkpoint_handler.get_recovery_or_checkpoint_path_train(),
|
||||
container=runner.container)
|
||||
# We expect to have 4 checkpoints, FIXED_EPOCH (recovery), FIXED_EPOCH+1, FIXED_EPOCH and best.
|
||||
assert len(os.listdir(runner.container.checkpoint_folder)) == 4
|
||||
|
|
|
@ -233,7 +233,7 @@ def run_model_inference_train_and_test(test_output_dirs: OutputFolderForTests,
|
|||
|
||||
with mock.patch("InnerEye.ML.model_testing.PARENT_RUN_CONTEXT", Mock()) as m:
|
||||
metrics = MLRunner(config).model_inference_train_and_test(
|
||||
checkpoint_handler=checkpoint_handler,
|
||||
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test(),
|
||||
model_proc=model_proc)
|
||||
|
||||
if model_proc == ModelProcessing.DEFAULT:
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:352b6e6e3dc074c7574b50892faa0474aef9a826e25e4fb8f8f41c7d97f5d8b0
|
||||
size 6447
|
|
@ -7,73 +7,107 @@ import pytest
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from InnerEye.ML.deep_learning_config import DeepLearningConfig
|
||||
from InnerEye.ML.deep_learning_config import DatasetParams, WorkflowParams
|
||||
|
||||
|
||||
def test_validate_dataset_params() -> None:
|
||||
# DatasetParams cannot be initialized with neither of these these parameters set
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DeepLearningConfig(local_dataset=None, azure_dataset_id="")
|
||||
DatasetParams(local_dataset=None, azure_dataset_id="").validate()
|
||||
assert ex.value.args[0] == "Either of local_dataset or azure_dataset_id must be set."
|
||||
|
||||
# The following should be okay
|
||||
DeepLearningConfig(local_dataset=Path("foo"))
|
||||
DeepLearningConfig(azure_dataset_id="bar")
|
||||
DatasetParams(local_dataset=Path("foo")).validate()
|
||||
DatasetParams(azure_dataset_id="bar").validate()
|
||||
|
||||
config = DeepLearningConfig(local_dataset=Path("foo"),
|
||||
azure_dataset_id="",
|
||||
extra_azure_dataset_ids=[])
|
||||
config = DatasetParams(local_dataset=Path("foo"),
|
||||
azure_dataset_id="",
|
||||
extra_azure_dataset_ids=[])
|
||||
config.validate()
|
||||
assert not config.all_azure_dataset_ids()
|
||||
|
||||
config = DeepLearningConfig(azure_dataset_id="foo",
|
||||
extra_azure_dataset_ids=[])
|
||||
config = DatasetParams(azure_dataset_id="foo",
|
||||
extra_azure_dataset_ids=[])
|
||||
config.validate()
|
||||
assert len(config.all_azure_dataset_ids()) == 1
|
||||
|
||||
config = DeepLearningConfig(local_dataset=Path("foo"),
|
||||
azure_dataset_id="",
|
||||
extra_azure_dataset_ids=["bar"])
|
||||
config = DatasetParams(local_dataset=Path("foo"),
|
||||
azure_dataset_id="",
|
||||
extra_azure_dataset_ids=["bar"])
|
||||
config.validate()
|
||||
assert len(config.all_azure_dataset_ids()) == 1
|
||||
|
||||
config = DeepLearningConfig(azure_dataset_id="foo",
|
||||
extra_azure_dataset_ids=["bar"])
|
||||
config = DatasetParams(azure_dataset_id="foo",
|
||||
extra_azure_dataset_ids=["bar"])
|
||||
config.validate()
|
||||
assert len(config.all_azure_dataset_ids()) == 2
|
||||
|
||||
config = DeepLearningConfig(azure_dataset_id="foo",
|
||||
dataset_mountpoint="",
|
||||
extra_dataset_mountpoints=[])
|
||||
config = DatasetParams(azure_dataset_id="foo",
|
||||
dataset_mountpoint="",
|
||||
extra_dataset_mountpoints=[])
|
||||
config.validate()
|
||||
assert not config.all_dataset_mountpoints()
|
||||
|
||||
config = DeepLearningConfig(azure_dataset_id="foo",
|
||||
dataset_mountpoint="foo",
|
||||
extra_dataset_mountpoints=[])
|
||||
config = DatasetParams(azure_dataset_id="foo",
|
||||
dataset_mountpoint="foo",
|
||||
extra_dataset_mountpoints=[])
|
||||
config.validate()
|
||||
assert len(config.all_dataset_mountpoints()) == 1
|
||||
|
||||
config = DeepLearningConfig(azure_dataset_id="foo",
|
||||
dataset_mountpoint="",
|
||||
extra_dataset_mountpoints=["bar"])
|
||||
config = DatasetParams(azure_dataset_id="foo",
|
||||
dataset_mountpoint="",
|
||||
extra_dataset_mountpoints=["bar"])
|
||||
config.validate()
|
||||
assert len(config.all_dataset_mountpoints()) == 1
|
||||
|
||||
config = DeepLearningConfig(azure_dataset_id="foo",
|
||||
extra_azure_dataset_ids=["bar"],
|
||||
dataset_mountpoint="foo",
|
||||
extra_dataset_mountpoints=["bar"])
|
||||
config = DatasetParams(azure_dataset_id="foo",
|
||||
extra_azure_dataset_ids=["bar"],
|
||||
dataset_mountpoint="foo",
|
||||
extra_dataset_mountpoints=["bar"])
|
||||
config.validate()
|
||||
assert len(config.all_dataset_mountpoints()) == 2
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DeepLearningConfig(azure_dataset_id="foo",
|
||||
dataset_mountpoint="foo",
|
||||
extra_dataset_mountpoints=["bar"])
|
||||
DatasetParams(azure_dataset_id="foo",
|
||||
dataset_mountpoint="foo",
|
||||
extra_dataset_mountpoints=["bar"]).validate()
|
||||
assert "Expected the number of azure datasets to equal the number of mountpoints" in ex.value.args[0]
|
||||
|
||||
|
||||
def test_validate_deep_learning_config() -> None:
|
||||
def test_validate_workflow_params() -> None:
|
||||
|
||||
# DeepLearningConfig cannot be initialized with both these parameters set
|
||||
# DeepLearningConfig cannot be initialized with more than one of these parameters set
|
||||
with pytest.raises(ValueError) as ex:
|
||||
DeepLearningConfig(local_dataset=Path("foo"),
|
||||
local_weights_path=Path("foo"), weights_url="bar")
|
||||
assert ex.value.args[0] == "Cannot specify both local_weights_path and weights_url."
|
||||
WorkflowParams(local_dataset=Path("foo"),
|
||||
local_weights_path=[Path("foo")],
|
||||
weights_url=["bar"]).validate()
|
||||
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_dataset=Path("foo"),
|
||||
local_weights_path=[Path("foo")],
|
||||
model_id="foo:1").validate()
|
||||
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_dataset=Path("foo"),
|
||||
weights_url=["foo"],
|
||||
model_id="foo:1").validate()
|
||||
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_dataset=Path("foo"),
|
||||
local_weights_path=[Path("foo")],
|
||||
weights_url=["foo"],
|
||||
model_id="foo:1").validate()
|
||||
assert ex.value.args[0] == "Cannot specify more than one of local_weights_path, weights_url or model_id."
|
||||
|
||||
with pytest.raises(ValueError) as ex:
|
||||
WorkflowParams(local_dataset=Path("foo"),
|
||||
model_id="foo").validate()
|
||||
assert "model_id should be in the form 'model_name:version'" in ex.value.args[0]
|
||||
|
||||
# The following should be okay
|
||||
DeepLearningConfig(local_dataset=Path("foo"), local_weights_path=Path("foo"))
|
||||
DeepLearningConfig(local_dataset=Path("foo"), weights_url="bar")
|
||||
WorkflowParams(local_dataset=Path("foo"), local_weights_path=[Path("foo")]).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), weights_url=["foo"]).validate()
|
||||
WorkflowParams(local_dataset=Path("foo"), model_id="foo:1").validate()
|
||||
|
|
|
@ -98,13 +98,13 @@ def test_innereye_container_init() -> None:
|
|||
"""
|
||||
# The constructor should copy all fields that belong to either WorkflowParams or DatasetParams from the
|
||||
# config object to the container.
|
||||
for (attrib, type_) in [("weights_url", WorkflowParams), ("azure_dataset_id", DatasetParams)]:
|
||||
for (attrib, type_) in [("weights_url", WorkflowParams), ("extra_dataset_mountpoints", DatasetParams)]:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
assert hasattr(type_, attrib)
|
||||
assert hasattr(config, attrib)
|
||||
setattr(config, attrib, "foo")
|
||||
setattr(config, attrib, ["foo"])
|
||||
container = InnerEyeContainer(config)
|
||||
assert getattr(container, attrib) == "foo"
|
||||
assert getattr(container, attrib) == ["foo"]
|
||||
|
||||
|
||||
def test_copied_properties() -> None:
|
||||
|
|
|
@ -102,7 +102,7 @@ def test_model_test(
|
|||
checkpoint_handler.additional_training_done()
|
||||
inference_results = model_testing.segmentation_model_test(config,
|
||||
execution_mode=execution_mode,
|
||||
checkpoint_handler=checkpoint_handler)
|
||||
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
|
||||
epoch_dir = config.outputs_folder / get_best_epoch_results_path(execution_mode)
|
||||
total_num_patients_column_name = f"total_{MetricsFileColumns.Patient.value}".lower()
|
||||
if not total_num_patients_column_name.endswith("s"):
|
||||
|
|
|
@ -38,7 +38,8 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool,
|
|||
assert len(train_results.train_results_per_epoch()) == config.num_epochs
|
||||
|
||||
# Run inference on this
|
||||
test_results = model_test(config=config, data_split=ModelExecutionMode.TEST, checkpoint_handler=checkpoint_handler)
|
||||
test_results = model_test(config=config, data_split=ModelExecutionMode.TEST,
|
||||
checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
|
||||
assert isinstance(test_results, InferenceMetricsForClassification)
|
||||
|
||||
# Mimic using a run recovery and see if it is the same
|
||||
|
@ -55,7 +56,7 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool,
|
|||
shutil.copytree(str(config.checkpoint_folder), str(checkpoint_root))
|
||||
checkpoint_handler_run_recovery.run_recovery = RunRecovery([checkpoint_root])
|
||||
test_results_run_recovery = model_test(config_run_recovery, data_split=ModelExecutionMode.TEST,
|
||||
checkpoint_handler=checkpoint_handler_run_recovery)
|
||||
checkpoint_paths=checkpoint_handler_run_recovery.get_checkpoints_to_test())
|
||||
assert isinstance(test_results_run_recovery, InferenceMetricsForClassification)
|
||||
assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
|
||||
test_results_run_recovery.metrics.values()[MetricType.CROSS_ENTROPY.value]
|
||||
|
@ -70,13 +71,13 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool,
|
|||
local_weights_path = test_output_dirs.root_dir / "local_weights_file.pth"
|
||||
shutil.copyfile(str(config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX),
|
||||
local_weights_path)
|
||||
config_local_weights.local_weights_path = local_weights_path
|
||||
config_local_weights.local_weights_path = [local_weights_path]
|
||||
|
||||
checkpoint_handler_local_weights = get_default_checkpoint_handler(model_config=config_local_weights,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
checkpoint_handler_local_weights.download_recovery_checkpoints_or_weights()
|
||||
test_results_local_weights = model_test(config_local_weights, data_split=ModelExecutionMode.TEST,
|
||||
checkpoint_handler=checkpoint_handler_local_weights)
|
||||
checkpoint_paths=checkpoint_handler_local_weights.get_checkpoints_to_test())
|
||||
assert isinstance(test_results_local_weights, InferenceMetricsForClassification)
|
||||
assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
|
||||
test_results_local_weights.metrics.values()[MetricType.CROSS_ENTROPY.value]
|
||||
|
|
|
@ -285,8 +285,9 @@ def model_train_unittest(config: Optional[DeepLearningConfig],
|
|||
checkpoint_handler = CheckpointHandler(azure_config=azure_config,
|
||||
container=runner.container,
|
||||
project_root=dirs.root_dir)
|
||||
_, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
|
||||
_, storing_logger = model_train(checkpoint_path=checkpoint_handler.get_recovery_or_checkpoint_path_train(),
|
||||
container=runner.container)
|
||||
checkpoint_handler.additional_training_done()
|
||||
return storing_logger, checkpoint_handler # type: ignore
|
||||
|
||||
|
||||
|
|
|
@ -5,22 +5,23 @@
|
|||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from InnerEye.Common.common_util import OTHER_RUNS_SUBDIR_NAME
|
||||
from InnerEye.Common.fixed_paths import MODEL_WEIGHTS_DIR_NAME
|
||||
from InnerEye.Common.fixed_paths import MODEL_INFERENCE_JSON_FILE_NAME
|
||||
from InnerEye.ML.utils.checkpoint_handling import MODEL_WEIGHTS_DIR_NAME
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, get_recovery_checkpoint_path
|
||||
from InnerEye.ML.deep_learning_config import WEIGHTS_FILE
|
||||
from InnerEye.ML.deep_learning_config import FINAL_MODEL_FOLDER, FINAL_ENSEMBLE_MODEL_FOLDER
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
from InnerEye.ML.model_inference_config import read_model_inference_config
|
||||
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
|
||||
from Tests.AfterTraining.test_after_training import FALLBACK_ENSEMBLE_RUN, FALLBACK_SINGLE_RUN, get_most_recent_run, \
|
||||
get_most_recent_run_id
|
||||
from Tests.ML.configs.DummyModel import DummyModel
|
||||
from Tests.ML.util import get_default_checkpoint_handler
|
||||
get_most_recent_run_id, get_most_recent_model_id
|
||||
from Tests.ML.util import get_default_checkpoint_handler, get_default_workspace
|
||||
|
||||
EXTERNAL_WEIGHTS_URL_EXAMPLE = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
||||
|
||||
|
@ -36,10 +37,9 @@ def create_checkpoint_file(file: Path) -> None:
|
|||
assert loaded, "Unable to read the checkpoint file that was just created"
|
||||
|
||||
|
||||
def test_use_local_weights_file(test_output_dirs: OutputFolderForTests) -> None:
|
||||
def test_use_checkpoint_paths_or_urls(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.outputs_folder.mkdir()
|
||||
|
||||
# No checkpoint handling options set.
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
|
@ -47,29 +47,31 @@ def test_use_local_weights_file(test_output_dirs: OutputFolderForTests) -> None:
|
|||
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert not checkpoint_handler.run_recovery
|
||||
assert not checkpoint_handler.local_weights_path
|
||||
assert not checkpoint_handler.trained_weights_paths
|
||||
|
||||
# weights from local_weights_path and weights_url will be modified if needed and stored at this location
|
||||
expected_path = checkpoint_handler.output_params.outputs_folder / WEIGHTS_FILE
|
||||
|
||||
# Set a weights_path
|
||||
checkpoint_handler.azure_config.run_recovery_id = ""
|
||||
checkpoint_handler.container.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE
|
||||
checkpoint_handler.container.weights_url = [EXTERNAL_WEIGHTS_URL_EXAMPLE]
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.local_weights_path == expected_path
|
||||
assert checkpoint_handler.local_weights_path.is_file()
|
||||
expected_download_path = checkpoint_handler.output_params.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME /\
|
||||
os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
|
||||
assert checkpoint_handler.trained_weights_paths[0] == expected_download_path
|
||||
assert checkpoint_handler.trained_weights_paths[0].is_file()
|
||||
|
||||
# set a local_weights_path
|
||||
checkpoint_handler.container.weights_url = ""
|
||||
checkpoint_handler.container.weights_url = []
|
||||
local_weights_path = test_output_dirs.root_dir / "exist.pth"
|
||||
create_checkpoint_file(local_weights_path)
|
||||
checkpoint_handler.container.local_weights_path = local_weights_path
|
||||
checkpoint_handler.container.local_weights_path = [local_weights_path]
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.local_weights_path == expected_path
|
||||
assert checkpoint_handler.trained_weights_paths[0] == local_weights_path
|
||||
assert checkpoint_handler.trained_weights_paths[0].is_file()
|
||||
|
||||
|
||||
@pytest.mark.after_training_single_run
|
||||
def test_download_checkpoints_from_single_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
def test_download_recovery_checkpoints_from_single_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
|
||||
|
@ -92,7 +94,7 @@ def test_download_checkpoints_from_single_run(test_output_dirs: OutputFolderForT
|
|||
|
||||
|
||||
@pytest.mark.after_training_ensemble_run
|
||||
def test_download_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
def test_download_recovery_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
|
@ -104,33 +106,63 @@ def test_download_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderFo
|
|||
assert "has child runs" in str(ex)
|
||||
|
||||
|
||||
@pytest.mark.after_training_single_run
|
||||
def test_download_model_from_single_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
|
||||
# No checkpoint handling options set.
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
|
||||
|
||||
# Set a run recovery object - non ensemble
|
||||
checkpoint_handler.container.model_id = model_id
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.trained_weights_paths
|
||||
|
||||
expected_model_root = config.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME / FINAL_MODEL_FOLDER
|
||||
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
|
||||
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
|
||||
|
||||
assert len(expected_paths) == 1 # A registered model for a non-ensemble run should contain only one checkpoint
|
||||
assert len(checkpoint_handler.trained_weights_paths) == 1
|
||||
assert expected_paths[0] == checkpoint_handler.trained_weights_paths[0]
|
||||
assert expected_paths[0].is_file()
|
||||
|
||||
|
||||
@pytest.mark.after_training_ensemble_run
|
||||
def test_download_model_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
|
||||
# No checkpoint handling options set.
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
|
||||
|
||||
# Set a run recovery object - non ensemble
|
||||
checkpoint_handler.container.model_id = model_id
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.trained_weights_paths
|
||||
|
||||
expected_model_root = config.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME / FINAL_ENSEMBLE_MODEL_FOLDER
|
||||
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
|
||||
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
|
||||
|
||||
assert len(checkpoint_handler.trained_weights_paths) == len(expected_paths)
|
||||
assert set(checkpoint_handler.trained_weights_paths) == set(expected_paths)
|
||||
for path in expected_paths:
|
||||
assert path.is_file()
|
||||
|
||||
|
||||
def test_get_recovery_path_train(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.outputs_folder.mkdir()
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
|
||||
assert checkpoint_handler.get_recovery_path_train() is None
|
||||
|
||||
# weights from local_weights_path and weights_url will be modified if needed and stored at this location
|
||||
expected_path = checkpoint_handler.output_params.outputs_folder / WEIGHTS_FILE
|
||||
|
||||
# Set a weights_url to get checkpoint from
|
||||
checkpoint_handler.azure_config.run_recovery_id = ""
|
||||
checkpoint_handler.container.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.local_weights_path == expected_path
|
||||
assert checkpoint_handler.get_recovery_path_train() == expected_path
|
||||
|
||||
# Set a local_weights_path to get checkpoint from
|
||||
checkpoint_handler.container.weights_url = ""
|
||||
local_weights_path = test_output_dirs.root_dir / "exist.pth"
|
||||
create_checkpoint_file(local_weights_path)
|
||||
checkpoint_handler.container.local_weights_path = local_weights_path
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
assert checkpoint_handler.local_weights_path == expected_path
|
||||
assert checkpoint_handler.get_recovery_path_train() == expected_path
|
||||
assert checkpoint_handler.get_recovery_or_checkpoint_path_train() is None
|
||||
|
||||
|
||||
@pytest.mark.after_training_single_run
|
||||
|
@ -147,7 +179,7 @@ def test_get_recovery_path_train_single_run(test_output_dirs: OutputFolderForTes
|
|||
|
||||
# Run recovery with start epoch provided should succeed
|
||||
expected_path = get_recovery_checkpoint_path(path=config.checkpoint_folder)
|
||||
assert checkpoint_handler.get_recovery_path_train() == expected_path
|
||||
assert checkpoint_handler.get_recovery_or_checkpoint_path_train() == expected_path
|
||||
|
||||
|
||||
@pytest.mark.after_training_single_run
|
||||
|
@ -159,8 +191,8 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
|
|||
|
||||
# We have not set a run_recovery, nor have we trained, so this should fail to get a checkpoint
|
||||
with pytest.raises(ValueError) as ex:
|
||||
checkpoint_handler.get_best_checkpoint()
|
||||
assert "no run recovery object provided and no training has been done in this run" in ex.value.args[0]
|
||||
checkpoint_handler.get_best_checkpoints()
|
||||
assert "no run recovery object provided and no training has been done in this run" in ex.value.args[0]
|
||||
|
||||
run_recovery_id = get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
|
||||
|
||||
|
@ -169,7 +201,7 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
|
|||
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
expected_checkpoint = config.checkpoint_folder / f"{BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}"
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoint()
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
|
||||
assert checkpoint_paths
|
||||
assert len(checkpoint_paths) == 1
|
||||
assert expected_checkpoint == checkpoint_paths[0]
|
||||
|
@ -183,7 +215,7 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
|
|||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
|
||||
# There is no checkpoint in the current run - use the one from run_recovery
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoint()
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
|
||||
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
assert checkpoint_paths
|
||||
assert len(checkpoint_paths) == 1
|
||||
|
@ -192,24 +224,23 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
|
|||
# Copy over checkpoints to make it look like training has happened and a better checkpoint written
|
||||
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint.touch()
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoint()
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
|
||||
assert checkpoint_paths
|
||||
assert len(checkpoint_paths) == 1
|
||||
assert expected_checkpoint == checkpoint_paths[0]
|
||||
|
||||
|
||||
@pytest.mark.after_training_ensemble_run
|
||||
def test_get_all_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
def test_download_checkpoints_from_hyperdrive_child_runs(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.outputs_folder.mkdir()
|
||||
manage_recovery = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
hyperdrive_run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
|
||||
manage_recovery.download_checkpoints_from_hyperdrive_child_runs(hyperdrive_run)
|
||||
checkpoint_handler.download_checkpoints_from_hyperdrive_child_runs(hyperdrive_run)
|
||||
expected_checkpoints = [config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME / str(i)
|
||||
/ BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX for i in range(2)]
|
||||
checkpoint_paths = manage_recovery.get_best_checkpoint()
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
|
||||
assert checkpoint_paths
|
||||
assert len(checkpoint_paths) == 2
|
||||
assert set(expected_checkpoints) == set(checkpoint_paths)
|
||||
|
@ -218,28 +249,27 @@ def test_get_all_checkpoints_from_ensemble_run(test_output_dirs: OutputFolderFor
|
|||
def test_get_checkpoints_to_test(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.outputs_folder.mkdir()
|
||||
manage_recovery = get_default_checkpoint_handler(model_config=config,
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
|
||||
# Set a local_weights_path to get checkpoint from. Model has not trained and no run recovery provided,
|
||||
# so the local weights should be used ignoring any epochs to test
|
||||
local_weights_path = test_output_dirs.root_dir / "exist.pth"
|
||||
create_checkpoint_file(local_weights_path)
|
||||
manage_recovery.container.local_weights_path = local_weights_path
|
||||
manage_recovery.download_recovery_checkpoints_or_weights()
|
||||
checkpoint_and_paths = manage_recovery.get_checkpoints_to_test()
|
||||
checkpoint_handler.container.local_weights_path = [local_weights_path]
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
|
||||
assert checkpoint_and_paths
|
||||
assert len(checkpoint_and_paths) == 1
|
||||
assert checkpoint_and_paths[0] == manage_recovery.output_params.outputs_folder / WEIGHTS_FILE
|
||||
assert checkpoint_and_paths[0] == local_weights_path
|
||||
|
||||
manage_recovery.additional_training_done()
|
||||
manage_recovery.container.checkpoint_folder.mkdir()
|
||||
checkpoint_handler.additional_training_done()
|
||||
checkpoint_handler.container.checkpoint_folder.mkdir(parents=True)
|
||||
|
||||
# Copy checkpoint to make it seem like training has happened
|
||||
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint.touch()
|
||||
checkpoint_and_paths = manage_recovery.get_checkpoints_to_test()
|
||||
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
|
||||
|
||||
assert checkpoint_and_paths
|
||||
assert len(checkpoint_and_paths) == 1
|
||||
|
@ -250,20 +280,19 @@ def test_get_checkpoints_to_test(test_output_dirs: OutputFolderForTests) -> None
|
|||
def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.outputs_folder.mkdir()
|
||||
manage_recovery = get_default_checkpoint_handler(model_config=config,
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
|
||||
run_recovery_id = get_most_recent_run_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
|
||||
|
||||
# Now set a run recovery object and set the start epoch to 1, so we get one epoch from
|
||||
# run recovery and one from the training checkpoints
|
||||
manage_recovery.azure_config.run_recovery_id = run_recovery_id
|
||||
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
|
||||
|
||||
manage_recovery.additional_training_done()
|
||||
manage_recovery.download_recovery_checkpoints_or_weights()
|
||||
checkpoint_handler.additional_training_done()
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
|
||||
checkpoint_and_paths = manage_recovery.get_checkpoints_to_test()
|
||||
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
|
||||
|
||||
assert checkpoint_and_paths
|
||||
assert len(checkpoint_and_paths) == 1
|
||||
|
@ -272,7 +301,7 @@ def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTes
|
|||
# Copy checkpoint to make it seem like training has happened
|
||||
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint.touch()
|
||||
checkpoint_and_paths = manage_recovery.get_checkpoints_to_test()
|
||||
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
|
||||
|
||||
assert checkpoint_and_paths
|
||||
assert len(checkpoint_and_paths) == 1
|
||||
|
@ -281,108 +310,94 @@ def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTes
|
|||
|
||||
def test_download_model_weights(test_output_dirs: OutputFolderForTests) -> None:
|
||||
# Download a sample ResNet model from a URL given in the Pytorch docs
|
||||
# The downloaded model does not match the architecture, which is okay since we are only testing the download here.
|
||||
result_path = CheckpointHandler.download_weights(urls=[EXTERNAL_WEIGHTS_URL_EXAMPLE],
|
||||
download_folder=test_output_dirs.root_dir)
|
||||
assert len(result_path) == 1
|
||||
assert result_path[0] == test_output_dirs.root_dir / os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
|
||||
assert result_path[0].is_file()
|
||||
|
||||
model_config = DummyModel(weights_url=EXTERNAL_WEIGHTS_URL_EXAMPLE)
|
||||
manage_recovery = get_default_checkpoint_handler(model_config=model_config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
result_path = manage_recovery.download_weights()
|
||||
assert result_path.is_file()
|
||||
modified_time = result_path[0].stat().st_mtime
|
||||
|
||||
result_path = CheckpointHandler.download_weights(urls=[EXTERNAL_WEIGHTS_URL_EXAMPLE, EXTERNAL_WEIGHTS_URL_EXAMPLE],
|
||||
download_folder=test_output_dirs.root_dir)
|
||||
assert len(result_path) == 2
|
||||
assert len(list(test_output_dirs.root_dir.glob("*"))) == 1
|
||||
assert result_path[0].samefile(result_path[1])
|
||||
assert result_path[0] == test_output_dirs.root_dir / os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
|
||||
assert result_path[0].is_file()
|
||||
# This call should not re-download the files, just return the existing ones
|
||||
assert result_path[0].stat().st_mtime == modified_time
|
||||
|
||||
|
||||
@pytest.mark.after_training_single_run
|
||||
def test_get_checkpoints_from_model_single_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_SINGLE_RUN)
|
||||
|
||||
downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
|
||||
workspace=get_default_workspace(),
|
||||
download_path=test_output_dirs.root_dir)
|
||||
# Check a single checkpoint has been downloaded
|
||||
expected_model_root = test_output_dirs.root_dir / FINAL_MODEL_FOLDER
|
||||
assert expected_model_root.is_dir()
|
||||
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
|
||||
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
|
||||
|
||||
assert len(expected_paths) == 1 # A registered model for a non-ensemble run should contain only one checkpoint
|
||||
assert len(downloaded_checkpoints) == 1
|
||||
assert expected_paths[0] == downloaded_checkpoints[0]
|
||||
assert expected_paths[0].is_file()
|
||||
|
||||
|
||||
@pytest.mark.after_training_ensemble_run
|
||||
def test_get_checkpoints_from_model_ensemble_run(test_output_dirs: OutputFolderForTests) -> None:
|
||||
model_id = get_most_recent_model_id(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
|
||||
|
||||
downloaded_checkpoints = CheckpointHandler.get_checkpoints_from_model(model_id=model_id,
|
||||
workspace=get_default_workspace(),
|
||||
download_path=test_output_dirs.root_dir)
|
||||
# Check that all the ensemble checkpoints have been downloaded
|
||||
expected_model_root = test_output_dirs.root_dir / FINAL_ENSEMBLE_MODEL_FOLDER
|
||||
assert expected_model_root.is_dir()
|
||||
model_inference_config = read_model_inference_config(expected_model_root / MODEL_INFERENCE_JSON_FILE_NAME)
|
||||
expected_paths = [expected_model_root / x for x in model_inference_config.checkpoint_paths]
|
||||
|
||||
assert len(expected_paths) == len(downloaded_checkpoints)
|
||||
assert set(expected_paths) == set(downloaded_checkpoints)
|
||||
for expected_path in expected_paths:
|
||||
assert expected_path.is_file()
|
||||
|
||||
|
||||
def test_get_local_weights_path_or_download(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
manage_recovery = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
|
||||
# If the model has neither local_weights_path or weights_url set, should fail.
|
||||
with pytest.raises(ValueError) as ex:
|
||||
manage_recovery.get_local_weights_path_or_download()
|
||||
assert "neither local_weights_path nor weights_url is set in the model config" in ex.value.args[0]
|
||||
checkpoint_handler.get_local_checkpoints_path_or_download()
|
||||
assert "none of model_id, local_weights_path or weights_url is set in the model config." in ex.value.args[0]
|
||||
|
||||
# If local_weights_path folder exists, get_local_weights_path_or_download should not do anything.
|
||||
local_weights_path = manage_recovery.project_root / "exist.pth"
|
||||
# If local_weights_path folder exists, get_local_checkpoints_path_or_download should not do anything.
|
||||
local_weights_path = test_output_dirs.root_dir / "exist.pth"
|
||||
create_checkpoint_file(local_weights_path)
|
||||
manage_recovery.container.local_weights_path = local_weights_path
|
||||
returned_weights_path = manage_recovery.get_local_weights_path_or_download()
|
||||
assert local_weights_path == returned_weights_path
|
||||
checkpoint_handler.container.local_weights_path = [local_weights_path]
|
||||
returned_weights_path = checkpoint_handler.get_local_checkpoints_path_or_download()
|
||||
assert local_weights_path == returned_weights_path[0]
|
||||
|
||||
# Pointing the model to a URL should trigger a download
|
||||
manage_recovery.container.local_weights_path = None
|
||||
manage_recovery.container.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE
|
||||
downloaded_weights = manage_recovery.get_local_weights_path_or_download()
|
||||
# Download goes into <project_root> / "modelweights" / "resnet18-5c106cde.pth"
|
||||
expected_path = manage_recovery.project_root / MODEL_WEIGHTS_DIR_NAME / \
|
||||
checkpoint_handler.container.local_weights_path = []
|
||||
checkpoint_handler.container.weights_url = [EXTERNAL_WEIGHTS_URL_EXAMPLE]
|
||||
downloaded_weights = checkpoint_handler.get_local_checkpoints_path_or_download()
|
||||
expected_path = checkpoint_handler.output_params.checkpoint_folder / MODEL_WEIGHTS_DIR_NAME / \
|
||||
os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
|
||||
assert downloaded_weights
|
||||
assert downloaded_weights.is_file()
|
||||
assert expected_path == downloaded_weights
|
||||
assert len(downloaded_weights) == 1
|
||||
assert downloaded_weights[0].is_file()
|
||||
assert expected_path == downloaded_weights[0]
|
||||
|
||||
# try again, should not re-download
|
||||
modified_time = downloaded_weights.stat().st_mtime
|
||||
downloaded_weights_new = manage_recovery.get_local_weights_path_or_download()
|
||||
assert downloaded_weights_new
|
||||
assert downloaded_weights_new.stat().st_mtime == modified_time
|
||||
|
||||
|
||||
def test_get_and_modify_local_weights(test_output_dirs: OutputFolderForTests) -> None:
|
||||
config = ModelConfigBase(should_validate=False)
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.outputs_folder.mkdir()
|
||||
|
||||
manage_recovery = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
|
||||
# If the model has neither local_weights_path or weights_url set, should fail.
|
||||
with pytest.raises(ValueError) as ex:
|
||||
manage_recovery.get_and_save_modified_weights()
|
||||
assert "neither local_weights_path nor weights_url is set in the model config" in ex.value.args[0]
|
||||
|
||||
# Pointing the model to a local_weights_path that does not exist will raise an error.
|
||||
manage_recovery.container.local_weights_path = manage_recovery.project_root / "non_exist"
|
||||
with pytest.raises(FileNotFoundError) as file_ex:
|
||||
manage_recovery.get_and_save_modified_weights()
|
||||
assert "Could not find the weights file" in file_ex.value.args[0]
|
||||
|
||||
# Test that weights are properly modified when a local_weights_path is set
|
||||
|
||||
# set a method to modify weights:
|
||||
with mock.patch.object(ModelConfigBase,
|
||||
'load_checkpoint_and_modify',
|
||||
lambda self, path_to_checkpoint: {"modified": "local", # type: ignore
|
||||
"path": path_to_checkpoint}):
|
||||
# Set the local_weights_path to an empty file, which will be passed to modify_checkpoint
|
||||
local_weights_path = manage_recovery.project_root / "exist.pth"
|
||||
create_checkpoint_file(local_weights_path)
|
||||
manage_recovery.container.local_weights_path = local_weights_path
|
||||
weights_path = manage_recovery.get_and_save_modified_weights()
|
||||
expected_path = manage_recovery.output_params.outputs_folder / WEIGHTS_FILE
|
||||
# read from weights_path and check that the dict has been written
|
||||
assert weights_path.is_file()
|
||||
assert expected_path == weights_path
|
||||
read = torch.load(str(weights_path))
|
||||
assert read.keys() == {"modified", "path"}
|
||||
assert read["modified"] == "local"
|
||||
assert read["path"] == local_weights_path
|
||||
# clean up
|
||||
weights_path.unlink()
|
||||
|
||||
# Test that weights are properly modified when weights_url is set
|
||||
|
||||
# set a different method to modify weights, to avoid using old files from other tests:
|
||||
with mock.patch.object(ModelConfigBase,
|
||||
'load_checkpoint_and_modify',
|
||||
lambda self, path_to_checkpoint: {"modified": "url", "path": path_to_checkpoint}):
|
||||
# Set the weights_url to the sample pytorch URL, which will be passed to modify_checkpoint
|
||||
manage_recovery.container.local_weights_path = None
|
||||
manage_recovery.container.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE
|
||||
weights_path = manage_recovery.get_and_save_modified_weights()
|
||||
expected_path = manage_recovery.output_params.outputs_folder / WEIGHTS_FILE
|
||||
# read from weights_path and check that the dict has been written
|
||||
assert weights_path.is_file()
|
||||
assert expected_path == weights_path
|
||||
read = torch.load(str(weights_path))
|
||||
assert read.keys() == {"modified", "path"}
|
||||
assert read["modified"] == "url"
|
||||
assert read["path"] == manage_recovery.project_root / MODEL_WEIGHTS_DIR_NAME / \
|
||||
os.path.basename(urlparse(EXTERNAL_WEIGHTS_URL_EXAMPLE).path)
|
||||
modified_time = downloaded_weights[0].stat().st_mtime
|
||||
downloaded_weights_new = checkpoint_handler.get_local_checkpoints_path_or_download()
|
||||
assert len(downloaded_weights_new) == 1
|
||||
assert downloaded_weights_new[0].stat().st_mtime == modified_time
|
||||
|
|
|
@ -36,7 +36,7 @@ steps:
|
|||
# hence don't set PYTHONPATH
|
||||
- bash: |
|
||||
source activate InnerEye
|
||||
pytest ./Tests/ -m "not (gpu or azureml or after_training_single_run or after_training_ensemble_run or inference or after_training_2node or after_training_glaucoma_cv_run)" --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-config=.coveragerc --cov-report=xml -n 2 --dist=loadscope --verbose
|
||||
pytest ./Tests/ -m "not (gpu or azureml or after_training_single_run or after_training_ensemble_run or inference or after_training_2node or after_training_glaucoma_cv_run or after_training_hello_container)" --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-config=.coveragerc --cov-report=xml -n 2 --dist=loadscope --verbose
|
||||
env:
|
||||
APPLICATION_KEY: $(InnerEyeDeepLearningServicePrincipalKey)
|
||||
DATASETS_ACCOUNT_KEY: $(InnerEyePublicDatasetsStorageKey)
|
||||
|
|
|
@ -139,9 +139,9 @@ empty, in which case the union of all validation sets for the `N` child runs wil
|
|||
|
||||
### Recovering failed runs and continuing training
|
||||
|
||||
To train further with an already-created model, give the above command with additional switches like these:
|
||||
To train further with an already-created model, give the above command with the `run_recovery_id` argument:
|
||||
```
|
||||
--run_recovery_id=foo_bar:foo_bar_12345_abcd --start_epoch=120
|
||||
--run_recovery_id=foo_bar:foo_bar_12345_abcd
|
||||
```
|
||||
The run recovery ID is of the form "experiment_id:run_id". When you trained your original model, it will have been
|
||||
queued as a "Run" inside of an "Experiment". The experiment will be given a name derived from the branch name - for
|
||||
|
@ -160,24 +160,45 @@ of the AzureML UI. The easiest way to get it is to go to any of the child runs a
|
|||
run recovery ID without the final underscore and digit.
|
||||
|
||||
### Testing an existing model
|
||||
To evaluate an existing model on a test set, you can use models from previous runs in AzureML or from local checkpoints.
|
||||
To evaluate an existing model on a test set, you can use registered models from previous runs in AzureML, a set of
|
||||
local checkpoints or a set of URLs pointing to model checkpoints. For all these options, you will need to set the
|
||||
flag `no-train` along with additional command line arguments to specify the checkpoints.
|
||||
|
||||
#### From a previus run in AzureML:
|
||||
This is similar to continuing training using a run_recovery object, but you will need to set `--no-train`.
|
||||
Thus your command should look like this:
|
||||
#### From a registered model on AzureML:
|
||||
You will need to specify the registered model to run on using the `model_id` argument. You can find the model name and
|
||||
version by clicking on `Registered Models` on the Details tab of a run in the AzureML UI.
|
||||
The model id is of the form "model_name:model_version". Thus your command should look like this:
|
||||
|
||||
```shell script
|
||||
python Inner/ML/runner.py --azureml --model=Prostate --no-train --cluster=my_cluster_name \
|
||||
--run_recovery_id=foo_bar:foo_bar_12345_abcd --start_epoch=120
|
||||
python Inner/ML/runner.py --azureml --model=Prostate --cluster=my_cluster_name \
|
||||
--no-train --model_id=Prostate:1
|
||||
```
|
||||
#### From a local checkpoint:
|
||||
To evaluate a model using a local checkpoint, use the local_weights_path to specify the path to the model checkpoint
|
||||
and set train to `False`.
|
||||
#### From local checkpoints:
|
||||
To evaluate a model using one or more local checkpoints, use the `local_weights_path` argument to specify the path(s) to the
|
||||
model checkpoint(s) on the local disk.
|
||||
```shell script
|
||||
python Inner/ML/runner.py --model=Prostate --no-train --local_weights_path=path_to_your_checkpoint
|
||||
```
|
||||
To run on multiple checkpoints (if you have trained an ensemble model), specify each checkpoint using the argument
|
||||
`local_weights_path`.
|
||||
```shell script
|
||||
python Inner/ML/runner.py --model=Prostate --no-train --local_weights_path=path_to_first_checkpoint,path_to_second_checkpoint
|
||||
```
|
||||
|
||||
Alternatively, to submit an AzureML run to apply a model to a single image on your local disc,
|
||||
#### From URLs:
|
||||
To evaluate a model using one or more checkpoints each specified by a URL, use the `weights_url` argument to specify the
|
||||
url(s) from which the model checkpoint(s) should be downloaded.
|
||||
```shell script
|
||||
python Inner/ML/runner.py --model=Prostate --no-train --weights_url=url_for_your_checkpoint
|
||||
```
|
||||
To run on multiple checkpoints (if you have trained an ensemble model), specify each checkpoint using the argument
|
||||
`weights_url`.
|
||||
```shell script
|
||||
python Inner/ML/runner.py --model=Prostate --no-train --weights_url=url_for_first_checkpoint,url_for_second_checkpoint
|
||||
```
|
||||
|
||||
#### Running a registered AzureML model on a single image on the local disk
|
||||
To submit an AzureML run to apply a model to a single image on your local disc,
|
||||
you can use the script `submit_for_inference.py`, with a command of this form:
|
||||
```shell script
|
||||
python InnerEye/Scripts/submit_for_inference.py --image_file ~/somewhere/ct.nii.gz --model_id Prostate:555 \
|
||||
|
|
Загрузка…
Ссылка в новой задаче