Checkpoint recovery refactoring (#439)

* Add auto-restart

* Change handling of checkpoints and clean-up

* Save last k recovery checkpoints

* Log epoch for keeping last ckpt

* Keeping k last checkpoints

* Add possibility to recover from particular checkpoint

* Update tests

* Check k recovery

* Re-add skipif

* Correct pick up of recovery runs and add test

* Correct pick up of recovery runs and add test

* Remove all start epochs

* Remove all start epochs

* Spimplify run recovery logic

* Fix it

* Merge conflicts import errors

* Fix it

* Fix tests in test_scalar_model.py

* Fix tests in test_model_util.py

* Fix tests in test_scalar_model.py

* Fix tests in test_model_training.py

* Avoid forcing the user to log epoch

* Fix test_get_checkpoints

* Fix test_checkpoint_handling.py

* Fix callback

* Update CHANGELOG.md

* Self PR review comments

* Fix more tests

* Fix argument in test

* Mypy

* Update InnerEye-DeepLearning.iml

* Update InnerEye-DeepLearning.iml

* Fix mypy errors

* Address PR comment

* Typo

* mypy fix

* just style
This commit is contained in:
melanibe 2021-04-21 16:40:20 +02:00 коммит произвёл GitHub
Родитель f421234c3d
Коммит adffa95a14
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
24 изменённых файлов: 264 добавлений и 256 удалений

Просмотреть файл

@ -13,4 +13,4 @@
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="pytest" />
</component>
</module>
</module>

Просмотреть файл

@ -16,7 +16,6 @@ created.
- ([#417](https://github.com/microsoft/InnerEye-DeepLearning/pull/417)) Added a generic way of adding PyTorch Lightning
models to the toolbox. It is now possible to train almost any Lightning model with the InnerEye toolbox in AzureML,
with only minimum code changes required. See [the MD documentation](docs/bring_your_own_model.md) for details.
- ([#438](https://github.com/microsoft/InnerEye-DeepLearning/pull/438)) Add links and small docs to InnerEye-Gateway and InnerEye-Inference
- ([#430](https://github.com/microsoft/InnerEye-DeepLearning/pull/430)) Update conversion to 1.0.1 InnerEye-DICOM-RT to
add: manufacturer, SoftwareVersions, Interpreter and ROIInterpretedTypes.
- ([#385](https://github.com/microsoft/InnerEye-DeepLearning/pull/385)) Add the ability to train a model on multiple
@ -48,6 +47,10 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
- ([#405](https://github.com/microsoft/InnerEye-DeepLearning/pull/405)) Cross-validation runs for classification models
now also generate a report notebook summarising the metrics from the individual splits. Also includes minor formatting
improvements for standard classification reports.
- ([#438](https://github.com/microsoft/InnerEye-DeepLearning/pull/438)) Add links and small docs to InnerEye-Gateway and InnerEye-Inference
- ([#439](https://github.com/microsoft/InnerEye-DeepLearning/pull/439)) Enable automatic job recovery from last recovery
checkpoint in case of job pre-emption on AML. Give the possibility to the user to keep more than one recovery
checkpoint.
### Changed
@ -62,8 +65,11 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
end-to-end test for classification cross-validation. WARNING: upgrade PL version causes hanging of multi-node
training.
- ([#437])(https://github.com/microsoft/InnerEye-DeepLearning/pull/437)) Upgrade to PyTorch-Lightning 1.2.8.
- ([#439](https://github.com/microsoft/InnerEye-DeepLearning/pull/439)) Recovery checkpoints are now
named `recovery_epoch=x.ckpt` instead of `recovery.ckpt` or `recovery-v0.ckpt`.
### Fixed
- ([#422](https://github.com/microsoft/InnerEye-DeepLearning/pull/422)) Documentation - clarified `setting_up_aml.md`
datastore creation instructions and fixed small typos in `hello_world_model.md`
- ([#432](https://github.com/microsoft/InnerEye-DeepLearning/pull/432)) Fixed cross-validation for classification
@ -73,7 +79,9 @@ with only minimum code changes required. See [the MD documentation](docs/bring_y
set, display an error message and terminate the run.
- ([#437](https://github.com/microsoft/InnerEye-DeepLearning/pull/437)) Fixed multi-node DDP bug in PL v1.2.8. Re-add
end-to-end test for multi-node.
### Removed
- ([#439](https://github.com/microsoft/InnerEye-DeepLearning/pull/439)) Deprecated `start_epoch` config argument.
### Deprecated

Просмотреть файл

@ -4,10 +4,13 @@
# ------------------------------------------------------------------------------------------
import abc
import logging
import re
from datetime import datetime
from enum import Enum, unique
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
DATASET_CSV_FILE_NAME = "dataset.csv"
CHECKPOINT_SUFFIX = ".ckpt"
@ -61,18 +64,16 @@ class OneHotEncoderBase(abc.ABC):
raise NotImplementedError("get_feature_length must be implemented by sub classes")
def create_recovery_checkpoint_path(path: Path) -> Path:
def get_recovery_checkpoint_path(path: Path) -> Path:
"""
Returns the file name of a recovery checkpoint in the given folder. Raises a FileNotFoundError if no
Returns the path to the last recovery checkpoint in the given folder or the provided filename. Raises a
FileNotFoundError if no
recovery checkpoint file is present.
:param path: Path to checkpoint folder
"""
# Recovery checkpoints are written alternately as recovery.ckpt and recovery-v0.ckpt.
best_checkpoint1 = path / f"{RECOVERY_CHECKPOINT_FILE_NAME_WITH_SUFFIX}"
best_checkpoint2 = path / f"{RECOVERY_CHECKPOINT_FILE_NAME}-v0{CHECKPOINT_SUFFIX}"
for p in [best_checkpoint1, best_checkpoint2]:
if p.is_file():
return p
recovery_ckpt_and_epoch = find_recovery_checkpoint_and_epoch(path)
if recovery_ckpt_and_epoch is not None:
return recovery_ckpt_and_epoch[0]
files = list(path.glob("*"))
raise FileNotFoundError(f"No checkpoint files found in {path}. Existing files: {' '.join(p.name for p in files)}")
@ -85,34 +86,55 @@ def get_best_checkpoint_path(path: Path) -> Path:
return path / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
def keep_latest(path: Path, search_pattern: str) -> Optional[Path]:
def find_all_recovery_checkpoints(path: Path) -> Optional[List[Path]]:
"""
Looks at all files that match the given pattern via "glob", and deletes all of them apart from the most most
recent file. The surviving file is returned. If there is no single file that matches the search pattern, then
return None.
Extracts all file starting with RECOVERY_CHECKPOINT_FILE_NAME in path
:param path:
:return:
"""
all_recovery_files = [f for f in path.glob(RECOVERY_CHECKPOINT_FILE_NAME + "*")]
if len(all_recovery_files) == 0:
return None
return all_recovery_files
PathAndEpoch = Tuple[Path, int]
def extract_latest_checkpoint_and_epoch(available_files: List[Path]) -> PathAndEpoch:
"""
Checkpoints are saved as recovery_epoch={epoch}.ckpt, find the latest ckpt and epoch number.
:param available_files: all available checkpoints
:return: path the checkpoint from latest epoch and epoch number
"""
recovery_epochs = [int(re.findall(r"[\d]+", f.stem)[0]) for f in available_files]
idx_max_epoch = int(np.argmax(recovery_epochs))
return available_files[idx_max_epoch], recovery_epochs[idx_max_epoch]
def find_recovery_checkpoint_and_epoch(path: Path) -> Optional[PathAndEpoch]:
"""
Looks at all the recovery files, extracts the epoch number for all of them and returns the most recent (latest
epoch)
checkpoint path along with the corresponding epoch number. If no recovery checkpoint are found, return None.
:param path: The folder to start searching in.
:param search_pattern: The glob pattern that specifies the files that should be searched.
:return: None if there is no file matching the search pattern, or a Path object that has the latest file matching
the pattern.
:return: None if there is no file matching the search pattern, or a Tuple with Path object and integer pointing to
recovery checkpoint path and recovery epoch.
"""
files_and_mod_time = [(f, f.stat().st_mtime) for f in path.glob(search_pattern)]
files_and_mod_time.sort(key=lambda f: f[1], reverse=True)
for (f, _) in files_and_mod_time[1:]:
logging.info(f"Removing file: {f}")
f.unlink()
if files_and_mod_time:
return files_and_mod_time[0][0]
available_checkpoints = find_all_recovery_checkpoints(path)
if available_checkpoints is not None:
return extract_latest_checkpoint_and_epoch(available_checkpoints)
return None
def keep_best_checkpoint(path: Path) -> Path:
def create_best_checkpoint(path: Path) -> Path:
"""
Clean up all checkpoints that are found in the given folder, and keep only the "best" one. "Best" is at the moment
defined as being the last checkpoint, but could be based on some defined policy. The best checkpoint will be
renamed to `best_checkpoint.ckpt`. All other files checkpoint files
but the best will be removed (or an existing checkpoint renamed to be the best checkpoint).
Creates the best checkpoint file. "Best" is at the moment defined as being the last checkpoint, but could be
based on some defined policy.
The best checkpoint will be renamed to `best_checkpoint.ckpt`.
:param path: The folder that contains all checkpoint files.
"""
logging.debug(f"Files in checkpoint folder: {' '.join(p.name for p in path.glob('*'))}")
last_ckpt = path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
all_files = f"Existing files: {' '.join(p.name for p in path.glob('*'))}"
if not last_ckpt.is_file():
@ -124,21 +146,6 @@ def keep_best_checkpoint(path: Path) -> Path:
return best
def cleanup_checkpoint_folder(path: Path) -> None:
"""
Removes surplus files from the checkpoint folder, and unifies the names of the files that are kept:
1) Keep only the most recent recovery checkpoint file
2) Chooses the best checkpoint file according to keep_best_checkpoint, and rename it to
BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
:param path: The folder containing all model checkpoints.
"""
logging.info(f"Files in checkpoint folder: {' '.join(p.name for p in path.glob('*'))}")
recovery = keep_latest(path, RECOVERY_CHECKPOINT_FILE_NAME + "*")
if recovery:
recovery.rename(path / RECOVERY_CHECKPOINT_FILE_NAME_WITH_SUFFIX)
keep_best_checkpoint(path)
def create_unique_timestamp_id() -> str:
"""
Creates a unique string using the current time in UTC, up to seconds precision, with characters that

Просмотреть файл

@ -35,7 +35,6 @@ class BasicModel2Epochs(SegmentationModelBase):
class_weights=equally_weighted_classes(fg_classes),
num_dataload_workers=1,
train_batch_size=8,
start_epoch=0,
num_epochs=2,
recovery_checkpoint_save_interval=1,
use_mixed_precision=True,

Просмотреть файл

@ -42,7 +42,6 @@ class GbmBase(SegmentationModelBase):
tail=[1.0],
class_weights=equally_weighted_classes(fg_classes),
train_batch_size=8,
start_epoch=0,
num_epochs=200,
l_rate=1e-3,
l_rate_polynomial_gamma=0.9,

Просмотреть файл

@ -99,7 +99,6 @@ class HeadAndNeckBase(SegmentationModelBase):
norm_method=PhotometricNormalizationMethod.CtWindow,
level=50,
window=600,
start_epoch=0,
l_rate=1e-3,
min_l_rate=1e-5,
l_rate_polynomial_gamma=0.9,

Просмотреть файл

@ -58,7 +58,6 @@ class HelloWorld(SegmentationModelBase):
# and testing (ie: how many epochs to test)
num_dataload_workers=0,
train_batch_size=2,
start_epoch=0,
num_epochs=2,
recovery_checkpoint_save_interval=1,
use_mixed_precision=True,

Просмотреть файл

@ -47,7 +47,6 @@ class Lung(SegmentationModelBase):
train_batch_size=8,
inference_batch_size=1,
inference_stride_size=(64, 256, 256),
start_epoch=0,
num_epochs=140,
l_rate=1e-3,
min_l_rate=1e-5,

Просмотреть файл

@ -76,7 +76,6 @@ class ProstateBase(SegmentationModelBase):
num_epochs=120,
opt_eps=1e-4,
optimizer_type=OptimizerType.Adam,
start_epoch=0,
test_crop_size=(128, 512, 512),
train_batch_size=2,
use_mixed_precision=True,

Просмотреть файл

@ -4,20 +4,20 @@
# ------------------------------------------------------------------------------------------
import random
from typing import Any, List
import numpy as np
import pandas as pd
import torch
from torch.nn.parameter import Parameter
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.config import equally_weighted_classes, ModelArchitectureConfig, SegmentationModelBase
from InnerEye.ML.config import ModelArchitectureConfig, SegmentationModelBase, equally_weighted_classes
from InnerEye.ML.configs.segmentation.Lung import AZURE_DATASET_ID
from InnerEye.ML.models.architectures.base_model import BaseSegmentationModel
from InnerEye.ML.models.parallel.model_parallel import get_device_from_parameters, move_to_device
from InnerEye.ML.utils.model_metadata_util import generate_random_colours_list
from InnerEye.ML.utils.split_dataset import DatasetSplits
RANDOM_COLOUR_GENERATOR = random.Random(0)
RECTANGLE_STROKE_THICKNESS = 3
@ -48,7 +48,6 @@ class PassThroughModel(SegmentationModelBase):
inference_batch_size=1,
class_weights=equally_weighted_classes(fg_classes, background_weight=0.02),
feature_channels=[1],
start_epoch=0,
num_epochs=1,
# Necessary to avoid https://github.com/pytorch/pytorch/issues/45324
max_num_gpus=1,

Просмотреть файл

@ -19,9 +19,8 @@ from InnerEye.Common.common_util import is_windows
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR, DEFAULT_LOGS_DIR_NAME
from InnerEye.Common.generic_parsing import CudaAwareConfig, GenericConfig
from InnerEye.Common.type_annotations import PathOrString, TupleFloat2
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode, \
create_recovery_checkpoint_path, create_unique_timestamp_id, \
get_best_checkpoint_path
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode, create_unique_timestamp_id, \
get_best_checkpoint_path, get_recovery_checkpoint_path
# A folder inside of the outputs folder that will contain all information for running the model in inference mode
FINAL_MODEL_FOLDER = "final_model"
@ -352,7 +351,7 @@ class OutputParams(param.Parameterized):
"""
Returns the full path to a recovery checkpoint.
"""
return create_recovery_checkpoint_path(self.checkpoint_folder)
return get_recovery_checkpoint_path(self.checkpoint_folder)
def get_path_to_best_checkpoint(self) -> Path:
"""
@ -435,6 +434,11 @@ class TrainerParams(CudaAwareConfig):
doc="Save epoch checkpoints when epoch number is a multiple "
"of recovery_checkpoint_save_interval. The intended use "
"is to allow restore training from failed runs.")
recovery_checkpoints_save_last_k: int = param.Integer(default=1, bounds=(-1, None),
doc="Number of recovery checkpoints to keep. Recovery "
"checkpoints will be stored as recovery_epoch:{"
"epoch}.ckpt. If set to -1 keep all recovery "
"checkpoints.")
detect_anomaly: bool = param.Boolean(False, doc="If true, test gradients for anomalies (NaN or Inf) during "
"training.")
use_mixed_precision: bool = param.Boolean(False, doc="If true, mixed precision training is activated during "
@ -454,9 +458,6 @@ class TrainerParams(CudaAwareConfig):
doc="Controls the PyTorch Lightning trainer flags 'deterministic' and 'benchmark'. If "
"'pl_deterministic' is True, results are perfectly reproducible. If False, they are not, but "
"you may see training speed increases.")
start_epoch: int = param.Integer(0, bounds=(0, None), doc="The first epoch to train. Set to 0 to start a new "
"training. Set to a value larger than zero for starting"
" from a checkpoint.")
class DeepLearningConfig(WorkflowParams,
@ -546,6 +547,7 @@ class DeepLearningConfig(WorkflowParams,
# This should be annotated as torch.utils.data.Dataset, but we don't want to import torch here.
self._datasets_for_training: Optional[Dict[ModelExecutionMode, Any]] = None
self._datasets_for_inference: Optional[Dict[ModelExecutionMode, Any]] = None
self.recovery_start_epoch = 0
super().__init__(throw_if_unknown_param=True, **params)
logging.info("Creating the default output folder structure.")
self.create_filesystem(fixed_paths.repository_root_directory())
@ -609,7 +611,7 @@ class DeepLearningConfig(WorkflowParams,
Returns the epochs for which training will be performed.
:return:
"""
return list(range(self.start_epoch + 1, self.num_epochs + 1))
return list(range(self.recovery_start_epoch + 1, self.num_epochs + 1))
def get_total_number_of_training_epochs(self) -> int:
"""

Просмотреть файл

@ -20,8 +20,8 @@ from InnerEye.Common.metrics_constants import LoggingColumns, MetricType, TRAIN_
from InnerEye.Common.type_annotations import DictStrFloat
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.deep_learning_config import DatasetParams, DeepLearningConfig, WorkflowParams, OutputParams, \
TrainerParams
from InnerEye.ML.deep_learning_config import DatasetParams, DeepLearningConfig, OutputParams, TrainerParams, \
WorkflowParams
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.lightning_loggers import StoringLogger
from InnerEye.ML.metrics import EpochTimers, MAX_ITEM_LOAD_TIME_SEC, store_epoch_metrics

Просмотреть файл

@ -12,7 +12,7 @@ from typing import Any, Dict, Optional, Tuple, TypeVar
import numpy as np
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
@ -21,7 +21,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, change_working_directory
from InnerEye.Common.resource_monitor import ResourceMonitor
from InnerEye.ML.common import ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME, cleanup_checkpoint_folder
from InnerEye.ML.common import ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME, create_best_checkpoint
from InnerEye.ML.deep_learning_config import ARGS_TXT, VISUALIZATION_FOLDER
from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning
from InnerEye.ML.lightning_container import LightningContainer
@ -71,6 +71,26 @@ def write_args_file(config: Any, outputs_folder: Path) -> None:
logging.info(output)
class InnerEyeRecoveryCheckpointCallback(ModelCheckpoint):
"""
This callback is used to save recovery checkpoints.
In particular, it makes sure we are logging "epoch", this is needed to the last k
checkpoints (here save_top_k is based on the epoch number instead of validation loss,
PL only allows to save_top_k for logged quantities).
"""
def __init__(self, container: LightningContainer):
super().__init__(dirpath=str(container.checkpoint_folder),
monitor="epoch",
filename=RECOVERY_CHECKPOINT_FILE_NAME + "_{epoch}",
period=container.recovery_checkpoint_save_interval,
save_top_k=container.recovery_checkpoints_save_last_k,
mode="max")
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, outputs: Any) -> None:
pl_module.log(name="epoch", value=trainer.current_epoch)
def create_lightning_trainer(container: LightningContainer,
resume_from_checkpoint: Optional[Path] = None,
num_nodes: int = 1,
@ -95,14 +115,11 @@ def create_lightning_trainer(container: LightningContainer,
# monitor=f"{VALIDATION_PREFIX}{MetricType.LOSS.value}",
# save_top_k=1,
save_last=True)
# Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
# Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs. Due to a bug in Lightning, this
# will still write alternate files recovery.ckpt and recovery-v0.ckpt, which are cleaned up later in
# cleanup_checkpoint_folder
recovery_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
filename=RECOVERY_CHECKPOINT_FILE_NAME,
period=container.recovery_checkpoint_save_interval
)
# Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last
# recovery_checkpoints_save_last_k.
recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container)
num_gpus = torch.cuda.device_count() if container.use_gpu else 0
logging.info(f"Number of available GPUs: {num_gpus}")
@ -242,7 +259,7 @@ def model_train(checkpoint_handler: CheckpointHandler,
sys.exit()
logging.info("Choosing the best checkpoint and removing redundant files.")
cleanup_checkpoint_folder(container.checkpoint_folder)
create_best_checkpoint(container.checkpoint_folder)
# Lightning modifies a ton of environment variables. If we first run training and then the test suite,
# those environment variables will mislead the training runs in the test suite, and make them crash.
# Hence, restore the original environment after training.

Просмотреть файл

@ -16,6 +16,7 @@ from azureml.core import Run
from InnerEye.Azure.azure_config import AzureConfig
from InnerEye.Common import fixed_paths
from InnerEye.ML.common import find_recovery_checkpoint_and_epoch
from InnerEye.ML.deep_learning_config import OutputParams, WEIGHTS_FILE
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.utils.run_recovery import RunRecovery
@ -79,26 +80,18 @@ class CheckpointHandler:
def get_recovery_path_train(self) -> Optional[Path]:
"""
Decides the checkpoint path to use for the current training run. If a run recovery object is used, use the
checkpoint from there, otherwise use the checkpoints from the current run.
Decides the checkpoint path to use for the current training run. Looks for the latest checkpoint in the
checkpoint folder. If run_recovery is provided, the checkpoints will have been downloaded to this folder
prior to calling this function. Else, if the run gets pre-empted and automatically restarted in AML,
the latest checkpoint will be present in this folder too.
:return: Constructed checkpoint path to recover from.
"""
start_epoch = self.container.start_epoch
if start_epoch > 0 and not self.run_recovery:
raise ValueError("Start epoch is > 0, but no run recovery object has been provided to resume training.")
recovery = find_recovery_checkpoint_and_epoch(self.container.checkpoint_folder)
if recovery is not None:
local_recovery_path, recovery_epoch = recovery
self.container._start_epoch = recovery_epoch
return local_recovery_path
if self.run_recovery and start_epoch == 0:
raise ValueError("Run recovery set, but start epoch is 0. Please provide start epoch > 0 (for which a "
"checkpoint was saved in the previous run) to resume training from that run.")
if self.run_recovery:
# run_recovery takes first precedence over local_weights_path.
# This is to allow easy recovery of runs which have either of these parameters set in the config:
checkpoints = self.run_recovery.get_recovery_checkpoint_paths()
if len(checkpoints) > 1:
raise ValueError(f"Recovering training of ensemble runs is not supported. Found more than one "
f"checkpoint for epoch {start_epoch}")
return checkpoints[0]
elif self.local_weights_path:
return self.local_weights_path
else:
@ -232,9 +225,3 @@ class CheckpointHandler:
target_file = self.output_params.outputs_folder / WEIGHTS_FILE
torch.save(modified_weights, target_file)
return target_file
def should_load_optimizer_checkpoint(self) -> bool:
"""
Returns true if the optimizer should be loaded from checkpoint. Looks at the model config to determine this.
"""
return self.container.start_epoch > 0

Просмотреть файл

@ -13,8 +13,8 @@ from azureml.core import Run
from InnerEye.Azure.azure_util import RUN_CONTEXT, download_outputs_from_run, fetch_child_runs, tag_values_all_distinct
from InnerEye.Common.common_util import OTHER_RUNS_SUBDIR_NAME, check_properties_are_not_none
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, \
create_recovery_checkpoint_path, get_best_checkpoint_path
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, get_best_checkpoint_path, \
get_recovery_checkpoint_path
from InnerEye.ML.deep_learning_config import CHECKPOINT_FOLDER, OutputParams
@ -63,8 +63,7 @@ class RunRecovery:
@staticmethod
def download_all_checkpoints_from_run(config: OutputParams, run: Run) -> RunRecovery:
"""
Downloads all checkpoints of the provided run: The best checkpoint and the recovery checkpoint.
A single folder inside the checkpoints folder will be created that contains the downloaded checkpoints.
Downloads all checkpoints of the provided run inside the checkpoints folder.
:param config: Model related configs.
:param run: Run whose checkpoints should be recovered
:return: run recovery information
@ -72,16 +71,15 @@ class RunRecovery:
if fetch_child_runs(run):
raise ValueError(f"AzureML run {run.id} has child runs, this method does not support those.")
root_output_dir = config.checkpoint_folder / run.id
download_outputs_from_run(
blobs_path=Path(CHECKPOINT_FOLDER),
destination=root_output_dir,
destination=config.checkpoint_folder,
run=run
)
return RunRecovery(checkpoints_roots=[root_output_dir])
return RunRecovery(checkpoints_roots=[config.checkpoint_folder])
def get_recovery_checkpoint_paths(self) -> List[Path]:
return [create_recovery_checkpoint_path(x) for x in self.checkpoints_roots]
return [get_recovery_checkpoint_path(x) for x in self.checkpoints_roots]
def get_best_checkpoint_paths(self) -> List[Path]:
return [get_best_checkpoint_path(x) for x in self.checkpoints_roots]

Просмотреть файл

@ -40,10 +40,10 @@ from InnerEye.ML.utils.io_util import zip_random_dicom_series
from InnerEye.Scripts import submit_for_inference
from Tests.ML.util import assert_nifti_content, get_default_azure_config, get_nifti_shape
FALLBACK_ENSEMBLE_RUN = "refs_pull_432_merge:HD_3af84e4a-0043-4260-8be2-04ce9ab09b1f"
FALLBACK_SINGLE_RUN = "refs_pull_407_merge:refs_pull_407_merge_1614271518_cdbeb28e"
FALLBACK_2NODE_RUN = "refs_pull_385_merge:refs_pull_385_merge_1612421371_ba12a007"
FALLBACK_CV_GLAUCOMA = "refs_pull_432_merge_1618332810_b5d10d74"
FALLBACK_ENSEMBLE_RUN = "refs_pull_439_merge:HD_403627fe-c564-4e36-8ba3-c2915d64e220"
FALLBACK_SINGLE_RUN = "refs_pull_439_merge:refs_pull_439_merge_1618850856_cd910071"
FALLBACK_2NODE_RUN = "refs_pull_439_merge:refs_pull_439_merge_1618850855_4d2356f9"
FALLBACK_CV_GLAUCOMA = "refs_pull_439_merge:HD_252cdfa3-bce4-49c5-bf53-995ee3bcab4c"
def get_most_recent_run_id(fallback_run_id_for_local_execution: str = FALLBACK_SINGLE_RUN) -> str:

Просмотреть файл

@ -142,7 +142,7 @@ def test_parsing_with_custom_yaml(test_output_dirs: OutputFolderForTests) -> Non
yaml_file = test_output_dirs.root_dir / "custom.yml"
yaml_file.write_text("""variables:
tenant_id: 'foo'
start_epoch: 7
l_rate: 1e-4
random_seed: 1
""")
# Arguments partly to be set in AzureConfig, and partly in model config.
@ -160,8 +160,8 @@ def test_parsing_with_custom_yaml(test_output_dirs: OutputFolderForTests) -> Non
# This is only present in yaml
# This is present in yaml and command line, and the latter should be used.
assert runner.azure_config.tenant_id == "bar"
# Settings in model config: start_epoch is only in yaml
assert runner.model_config.start_epoch == 7
# Settings in model config: l_rate is only in yaml
assert runner.model_config.l_rate == 1e-4
# Settings in model config: num_epochs is only on commandline
assert runner.model_config.num_epochs == 42
# Settings in model config: random_seed is both in yaml and command line, the latter should be used

Просмотреть файл

@ -49,7 +49,6 @@ class DummyModel(SegmentationModelBase):
trim_percentiles=(1, 99),
inference_batch_size=1,
train_batch_size=2,
start_epoch=0,
num_epochs=2,
l_rate=1e-3,
l_rate_polynomial_gamma=0.9,

Просмотреть файл

@ -4,6 +4,7 @@
# ------------------------------------------------------------------------------------------
import io
import logging
import os
from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional
@ -21,16 +22,20 @@ from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.metrics_constants import LoggingColumns, MetricType
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML import model_testing, runner
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CHECKPOINT_SUFFIX, ModelExecutionMode, \
RECOVERY_CHECKPOINT_FILE_NAME
from InnerEye.ML.configs.classification.DummyClassification import DummyClassification
from InnerEye.ML.configs.classification.DummyMulticlassClassification import DummyMulticlassClassification
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
from InnerEye.ML.metrics import InferenceMetricsForClassification, binary_classification_accuracy, \
compute_scalar_metrics
from InnerEye.ML.metrics_dict import MetricsDict, ScalarMetricsDict
from InnerEye.ML.model_training import model_train
from InnerEye.ML.reports.notebook_report import generate_classification_multilabel_notebook, \
generate_classification_notebook, get_html_report_name, get_ipynb_report_name
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
from InnerEye.ML.utils.config_loader import ModelConfigLoader
from InnerEye.ML.visualizers.plot_cross_validation import EpochMetricValues, get_config_and_results_for_offline_runs, \
unroll_aggregate_metrics
@ -38,6 +43,7 @@ from Tests.ML.configs.ClassificationModelForTesting import ClassificationModelFo
from Tests.ML.configs.DummyModel import DummyModel
from Tests.ML.util import get_default_azure_config, machine_has_gpu, \
model_train_unittest
from Tests.ML.utils.test_model_util import FIXED_EPOCH, create_model_and_store_checkpoint
@pytest.mark.cpu_and_gpu
@ -326,7 +332,10 @@ def test_runner1(test_output_dirs: OutputFolderForTests) -> None:
"--random_seed", str(set_from_commandline),
"--non_image_feature_channels", scalar1,
"--output_to", output_root,
"--max_num_gpus", "1"
"--max_num_gpus", "1",
"--recovery_checkpoint_save_interval", "2",
"--recovery_checkpoints_save_last_k", "2",
"--num_epochs", "6",
]
with mock.patch("sys.argv", args):
config, _ = runner.run(project_root=fixed_paths.repository_root_directory(),
@ -337,6 +346,53 @@ def test_runner1(test_output_dirs: OutputFolderForTests) -> None:
assert config.non_image_feature_channels == ["label"]
assert str(config.outputs_folder).startswith(output_root)
assert (config.logs_folder / LOG_FILE_NAME).exists()
# Check that we saved one checkpoint every second epoch and that we kept only that last 2 and that last.ckpt has
# been renamed to best.ckpt
assert len(os.listdir(config.checkpoint_folder)) == 3
assert (config.checkpoint_folder / str(RECOVERY_CHECKPOINT_FILE_NAME + "_epoch=3" + CHECKPOINT_SUFFIX)).exists()
assert (config.checkpoint_folder / str(RECOVERY_CHECKPOINT_FILE_NAME + "_epoch=5" + CHECKPOINT_SUFFIX)).exists()
assert (config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).exists()
@pytest.mark.skipif(common_util.is_windows(), reason="Has OOM issues on windows build")
def test_runner_restart(test_output_dirs: OutputFolderForTests) -> None:
"""
Test if starting training from a folder where the checkpoints folder already has recovery checkpoints picks up
that it is a recovery run. Also checks that we update the start epoch in the config at loading time.
"""
model_config = DummyClassification()
model_config.set_output_to(test_output_dirs.root_dir)
model_config.num_epochs = FIXED_EPOCH + 2
# We save all checkpoints - if recovery works as expected we should have a new checkpoint for epoch 4, 5.
model_config.recovery_checkpoint_save_interval = 1
model_config.recovery_checkpoints_save_last_k = -1
runner = MLRunner(model_config=model_config)
runner.setup(use_mount_or_download_dataset=False)
# Epochs are 0 based for saving
create_model_and_store_checkpoint(model_config,
runner.container.checkpoint_folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}",
weights_only=False)
azure_config = get_default_azure_config()
checkpoint_handler = CheckpointHandler(azure_config=azure_config,
container=runner.container,
project_root=test_output_dirs.root_dir)
_, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
container=runner.container)
# We expect to have 4 checkpoints, FIXED_EPOCH (recovery), FIXED_EPOCH+1, FIXED_EPOCH and best.
assert len(os.listdir(runner.container.checkpoint_folder)) == 4
assert (
runner.container.checkpoint_folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}").exists()
assert (
runner.container.checkpoint_folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
f"{FIXED_EPOCH}{CHECKPOINT_SUFFIX}").exists()
assert (
runner.container.checkpoint_folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
f"{FIXED_EPOCH + 1}{CHECKPOINT_SUFFIX}").exists()
assert (runner.container.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).exists()
# Check that we really restarted epoch from epoch FIXED_EPOCH.
assert list(storing_logger.epochs) == [FIXED_EPOCH, FIXED_EPOCH + 1] # type: ignore
@pytest.mark.skipif(common_util.is_windows(), reason="Has OOM issues on windows build")

Просмотреть файл

@ -19,9 +19,9 @@ from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, is_windows, l
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.metrics_constants import MetricType, TrackedMetrics, VALIDATION_PREFIX
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, DATASET_CSV_FILE_NAME, ModelExecutionMode, \
RECOVERY_CHECKPOINT_FILE_NAME_WITH_SUFFIX, \
STORED_CSV_FILE_NAMES
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CHECKPOINT_SUFFIX, DATASET_CSV_FILE_NAME, \
ModelExecutionMode, \
RECOVERY_CHECKPOINT_FILE_NAME, STORED_CSV_FILE_NAMES
from InnerEye.ML.config import MixtureLossComponent, SegmentationLoss
from InnerEye.ML.configs.classification.DummyClassification import DummyClassification
from InnerEye.ML.dataset.sample import CroppedSample
@ -43,9 +43,10 @@ base_path = full_ml_test_data_path()
def test_get_total_number_of_training_epochs() -> None:
c = DeepLearningConfig(num_epochs=2, should_validate=False)
assert c.get_total_number_of_training_epochs() == 2
c = DeepLearningConfig(num_epochs=10, start_epoch=5, should_validate=False)
assert c.get_total_number_of_training_epochs() == 5
c = DeepLearningConfig(num_epochs=10, should_validate=False)
# Fake recovering training
c.recovery_start_epoch = 2
assert c.get_total_number_of_training_epochs() == 8
@pytest.mark.parametrize("image_channels", [["region"], ["random_123"]])
@pytest.mark.parametrize("ground_truth_ids", [["region", "region"], ["region", "other_region"]])
@ -176,7 +177,8 @@ def _test_model_train(output_dirs: OutputFolderForTests,
assert train_config.checkpoint_folder.is_dir()
actual_checkpoints = list(train_config.checkpoint_folder.rglob("*.ckpt"))
assert len(actual_checkpoints) == 2, f"Actual checkpoints: {actual_checkpoints}"
assert (train_config.checkpoint_folder / RECOVERY_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
assert (train_config.checkpoint_folder / str(
RECOVERY_CHECKPOINT_FILE_NAME + f"_epoch={train_config.num_epochs - 1}" + CHECKPOINT_SUFFIX)).is_file()
assert (train_config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
assert (train_config.outputs_folder / DATASET_CSV_FILE_NAME).is_file()
assert (train_config.outputs_folder / STORED_CSV_FILE_NAMES[ModelExecutionMode.TRAIN]).is_file()
@ -324,7 +326,6 @@ def test_recover_training_mean_teacher_model(test_output_dirs: OutputFolderForTe
assert len(list(config.checkpoint_folder.glob("*.*"))) == 2
# Restart training from previous run
config.start_epoch = 2
config.num_epochs = 3
config.set_output_to(test_output_dirs.root_dir / "recovered")
os.makedirs(str(config.outputs_folder))

Просмотреть файл

@ -14,7 +14,7 @@ import torch
from InnerEye.Common.common_util import OTHER_RUNS_SUBDIR_NAME
from InnerEye.Common.fixed_paths import MODEL_WEIGHTS_DIR_NAME
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, create_recovery_checkpoint_path
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, get_recovery_checkpoint_path
from InnerEye.ML.deep_learning_config import WEIGHTS_FILE
from InnerEye.ML.model_config_base import ModelConfigBase
from Tests.AfterTraining.test_after_training import FALLBACK_ENSEMBLE_RUN, FALLBACK_SINGLE_RUN, get_most_recent_run, \
@ -83,8 +83,8 @@ def test_download_checkpoints_from_single_run(test_output_dirs: OutputFolderForT
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.run_recovery
expected_checkpoint_root = config.checkpoint_folder / run_recovery_id.split(":")[1]
expected_paths = [create_recovery_checkpoint_path(path=expected_checkpoint_root),
expected_checkpoint_root = config.checkpoint_folder
expected_paths = [get_recovery_checkpoint_path(path=expected_checkpoint_root),
expected_checkpoint_root / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
assert checkpoint_handler.run_recovery.checkpoints_roots == [expected_checkpoint_root]
for path in expected_paths:
@ -121,13 +121,7 @@ def test_get_recovery_path_train(test_output_dirs: OutputFolderForTests) -> None
checkpoint_handler.container.weights_url = EXTERNAL_WEIGHTS_URL_EXAMPLE
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.local_weights_path == expected_path
checkpoint_handler.container.start_epoch = 0
assert checkpoint_handler.get_recovery_path_train() == expected_path
# Can't resume training from an external checkpoint
checkpoint_handler.container.start_epoch = 20
with pytest.raises(ValueError) as ex:
checkpoint_handler.get_recovery_path_train()
assert ex.value.args[0] == "Start epoch is > 0, but no run recovery object has been provided to resume training."
# Set a local_weights_path to get checkpoint from
checkpoint_handler.container.weights_url = ""
@ -136,13 +130,7 @@ def test_get_recovery_path_train(test_output_dirs: OutputFolderForTests) -> None
checkpoint_handler.container.local_weights_path = local_weights_path
checkpoint_handler.download_recovery_checkpoints_or_weights()
assert checkpoint_handler.local_weights_path == expected_path
checkpoint_handler.container.start_epoch = 0
assert checkpoint_handler.get_recovery_path_train() == expected_path
# Can't resume training from an external checkpoint
checkpoint_handler.container.start_epoch = 20
with pytest.raises(ValueError) as ex:
checkpoint_handler.get_recovery_path_train()
assert ex.value.args[0] == "Start epoch is > 0, but no run recovery object has been provided to resume training."
@pytest.mark.after_training_single_run
@ -157,14 +145,8 @@ def test_get_recovery_path_train_single_run(test_output_dirs: OutputFolderForTes
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
# We have not set a start_epoch but we are trying to use run_recovery, this should fail
with pytest.raises(ValueError) as ex:
checkpoint_handler.get_recovery_path_train()
assert "Run recovery set, but start epoch is 0" in ex.value.args[0]
# Run recovery with start epoch provided should succeed
checkpoint_handler.container.start_epoch = 20
expected_path = create_recovery_checkpoint_path(path=config.checkpoint_folder / run_recovery_id.split(":")[1])
expected_path = get_recovery_checkpoint_path(path=config.checkpoint_folder)
assert checkpoint_handler.get_recovery_path_train() == expected_path
@ -186,8 +168,7 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
# in the run, into a subfolder of the checkpoint folder
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
expected_checkpoint = config.checkpoint_folder / run_recovery_id.split(":")[1] \
/ f"{BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}"
expected_checkpoint = config.checkpoint_folder / f"{BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}"
checkpoint_paths = checkpoint_handler.get_best_checkpoint()
assert checkpoint_paths
assert len(checkpoint_paths) == 1
@ -201,11 +182,9 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
checkpoint_handler.download_recovery_checkpoints_or_weights()
checkpoint_handler.container.start_epoch = 1
# There is no checkpoint in the current run - use the one from run_recovery
checkpoint_paths = checkpoint_handler.get_best_checkpoint()
expected_checkpoint = config.checkpoint_folder / run_recovery_id.split(":")[1] \
/ BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
assert checkpoint_paths
assert len(checkpoint_paths) == 1
assert checkpoint_paths[0] == expected_checkpoint
@ -254,7 +233,6 @@ def test_get_checkpoints_to_test(test_output_dirs: OutputFolderForTests) -> None
assert len(checkpoint_and_paths) == 1
assert checkpoint_and_paths[0] == manage_recovery.output_params.outputs_folder / WEIGHTS_FILE
manage_recovery.container.start_epoch = 1
manage_recovery.additional_training_done()
manage_recovery.container.checkpoint_folder.mkdir()
@ -281,7 +259,7 @@ def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTes
# Now set a run recovery object and set the start epoch to 1, so we get one epoch from
# run recovery and one from the training checkpoints
manage_recovery.azure_config.run_recovery_id = run_recovery_id
config.start_epoch = 1
manage_recovery.additional_training_done()
manage_recovery.download_recovery_checkpoints_or_weights()
@ -289,8 +267,7 @@ def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTes
assert checkpoint_and_paths
assert len(checkpoint_and_paths) == 1
assert checkpoint_and_paths[0] == config.checkpoint_folder / run_recovery_id.split(":")[1] / \
BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
assert checkpoint_and_paths[0] == config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
# Copy checkpoint to make it seem like training has happened
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX

Просмотреть файл

@ -32,36 +32,6 @@ def enumerate_scheduler(scheduler: _LRScheduler, steps: int) -> List[float]:
return lrs
def test_create_lr_scheduler_last_epoch() -> None:
"""
Test to check if the lr scheduler is initialized to the correct epoch
"""
l_rate = 1e-3
gamma = 0.5
total_epochs = 5
expected_lrs_per_epoch = [l_rate * (gamma ** i) for i in range(total_epochs)]
config = DummyModel()
config.l_rate = l_rate
config.l_rate_scheduler = LRSchedulerType.Step
config.l_rate_step_step_size = 1
config.l_rate_step_gamma = gamma
# create lr scheduler
initial_scheduler, initial_optimizer = _create_lr_scheduler_and_optimizer(config)
# check lr scheduler initialization step
initial_epochs = 3
assert np.allclose(enumerate_scheduler(initial_scheduler, initial_epochs), expected_lrs_per_epoch[:initial_epochs])
# create lr scheduler for recovery checkpoint
config.start_epoch = initial_epochs
recovery_scheduler, recovery_optimizer = _create_lr_scheduler_and_optimizer(config)
# Both the scheduler and the optimizer need to be loaded from the checkpoint.
recovery_scheduler.load_state_dict(initial_scheduler.state_dict())
recovery_optimizer.load_state_dict(initial_optimizer.state_dict())
assert recovery_scheduler.last_epoch == config.start_epoch
# check lr scheduler initialization matches the checkpoint epoch
# as training will start for start_epoch + 1 in this case
assert np.allclose(enumerate_scheduler(recovery_scheduler, 2), expected_lrs_per_epoch[initial_epochs:])
@pytest.mark.parametrize("lr_scheduler_type", [x for x in LRSchedulerType])
def test_lr_monotonically_decreasing_function(lr_scheduler_type: LRSchedulerType) -> None:
"""
@ -220,45 +190,6 @@ def test_lr_scheduler_with_warmup(warmup_epochs: int, expected_values: List[floa
assert lrs == expected_values
# Exclude Polynomial scheduler because that uses lambdas, which we can't save to a state dict
@pytest.mark.parametrize("lr_scheduler_type", [x for x in LRSchedulerType if x != LRSchedulerType.Polynomial])
@pytest.mark.parametrize("warmup_epochs", [0, 3, 4, 5])
def test_resume_from_saved_state(lr_scheduler_type: LRSchedulerType, warmup_epochs: int) -> None:
"""
Tests if LR scheduler when restarted from an epoch continues as expected.
"""
restart_from_epoch = 4
config = DummyModel(num_epochs=7,
l_rate_scheduler=lr_scheduler_type,
l_rate_exponential_gamma=0.9,
l_rate_step_gamma=0.9,
l_rate_step_step_size=2,
l_rate_multi_step_gamma=0.9,
l_rate_multi_step_milestones=[3, 5, 7],
l_rate_polynomial_gamma=0.9,
l_rate_warmup=LRWarmUpType.Linear if warmup_epochs > 0 else LRWarmUpType.NoWarmUp,
l_rate_warmup_epochs=warmup_epochs)
# This scheduler mimics what happens if we train for the full set of epochs
scheduler_all_epochs, _ = _create_lr_scheduler_and_optimizer(config)
expected_lr_list = enumerate_scheduler(scheduler_all_epochs, config.num_epochs)
# Create a scheduler where training will be recovered
scheduler1, optimizer1 = _create_lr_scheduler_and_optimizer(config)
# Scheduler 1 is only run for 4 epochs, and then "restarted" to train the rest of the epochs.
result_lr_list = enumerate_scheduler(scheduler1, restart_from_epoch)
# resume state: This just means setting start_epoch in the config
config.start_epoch = restart_from_epoch
scheduler_resume, optimizer_resume = _create_lr_scheduler_and_optimizer(config)
# Load a "checkpoint" for both scheduler and optimizer
scheduler_resume.load_state_dict(scheduler1.state_dict())
optimizer_resume.load_state_dict(optimizer1.state_dict())
result_lr_list.extend(enumerate_scheduler(scheduler_resume, config.num_epochs - restart_from_epoch))
print(f"Actual schedule: {result_lr_list}")
print(f"Expected schedule: {expected_lr_list}")
assert len(result_lr_list) == len(expected_lr_list)
assert np.allclose(result_lr_list, expected_lr_list)
@pytest.mark.parametrize("lr_scheduler_type", [x for x in LRSchedulerType])
def test_save_and_load_state_dict(lr_scheduler_type: LRSchedulerType) -> None:
def object_dict_same(lr1: SchedulerWithWarmUp, lr2: SchedulerWithWarmUp) -> bool:

Просмотреть файл

@ -12,8 +12,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, LAST_CHECKPOINT_FILE_NAME, \
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, RECOVERY_CHECKPOINT_FILE_NAME, RECOVERY_CHECKPOINT_FILE_NAME_WITH_SUFFIX, \
cleanup_checkpoint_folder, keep_best_checkpoint, keep_latest
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, RECOVERY_CHECKPOINT_FILE_NAME, create_best_checkpoint, \
extract_latest_checkpoint_and_epoch, find_all_recovery_checkpoints, find_recovery_checkpoint_and_epoch
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.lightning_base import InnerEyeContainer
from InnerEye.ML.lightning_helpers import load_from_checkpoint_and_adjust_for_inference
@ -28,7 +28,8 @@ FIXED_EPOCH = 42
FIXED_GLOBAL_STEP = 4242
def create_model_and_store_checkpoint(config: ModelConfigBase, checkpoint_path: Path) -> None:
def create_model_and_store_checkpoint(config: ModelConfigBase, checkpoint_path: Path,
weights_only: bool = True) -> None:
"""
Creates a Lightning model for the given model configuration, and stores it as a checkpoint file.
If a GPU is available, the model is moved to the GPU before storing.
@ -48,7 +49,7 @@ def create_model_and_store_checkpoint(config: ModelConfigBase, checkpoint_path:
trainer.global_step = FIXED_GLOBAL_STEP - 1
# In PL, it is the Trainer's responsibility to save the model. Checkpoint handling refers back to the trainer
# to get a save_func. Mimicking that here.
trainer.save_checkpoint(checkpoint_path, weights_only=True)
trainer.save_checkpoint(checkpoint_path, weights_only=weights_only)
@pytest.mark.cpu_and_gpu
@ -95,38 +96,66 @@ def test_checkpoint_path() -> None:
assert LAST_CHECKPOINT_FILE_NAME == ModelCheckpoint.CHECKPOINT_NAME_LAST
def test_keep_latest(test_output_dirs: OutputFolderForTests) -> None:
def test_find_all_recovery_checkpoints(test_output_dirs: OutputFolderForTests) -> None:
checkpoint_folder = test_output_dirs.root_dir
# No recovery yet available
(checkpoint_folder / "epoch=2.ckpt").touch()
assert find_all_recovery_checkpoints(checkpoint_folder) is None
# Add recovery file to fake folder
file_list = ["recovery_epoch=1.ckpt", "recovery.ckpt"]
for f in file_list:
(checkpoint_folder / f).touch()
found_file_names = set([f.stem for f in find_all_recovery_checkpoints(checkpoint_folder)]) # type: ignore
assert len(found_file_names.difference(found_file_names)) == 0
def test_find_latest_checkpoint_and_epoch() -> None:
file_list = [Path("epoch=1.ckpt"), Path("epoch=3.ckpt"), Path("epoch=2.ckpt")]
assert Path("epoch=3.ckpt"), 3 == extract_latest_checkpoint_and_epoch(file_list)
invalid_file_list = [Path("epoch.ckpt"), Path("epoch=3.ckpt"), Path("epoch=2.ckpt")]
with pytest.raises(IndexError):
extract_latest_checkpoint_and_epoch(invalid_file_list)
def test_find_recovery_checkpoint(test_output_dirs: OutputFolderForTests) -> None:
"""
Test if the logic to keep only the most recently modified file works.
"""
folder = test_output_dirs.root_dir
prefix = "foo"
pattern = prefix + "*"
file1 = folder / (prefix + ".txt")
file2 = folder / (prefix + "2.txt")
prefix = RECOVERY_CHECKPOINT_FILE_NAME
file1 = folder / (prefix + "epoch=1.txt")
file2 = folder / (prefix + "epoch=2.txt")
# No file present yet
assert keep_latest(folder, pattern) is None
assert find_recovery_checkpoint_and_epoch(folder) is None
# Single file present: This should be returned.
file1.touch()
# Without sleeping, the test can fail in Azure build agents
time.sleep(0.1)
latest = keep_latest(folder, pattern)
assert latest == file1
assert latest.is_file()
# Two files present: keep file2, file1 should be deleted
recovery = find_recovery_checkpoint_and_epoch(folder)
assert recovery is not None
latest_checkpoint, latest_epoch = recovery
assert latest_checkpoint == file1
assert latest_epoch == 1
assert latest_checkpoint.is_file()
# Two files present: keep file2 should be returned
file2.touch()
time.sleep(0.1)
latest = keep_latest(folder, pattern)
assert latest == file2
assert latest.is_file()
assert not file1.is_file()
# Add file1 again: Now this one should be the most recent one
recovery = find_recovery_checkpoint_and_epoch(folder)
assert recovery is not None
latest_checkpoint, latest_epoch = recovery
assert latest_checkpoint == file2
assert latest_checkpoint.is_file()
assert latest_epoch == 2
# Add file1 again: file should should still be returned as it has the
# highest epoch number
file1.touch()
time.sleep(0.1)
latest = keep_latest(folder, pattern)
assert latest == file1
assert latest.is_file()
assert not file2.is_file()
recovery = find_recovery_checkpoint_and_epoch(folder)
assert recovery is not None
latest_checkpoint, latest_epoch = recovery
assert latest_checkpoint == file2
assert latest_checkpoint.is_file()
assert latest_epoch == 2
def test_keep_best_checkpoint(test_output_dirs: OutputFolderForTests) -> None:
@ -135,11 +164,11 @@ def test_keep_best_checkpoint(test_output_dirs: OutputFolderForTests) -> None:
"""
folder = test_output_dirs.root_dir
with pytest.raises(FileNotFoundError) as ex:
keep_best_checkpoint(folder)
create_best_checkpoint(folder)
assert "Checkpoint file" in str(ex)
last = folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
last.touch()
actual = keep_best_checkpoint(folder)
actual = create_best_checkpoint(folder)
assert not last.is_file(), "Checkpoint file should have been renamed"
expected = folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
assert actual == expected
@ -149,12 +178,12 @@ def test_keep_best_checkpoint(test_output_dirs: OutputFolderForTests) -> None:
def test_cleanup_checkpoints1(test_output_dirs: OutputFolderForTests) -> None:
folder = test_output_dirs.root_dir
with pytest.raises(FileNotFoundError) as ex:
cleanup_checkpoint_folder(folder)
create_best_checkpoint(folder)
assert "Checkpoint file" in str(ex)
# Single checkpoint file, nothing else: This file should be rename to best_checkpoint
last = folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
last.touch()
cleanup_checkpoint_folder(folder)
create_best_checkpoint(folder)
assert len(list(folder.glob("*"))) == 1
assert (folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
@ -164,9 +193,12 @@ def test_cleanup_checkpoints2(test_output_dirs: OutputFolderForTests) -> None:
folder = test_output_dirs.root_dir
last = folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
last.touch()
(folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-v0").touch()
(folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-v1").touch()
cleanup_checkpoint_folder(folder)
assert len(list(folder.glob("*"))) == 2
(folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-epoch=3").touch()
(folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-epoch=6").touch()
# Before cleanup: last.ckpt, recovery-epoch=6.ckpt, recovery-epoch=3.ckpt
create_best_checkpoint(folder)
# After cleanup: best.ckpt, recovery-epoch=6.ckpt, recovery-epoch=3.ckpt
assert len(list(folder.glob("*"))) == 3
assert (folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
assert (folder / RECOVERY_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
assert (folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-epoch=6").is_file()
assert (folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-epoch=3").is_file()

Просмотреть файл

@ -18,7 +18,7 @@ steps:
branch_name_without_prefix=${full_branch_name#$branch_prefix}
python $(Agent.TempDirectory)/InnerEye/TestSubmodule/test_submodule_runner.py --azureml=True --model="$(model)" --train="$(train)" $(more_switches) --wait_for_completion="${{parameters.wait_for_completion}}" --max_run_duration="${{parameters.max_run_duration}}" --cluster="$(cluster)" --tag="$(tag)" --build_number=$(Build.BuildId) --build_user="$(Build.RequestedFor)" --build_user_email="" --build_branch="$branch_name_without_prefix" --build_source_repository="$(Build.Repository.Name)" --monitoring_interval_seconds=5 --show_patch_sampling=0
mv most_recent_run.txt training_run.txt
python $(Agent.TempDirectory)/InnerEye/TestSubmodule/test_submodule_runner.py --run_recovery_id=`cat training_run.txt` --start_epoch=2 --num_epochs=4 --azureml=True --model="$(model)" --train="$(train)" $(more_switches) --wait_for_completion="${{parameters.wait_for_completion}}" --max_run_duration="${{parameters.max_run_duration}}" --cluster="$(cluster)" --tag="$(tag)" --build_number=$(Build.BuildId) --build_user="$(Build.RequestedFor)" --build_user_email="" --build_branch="$branch_name_without_prefix" --build_source_id="$(Build.SourceVersion)" --build_source_message="$source_version_message" --build_source_author="$(Build.SourceVersionAuthor)" --build_source_repository="$(Build.Repository.Name)" --show_patch_sampling=0
python $(Agent.TempDirectory)/InnerEye/TestSubmodule/test_submodule_runner.py --run_recovery_id=`cat training_run.txt` --num_epochs=4 --azureml=True --model="$(model)" --train="$(train)" $(more_switches) --wait_for_completion="${{parameters.wait_for_completion}}" --max_run_duration="${{parameters.max_run_duration}}" --cluster="$(cluster)" --tag="$(tag)" --build_number=$(Build.BuildId) --build_user="$(Build.RequestedFor)" --build_user_email="" --build_branch="$branch_name_without_prefix" --build_source_id="$(Build.SourceVersion)" --build_source_message="$source_version_message" --build_source_author="$(Build.SourceVersionAuthor)" --build_source_repository="$(Build.Repository.Name)" --show_patch_sampling=0
mv training_run.txt most_recent_run.txt
env:
PYTHONPATH: $(Agent.TempDirectory)/InnerEye