Improve recovery of preempted jobs (#633)
Autosaving checkpoints by default every 1 epoch to a fixed file name. Retiring the "top k" recovery checkpoint notion because that was tied to specific models that needed more than 1 checkpoint.
This commit is contained in:
Родитель
25db288768
Коммит
ccb53d01ad
|
@ -3,7 +3,6 @@
|
|||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/fastMRI" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/hi-ml/src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/hi-ml-azure/src" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/hi-ml" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
|
@ -72,6 +72,8 @@ gets uploaded to AzureML, by skipping all test folders.
|
|||
- ([#596](https://github.com/microsoft/InnerEye-DeepLearning/pull/596)) Add `cudatoolkit=11.1` specification to environment.yml.
|
||||
- ([#615](https://github.com/microsoft/InnerEye-DeepLearning/pull/615)) Minor changes to checkpoint download from AzureML.
|
||||
- ([#605](https://github.com/microsoft/InnerEye-DeepLearning/pull/605)) Make build jobs deterministic for regression testing.
|
||||
- ([#633](https://github.com/microsoft/InnerEye-DeepLearning/pull/633)) Model training now only writes one recovery checkpoint, rather than multiple ones. Frequency is controlled by
|
||||
`autosave_every_n_val_epochs`.
|
||||
- ([#632](https://github.com/microsoft/InnerEye-DeepLearning/pull/632)) Nifti test data is no longer stored in Git LFS
|
||||
|
||||
### Fixed
|
||||
|
@ -125,6 +127,9 @@ in inference-only runs when using lightning containers.
|
|||
|
||||
### Deprecated
|
||||
|
||||
- ([#633](https://github.com/microsoft/InnerEye-DeepLearning/pull/633)) Model fields `recovery_checkpoint_save_interval` and `recovery_checkpoints_save_last_k` have been retired.
|
||||
Recovery checkpoint handling is now controlled by `autosave_every_n_val_epochs`.
|
||||
|
||||
|
||||
## 0.3 (2021-06-01)
|
||||
|
||||
|
|
|
@ -8,12 +8,12 @@ import re
|
|||
from pathlib import Path
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from azureml.core import Experiment, Run, Workspace, get_run
|
||||
from azureml.core import Experiment, Run, Workspace
|
||||
from azureml.exceptions import UserErrorException
|
||||
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME
|
||||
from health_azure.utils import create_run_recovery_id
|
||||
from health_azure.utils import create_run_recovery_id, get_aml_run_from_run_id
|
||||
|
||||
DEFAULT_CROSS_VALIDATION_SPLIT_INDEX = -1
|
||||
EXPERIMENT_RUN_SEPARATOR = ":"
|
||||
|
@ -79,40 +79,7 @@ def fetch_run(workspace: Workspace, run_recovery_id: str) -> Run:
|
|||
or just the run_id
|
||||
:return: The AzureML run.
|
||||
"""
|
||||
experiment, run = split_recovery_id(run_recovery_id)
|
||||
try:
|
||||
experiment_to_recover = Experiment(workspace, experiment)
|
||||
except Exception as ex:
|
||||
raise Exception(
|
||||
f"Unable to retrieve run {run} in experiment {experiment}: {str(ex)}"
|
||||
)
|
||||
run_to_recover = fetch_run_for_experiment(experiment_to_recover, run)
|
||||
logging.info(
|
||||
"Fetched run #{} {} from experiment {}.".format(
|
||||
run, run_to_recover.number, experiment
|
||||
)
|
||||
)
|
||||
return run_to_recover
|
||||
|
||||
|
||||
def fetch_run_for_experiment(experiment_to_recover: Experiment, run_id: str) -> Run:
|
||||
"""
|
||||
:param experiment_to_recover: an experiment
|
||||
:param run_id: a string representing the Run ID of one of the runs of the experiment
|
||||
:return: the run matching run_id_or_number; raises an exception if not found
|
||||
"""
|
||||
try:
|
||||
return get_run(experiment=experiment_to_recover, run_id=run_id, rehydrate=True)
|
||||
except Exception:
|
||||
available_runs = experiment_to_recover.get_runs()
|
||||
available_ids = ", ".join([run.id for run in available_runs])
|
||||
raise (
|
||||
Exception(
|
||||
"Run {} not found for experiment: {}. Available runs are: {}".format(
|
||||
run_id, experiment_to_recover.name, available_ids
|
||||
)
|
||||
)
|
||||
)
|
||||
return get_aml_run_from_run_id(aml_workspace=workspace, run_id=run_recovery_id)
|
||||
|
||||
|
||||
def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]:
|
||||
|
@ -133,9 +100,9 @@ def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]:
|
|||
|
||||
|
||||
def fetch_child_runs(
|
||||
run: Run,
|
||||
status: Optional[str] = None,
|
||||
expected_number_cross_validation_splits: int = 0,
|
||||
run: Run,
|
||||
status: Optional[str] = None,
|
||||
expected_number_cross_validation_splits: int = 0,
|
||||
) -> List[Run]:
|
||||
"""
|
||||
Fetch child runs for the provided runs that have the provided AML status (or fetch all by default)
|
||||
|
@ -312,7 +279,7 @@ def download_run_output_file(blob_path: Path, destination: Path, run: Run) -> Pa
|
|||
|
||||
|
||||
def download_run_outputs_by_prefix(
|
||||
blobs_prefix: Path, destination: Path, run: Run
|
||||
blobs_prefix: Path, destination: Path, run: Run
|
||||
) -> None:
|
||||
"""
|
||||
Download all the blobs from the run's default output directory: DEFAULT_AML_UPLOAD_DIR ("outputs") that
|
||||
|
@ -354,7 +321,7 @@ def is_running_on_azure_agent() -> bool:
|
|||
|
||||
|
||||
def get_comparison_baseline_paths(
|
||||
outputs_folder: Path, blob_path: Path, run: Run, dataset_csv_file_name: str
|
||||
outputs_folder: Path, blob_path: Path, run: Run, dataset_csv_file_name: str
|
||||
) -> Tuple[Optional[Path], Optional[Path]]:
|
||||
run_rec_id = run.id
|
||||
# We usually find dataset.csv in the same directory as metrics.csv, but we sometimes
|
||||
|
|
|
@ -13,11 +13,15 @@ from typing import Any, Dict, List
|
|||
DATASET_CSV_FILE_NAME = "dataset.csv"
|
||||
CHECKPOINT_SUFFIX = ".ckpt"
|
||||
|
||||
RECOVERY_CHECKPOINT_FILE_NAME = "recovery"
|
||||
RECOVERY_CHECKPOINT_FILE_NAME_WITH_SUFFIX = RECOVERY_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX
|
||||
# The file names for the legacy "recovery" checkpoints behaviour, which stored the most recent N checkpoints
|
||||
LEGACY_RECOVERY_CHECKPOINT_FILE_NAME = "recovery"
|
||||
|
||||
BEST_CHECKPOINT_FILE_NAME = "best_checkpoint"
|
||||
BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX = BEST_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX
|
||||
# The file names for the new recovery checkpoint behaviour: A single fixed checkpoint that is written every N epochs.
|
||||
# Lightning does not overwrite files in place, and will hence create files "autosave.ckpt", "autosave-v1.ckpt"
|
||||
# alternatingly
|
||||
AUTOSAVE_CHECKPOINT_FILE_NAME = "autosave"
|
||||
AUTOSAVE_CHECKPOINT_CANDIDATES = [AUTOSAVE_CHECKPOINT_FILE_NAME + CHECKPOINT_SUFFIX,
|
||||
AUTOSAVE_CHECKPOINT_FILE_NAME + "-v1" + CHECKPOINT_SUFFIX]
|
||||
|
||||
# This is a constant that must match a filename defined in pytorch_lightning.ModelCheckpoint, but we don't want
|
||||
# to import that here.
|
||||
|
@ -84,4 +88,4 @@ def get_best_checkpoint_path(path: Path) -> Path:
|
|||
Given a path and checkpoint, formats a path based on the checkpoint file name format.
|
||||
:param path to checkpoint folder
|
||||
"""
|
||||
return path / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
return path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
|
|
|
@ -56,8 +56,6 @@ class DeepSMILECrck(BaseMIL):
|
|||
# To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI
|
||||
# declared in TrainerParams:
|
||||
num_epochs=16,
|
||||
recovery_checkpoint_save_interval=16,
|
||||
recovery_checkpoints_save_last_k=-1,
|
||||
# declared in WorkflowParams:
|
||||
number_of_cross_validation_splits=5,
|
||||
cross_validation_split_index=0,
|
||||
|
|
|
@ -41,8 +41,6 @@ class DeepSMILEPanda(BaseMIL):
|
|||
# To mount the dataset instead of downloading in AML, pass --use_dataset_mount in the CLI
|
||||
# declared in TrainerParams:
|
||||
num_epochs=200,
|
||||
recovery_checkpoint_save_interval=10,
|
||||
recovery_checkpoints_save_last_k=-1,
|
||||
# use_mixed_precision = True,
|
||||
# declared in WorkflowParams:
|
||||
number_of_cross_validation_splits=5,
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Any
|
|||
import pandas as pd
|
||||
|
||||
from InnerEye.ML.config import PhotometricNormalizationMethod, SegmentationModelBase, equally_weighted_classes
|
||||
from InnerEye.ML.configs.segmentation.Lung import AZURE_DATASET_ID
|
||||
from InnerEye.ML.configs.segmentation.Lung import LUNG_AZURE_DATASET_ID
|
||||
from InnerEye.ML.deep_learning_config import LRSchedulerType
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
|
||||
|
@ -40,9 +40,8 @@ class BasicModel2Epochs(SegmentationModelBase):
|
|||
num_dataload_workers=1,
|
||||
train_batch_size=8,
|
||||
num_epochs=2,
|
||||
recovery_checkpoint_save_interval=1,
|
||||
use_mixed_precision=True,
|
||||
azure_dataset_id=AZURE_DATASET_ID,
|
||||
azure_dataset_id=LUNG_AZURE_DATASET_ID,
|
||||
comparison_blob_storage_paths=comparison_blob_storage_paths,
|
||||
inference_on_val_set=True,
|
||||
inference_on_test_set=True,
|
||||
|
|
|
@ -50,7 +50,6 @@ class GbmBase(SegmentationModelBase):
|
|||
adam_betas=(0.9, 0.999),
|
||||
momentum=0.9,
|
||||
weight_decay=1e-4,
|
||||
recovery_checkpoint_save_interval=10,
|
||||
use_mixed_precision=True,
|
||||
use_model_parallel=True,
|
||||
)
|
||||
|
|
|
@ -86,7 +86,6 @@ class HeadAndNeckBase(SegmentationModelBase):
|
|||
super().__init__(
|
||||
should_validate=False, # we'll validate after kwargs are added
|
||||
num_epochs=num_epochs,
|
||||
recovery_checkpoint_save_interval=10,
|
||||
architecture="UNet3D",
|
||||
kernel_size=3,
|
||||
train_batch_size=1,
|
||||
|
|
|
@ -59,7 +59,6 @@ class HelloWorld(SegmentationModelBase):
|
|||
num_dataload_workers=0,
|
||||
train_batch_size=2,
|
||||
num_epochs=2,
|
||||
recovery_checkpoint_save_interval=1,
|
||||
use_mixed_precision=True,
|
||||
|
||||
# Pre-processing - in this section we define how to normalize our inputs, in this case we are doing
|
||||
|
|
|
@ -13,7 +13,7 @@ from InnerEye.ML.deep_learning_config import OptimizerType
|
|||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
|
||||
# Change this string to the name of your dataset on Azure blob storage.
|
||||
AZURE_DATASET_ID = "2339eba2-8ec5-4ccb-86ff-c170470ac6e2_geonorm_with_train_test_split_2020_05_26"
|
||||
LUNG_AZURE_DATASET_ID = "2339eba2-8ec5-4ccb-86ff-c170470ac6e2_geonorm_with_train_test_split_2020_05_26"
|
||||
|
||||
|
||||
class Lung(SegmentationModelBase):
|
||||
|
@ -29,7 +29,7 @@ class Lung(SegmentationModelBase):
|
|||
architecture="UNet3D",
|
||||
feature_channels=[32],
|
||||
kernel_size=3,
|
||||
azure_dataset_id=AZURE_DATASET_ID,
|
||||
azure_dataset_id=LUNG_AZURE_DATASET_ID,
|
||||
crop_size=(64, 224, 224),
|
||||
test_crop_size=(128, 512, 512),
|
||||
image_channels=["ct"],
|
||||
|
@ -56,7 +56,6 @@ class Lung(SegmentationModelBase):
|
|||
adam_betas=(0.9, 0.999),
|
||||
momentum=0.9,
|
||||
weight_decay=1e-4,
|
||||
recovery_checkpoint_save_interval=10,
|
||||
use_mixed_precision=True,
|
||||
use_model_parallel=True,
|
||||
monitoring_interval_seconds=0,
|
||||
|
|
|
@ -20,7 +20,6 @@ class CIFAR10SimCLR(SSLContainer):
|
|||
ssl_encoder=EncoderName.resnet50,
|
||||
ssl_training_type=SSLTrainingType.SimCLR,
|
||||
random_seed=1,
|
||||
recovery_checkpoint_save_interval=200,
|
||||
num_epochs=2500,
|
||||
num_workers=6)
|
||||
|
||||
|
@ -38,7 +37,6 @@ class CIFAR10BYOL(SSLContainer):
|
|||
ssl_encoder=EncoderName.resnet50,
|
||||
ssl_training_type=SSLTrainingType.BYOL,
|
||||
random_seed=1,
|
||||
recovery_checkpoint_save_interval=200,
|
||||
num_epochs=2500,
|
||||
num_workers=6)
|
||||
|
||||
|
@ -55,6 +53,5 @@ class CIFAR10CIFAR100BYOL(SSLContainer):
|
|||
ssl_encoder=EncoderName.resnet50,
|
||||
ssl_training_type=SSLTrainingType.BYOL,
|
||||
random_seed=1,
|
||||
recovery_checkpoint_save_interval=200,
|
||||
num_epochs=2500,
|
||||
num_workers=6)
|
||||
|
|
|
@ -11,7 +11,6 @@ class SSLClassifierCIFAR(SSLClassifierContainer):
|
|||
super().__init__(
|
||||
linear_head_dataset_name=SSLDatasetName.CIFAR10,
|
||||
random_seed=1,
|
||||
recovery_checkpoint_save_interval=5,
|
||||
num_epochs=100,
|
||||
l_rate=1e-4,
|
||||
num_workers=6)
|
||||
|
|
|
@ -27,7 +27,6 @@ class NIH_RSNA_BYOL(SSLContainer):
|
|||
linear_head_dataset_name=SSLDatasetName.RSNAKaggleCXR,
|
||||
azure_dataset_id=NIH_AZURE_DATASET_ID,
|
||||
random_seed=1,
|
||||
recovery_checkpoint_save_interval=200,
|
||||
num_epochs=1000,
|
||||
# We usually train this model with 16 GPUs, giving an effective batch size of 1200
|
||||
ssl_training_batch_size=75,
|
||||
|
@ -44,7 +43,6 @@ class NIH_RSNA_SimCLR(SSLContainer):
|
|||
linear_head_dataset_name=SSLDatasetName.RSNAKaggleCXR,
|
||||
azure_dataset_id=NIH_AZURE_DATASET_ID,
|
||||
random_seed=1,
|
||||
recovery_checkpoint_save_interval=200,
|
||||
num_epochs=1000,
|
||||
# We usually train this model with 16 GPUs, giving an effective batch size of 1200
|
||||
ssl_training_batch_size=75,
|
||||
|
@ -60,7 +58,6 @@ class CXRImageClassifier(SSLClassifierContainer):
|
|||
def __init__(self) -> None:
|
||||
super().__init__(linear_head_dataset_name=SSLDatasetName.RSNAKaggleCXR,
|
||||
random_seed=1,
|
||||
recovery_checkpoint_save_interval=10,
|
||||
num_epochs=200,
|
||||
use_balanced_binary_loss_for_linear_head=True,
|
||||
azure_dataset_id=RSNA_AZURE_DATASET_ID,
|
||||
|
|
|
@ -20,8 +20,6 @@ class NIH_COVID_BYOL(SSLContainer):
|
|||
super().__init__(ssl_training_dataset_name=SSLDatasetName.NIHCXR,
|
||||
linear_head_dataset_name=SSLDatasetName.Covid,
|
||||
random_seed=1,
|
||||
recovery_checkpoint_save_interval=50,
|
||||
recovery_checkpoints_save_last_k=3,
|
||||
num_epochs=500,
|
||||
ssl_training_batch_size=75, # This runs with 16 gpus (4 nodes)
|
||||
num_workers=12,
|
||||
|
|
|
@ -12,7 +12,7 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
from InnerEye.ML.config import ModelArchitectureConfig, SegmentationModelBase, equally_weighted_classes
|
||||
from InnerEye.ML.configs.segmentation.Lung import AZURE_DATASET_ID
|
||||
from InnerEye.ML.configs.segmentation.Lung import LUNG_AZURE_DATASET_ID
|
||||
from InnerEye.ML.models.architectures.base_model import BaseSegmentationModel
|
||||
from InnerEye.ML.models.parallel.model_parallel import get_device_from_parameters, move_to_device
|
||||
from InnerEye.ML.utils.model_metadata_util import generate_random_colours_list
|
||||
|
@ -34,7 +34,7 @@ class PassThroughModel(SegmentationModelBase):
|
|||
should_validate=False,
|
||||
# Set as UNet3D only because this does not shrink patches in the forward pass.
|
||||
architecture=ModelArchitectureConfig.UNet3D,
|
||||
azure_dataset_id=AZURE_DATASET_ID,
|
||||
azure_dataset_id=LUNG_AZURE_DATASET_ID,
|
||||
crop_size=(64, 224, 224),
|
||||
num_dataload_workers=1,
|
||||
# Disable monitoring so that we can use VS Code remote debugging
|
||||
|
|
|
@ -475,7 +475,6 @@ class OutputParams(param.Parameterized):
|
|||
@property
|
||||
def checkpoint_folder(self) -> Path:
|
||||
"""Gets the full path in which the model checkpoints should be stored during training."""
|
||||
print(f"Expected Checkpoint path {self.outputs_folder / CHECKPOINT_FOLDER}")
|
||||
return self.outputs_folder / CHECKPOINT_FOLDER
|
||||
|
||||
@property
|
||||
|
@ -567,15 +566,11 @@ class OptimizerParams(param.Parameterized):
|
|||
|
||||
class TrainerParams(param.Parameterized):
|
||||
num_epochs: int = param.Integer(100, bounds=(1, None), doc="Number of epochs to train.")
|
||||
recovery_checkpoint_save_interval: int = param.Integer(10, bounds=(0, None),
|
||||
doc="Save epoch checkpoints when epoch number is a multiple "
|
||||
"of recovery_checkpoint_save_interval. The intended use "
|
||||
"is to allow restore training from failed runs.")
|
||||
recovery_checkpoints_save_last_k: int = param.Integer(default=1, bounds=(-1, None),
|
||||
doc="Number of recovery checkpoints to keep. Recovery "
|
||||
"checkpoints will be stored as recovery_epoch:{"
|
||||
"epoch}.ckpt. If set to -1 keep all recovery "
|
||||
"checkpoints.")
|
||||
autosave_every_n_val_epochs: int = param.Integer(1, bounds=(0, None),
|
||||
doc="Save epoch checkpoints every N validation epochs. "
|
||||
"If pl_check_val_every_n_epoch > 1, this means that "
|
||||
"checkpoints are saved every N * pl_check_val_every_n_epoch "
|
||||
"training epochs.")
|
||||
detect_anomaly: bool = param.Boolean(False, doc="If true, test gradients for anomalies (NaN or Inf) during "
|
||||
"training.")
|
||||
use_mixed_precision: bool = param.Boolean(False, doc="If true, mixed precision training is activated during "
|
||||
|
|
|
@ -8,7 +8,7 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple, TypeVar
|
||||
|
||||
from pytorch_lightning import Callback, LightningModule, Trainer, seed_everything
|
||||
from pytorch_lightning import Callback, Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint, TQDMProgressBar
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
|
@ -17,15 +17,16 @@ from InnerEye.Azure.azure_runner import ENV_GLOBAL_RANK, ENV_LOCAL_RANK, ENV_NOD
|
|||
from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context
|
||||
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, change_working_directory
|
||||
from InnerEye.Common.resource_monitor import ResourceMonitor
|
||||
from InnerEye.ML.common import ARGS_TXT, ModelExecutionMode, RECOVERY_CHECKPOINT_FILE_NAME, VISUALIZATION_FOLDER
|
||||
from InnerEye.ML.common import ARGS_TXT, AUTOSAVE_CHECKPOINT_FILE_NAME, ModelExecutionMode, \
|
||||
VISUALIZATION_FOLDER
|
||||
from InnerEye.ML.lightning_base import InnerEyeContainer, InnerEyeLightning
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
from InnerEye.ML.lightning_loggers import StoringLogger
|
||||
from InnerEye.ML.lightning_models import SUBJECT_OUTPUT_PER_RANK_PREFIX, ScalarLightning, \
|
||||
get_subject_output_file_per_rank
|
||||
from InnerEye.ML.utils.checkpoint_handling import create_best_checkpoint
|
||||
from InnerEye.ML.utils.checkpoint_handling import cleanup_checkpoints
|
||||
from health_azure.utils import is_global_rank_zero, is_local_rank_zero
|
||||
from health_ml.utils import AzureMLLogger, AzureMLProgressBar, log_on_epoch
|
||||
from health_ml.utils import AzureMLLogger, AzureMLProgressBar
|
||||
|
||||
TEMP_PREFIX = "temp/"
|
||||
|
||||
|
@ -54,29 +55,6 @@ def write_args_file(config: Any, outputs_folder: Path) -> None:
|
|||
logging.info(output)
|
||||
|
||||
|
||||
class InnerEyeRecoveryCheckpointCallback(ModelCheckpoint):
|
||||
"""
|
||||
This callback is used to save recovery checkpoints.
|
||||
In particular, it makes sure we are logging "epoch", this is needed to the last k
|
||||
checkpoints (here save_top_k is based on the epoch number instead of validation loss,
|
||||
PL only allows to save_top_k for logged quantities).
|
||||
"""
|
||||
|
||||
def __init__(self, container: LightningContainer):
|
||||
super().__init__(dirpath=str(container.checkpoint_folder),
|
||||
monitor="epoch_started",
|
||||
filename=RECOVERY_CHECKPOINT_FILE_NAME + "_{epoch}",
|
||||
every_n_epochs=container.recovery_checkpoint_save_interval,
|
||||
save_top_k=container.recovery_checkpoints_save_last_k,
|
||||
mode="max",
|
||||
save_last=False)
|
||||
|
||||
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule, unused: bool = None) -> None:
|
||||
# The metric to monitor must be logged on all ranks in distributed training
|
||||
log_on_epoch(pl_module, name="epoch_started", value=trainer.current_epoch, sync_dist=False) # type: ignore
|
||||
super().on_train_epoch_end(trainer, pl_module)
|
||||
|
||||
|
||||
def create_lightning_trainer(container: LightningContainer,
|
||||
resume_from_checkpoint: Optional[Path] = None,
|
||||
num_nodes: int = 1) -> \
|
||||
|
@ -126,15 +104,20 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
deterministic = False
|
||||
benchmark = True
|
||||
|
||||
# For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation
|
||||
# The last checkpoint is considered the "best" checkpoint. For large segmentation
|
||||
# models, this still appears to be the best way of choosing them because validation loss on the relatively small
|
||||
# training patches is not stable enough. Going by the validation loss somehow works for the Prostate model, but
|
||||
# not for the HeadAndNeck model.
|
||||
last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder), save_last=True, save_top_k=0)
|
||||
# Recovery checkpoints: {epoch} will turn into a string like "epoch=1"
|
||||
# Store 1 recovery checkpoint every recovery_checkpoint_save_interval epochs, keep the last
|
||||
# recovery_checkpoints_save_last_k.
|
||||
recovery_checkpoint_callback = InnerEyeRecoveryCheckpointCallback(container)
|
||||
# Note that "last" is somehow a misnomer, it should rather be "latest". There is a "last" checkpoint written in
|
||||
# every epoch. We could use that for recovery too, but it could happen that the job gets preempted right during
|
||||
# writing that file, and we would end up with an invalid file.
|
||||
last_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
|
||||
save_last=True,
|
||||
save_top_k=0)
|
||||
recovery_checkpoint_callback = ModelCheckpoint(dirpath=str(container.checkpoint_folder),
|
||||
filename=AUTOSAVE_CHECKPOINT_FILE_NAME,
|
||||
every_n_val_epochs=container.autosave_every_n_val_epochs,
|
||||
save_last=False)
|
||||
callbacks: List[Callback] = [
|
||||
last_checkpoint_callback,
|
||||
recovery_checkpoint_callback,
|
||||
|
@ -287,8 +270,8 @@ def model_train(checkpoint_path: Optional[Path],
|
|||
logging.info(f"Terminating training thread with rank {lightning_model.global_rank}.")
|
||||
sys.exit()
|
||||
|
||||
logging.info("Choosing the best checkpoint and removing redundant files.")
|
||||
create_best_checkpoint(container.checkpoint_folder)
|
||||
logging.info("Removing redundant checkpoint files.")
|
||||
cleanup_checkpoints(container.checkpoint_folder)
|
||||
# Lightning modifies a ton of environment variables. If we first run training and then the test suite,
|
||||
# those environment variables will mislead the training runs in the test suite, and make them crash.
|
||||
# Hence, restore the original environment after training.
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from builtins import property
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from azureml.core import Model, Run, Workspace
|
||||
|
||||
from InnerEye.Azure.azure_config import AzureConfig
|
||||
|
@ -22,8 +22,8 @@ from InnerEye.Azure.azure_util import RUN_CONTEXT, download_run_output_file, dow
|
|||
fetch_child_runs, tag_values_all_distinct
|
||||
from InnerEye.Common.common_util import OTHER_RUNS_SUBDIR_NAME
|
||||
from InnerEye.Common.fixed_paths import DEFAULT_AML_UPLOAD_DIR, MODEL_INFERENCE_JSON_FILE_NAME
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CHECKPOINT_FOLDER, \
|
||||
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, RECOVERY_CHECKPOINT_FILE_NAME
|
||||
from InnerEye.ML.common import (AUTOSAVE_CHECKPOINT_CANDIDATES, CHECKPOINT_FOLDER,
|
||||
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, LEGACY_RECOVERY_CHECKPOINT_FILE_NAME)
|
||||
from InnerEye.ML.deep_learning_config import OutputParams
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
from InnerEye.ML.model_inference_config import read_model_inference_config
|
||||
|
@ -105,13 +105,7 @@ class CheckpointHandler:
|
|||
logging.info(f"Available checkpoints: {len(checkpoints)}")
|
||||
for f in checkpoints:
|
||||
logging.info(f)
|
||||
recovery = find_recovery_checkpoint_and_epoch(self.container.checkpoint_folder)
|
||||
if recovery is not None:
|
||||
local_recovery_path, recovery_epoch = recovery
|
||||
self.container._start_epoch = recovery_epoch
|
||||
return local_recovery_path
|
||||
else:
|
||||
return None
|
||||
return find_recovery_checkpoint_on_disk_or_cloud(self.container.checkpoint_folder)
|
||||
|
||||
def get_best_checkpoints(self) -> List[Path]:
|
||||
"""
|
||||
|
@ -278,85 +272,85 @@ def download_folder_from_run_to_temp_folder(folder: str,
|
|||
return temp_folder / cleaned_prefix
|
||||
|
||||
|
||||
PathAndEpoch = Tuple[Path, int]
|
||||
|
||||
|
||||
def find_recovery_checkpoint_and_epoch(path: Path) -> Optional[PathAndEpoch]:
|
||||
def find_recovery_checkpoint_on_disk_or_cloud(path: Path) -> Optional[Path]:
|
||||
"""
|
||||
Looks at all the recovery files, extracts the epoch number for all of them and returns the most recent (latest
|
||||
epoch)
|
||||
checkpoint path along with the corresponding epoch number. If no recovery checkpoint are found, return None.
|
||||
Looks at all the checkpoint files and returns the path to the one that should be used for recovery.
|
||||
If no checkpoint files are found on disk, the function attempts to download from the current AzureML run.
|
||||
:param path: The folder to start searching in.
|
||||
:return: None if there is no file matching the search pattern, or a Tuple with Path object and integer pointing to
|
||||
recovery checkpoint path and recovery epoch.
|
||||
:return: None if there is no suitable recovery checkpoints, or else a full path to the checkpoint file.
|
||||
"""
|
||||
available_checkpoints = find_all_recovery_checkpoints(path)
|
||||
if available_checkpoints is None and is_running_in_azure_ml():
|
||||
recovery_checkpoint = find_recovery_checkpoint(path)
|
||||
if recovery_checkpoint is None and is_running_in_azure_ml():
|
||||
logging.info("No checkpoints available in the checkpoint folder. Trying to find checkpoints in AzureML.")
|
||||
# Download checkpoints from AzureML, then try to find recovery checkpoints among those.
|
||||
# Downloads should go to a temporary folder because downloading the files to the checkpoint folder might
|
||||
# cause artifact conflicts later.
|
||||
temp_folder = download_folder_from_run_to_temp_folder(folder=f"{DEFAULT_AML_UPLOAD_DIR}/{CHECKPOINT_FOLDER}/")
|
||||
available_checkpoints = find_all_recovery_checkpoints(temp_folder)
|
||||
if available_checkpoints is not None:
|
||||
return extract_latest_checkpoint_and_epoch(available_checkpoints)
|
||||
return None
|
||||
recovery_checkpoint = find_recovery_checkpoint(temp_folder)
|
||||
return recovery_checkpoint
|
||||
|
||||
|
||||
def get_recovery_checkpoint_path(path: Path) -> Path:
|
||||
"""
|
||||
Returns the path to the last recovery checkpoint in the given folder or the provided filename. Raises a
|
||||
FileNotFoundError if no
|
||||
recovery checkpoint file is present.
|
||||
FileNotFoundError if no recovery checkpoint file is present.
|
||||
:param path: Path to checkpoint folder
|
||||
"""
|
||||
recovery_ckpt_and_epoch = find_recovery_checkpoint_and_epoch(path)
|
||||
if recovery_ckpt_and_epoch is not None:
|
||||
return recovery_ckpt_and_epoch[0]
|
||||
files = list(path.glob("*"))
|
||||
raise FileNotFoundError(f"No checkpoint files found in {path}. Existing files: {' '.join(p.name for p in files)}")
|
||||
recovery_checkpoint = find_recovery_checkpoint(path)
|
||||
if recovery_checkpoint is None:
|
||||
files = [f.name for f in path.glob("*")]
|
||||
raise FileNotFoundError(f"No checkpoint files found in {path}. Existing files: {' '.join(files)}")
|
||||
return recovery_checkpoint
|
||||
|
||||
|
||||
def find_all_recovery_checkpoints(path: Path) -> Optional[List[Path]]:
|
||||
def find_recovery_checkpoint(path: Path) -> Optional[Path]:
|
||||
"""
|
||||
Extracts all file starting with RECOVERY_CHECKPOINT_FILE_NAME in path
|
||||
:param path:
|
||||
:return:
|
||||
Finds the checkpoint file in the given path that can be used for re-starting the present job.
|
||||
This can be an autosave checkpoint, or the last checkpoint. All existing checkpoints are loaded, and the one
|
||||
for the highest epoch is used for recovery.
|
||||
:param path: The folder to search in.
|
||||
:return: Returns the checkpoint file to use for re-starting, or None if no such file was found.
|
||||
"""
|
||||
all_recovery_files = [f for f in path.glob(RECOVERY_CHECKPOINT_FILE_NAME + "*")]
|
||||
if len(all_recovery_files) == 0:
|
||||
return None
|
||||
return all_recovery_files
|
||||
legacy_recovery_checkpoints = list(path.glob(LEGACY_RECOVERY_CHECKPOINT_FILE_NAME + "*"))
|
||||
if len(legacy_recovery_checkpoints) > 0:
|
||||
logging.warning(f"Found these legacy checkpoint files: {legacy_recovery_checkpoints}")
|
||||
raise ValueError("The legacy recovery checkpoint setup is no longer supported. As a workaround, you can take "
|
||||
f"one of the legacy checkpoints and upload as '{AUTOSAVE_CHECKPOINT_CANDIDATES[0]}'")
|
||||
candidates = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
|
||||
highest_epoch: Optional[int] = None
|
||||
file_with_highest_epoch: Optional[Path] = None
|
||||
for f in candidates:
|
||||
full_path = path / f
|
||||
if full_path.is_file():
|
||||
try:
|
||||
checkpoint = torch.load(str(full_path), map_location=torch.device("cpu"))
|
||||
epoch = checkpoint["epoch"]
|
||||
logging.info(f"Checkpoint for epoch {epoch} in {full_path}")
|
||||
if (highest_epoch is None) or (epoch > highest_epoch):
|
||||
highest_epoch = epoch
|
||||
file_with_highest_epoch = full_path
|
||||
except Exception as ex:
|
||||
logging.warning(f"Unable to load checkpoint file {full_path}: {ex}")
|
||||
return file_with_highest_epoch
|
||||
|
||||
|
||||
def extract_latest_checkpoint_and_epoch(available_files: List[Path]) -> PathAndEpoch:
|
||||
def cleanup_checkpoints(path: Path) -> None:
|
||||
"""
|
||||
Checkpoints are saved as recovery_epoch={epoch}.ckpt, find the latest ckpt and epoch number.
|
||||
:param available_files: all available checkpoints
|
||||
:return: path the checkpoint from latest epoch and epoch number
|
||||
"""
|
||||
recovery_epochs = [int(re.findall(r"[\d]+", f.stem)[0]) for f in available_files]
|
||||
idx_max_epoch = int(np.argmax(recovery_epochs))
|
||||
return available_files[idx_max_epoch], recovery_epochs[idx_max_epoch]
|
||||
|
||||
|
||||
def create_best_checkpoint(path: Path) -> Path:
|
||||
"""
|
||||
Creates the best checkpoint file. "Best" is at the moment defined as being the last checkpoint, but could be
|
||||
based on some defined policy.
|
||||
The best checkpoint will be renamed to `best_checkpoint.ckpt`.
|
||||
Remove autosave checkpoints from the given checkpoint folder, and check if a "last.ckpt" checkpoint is present.
|
||||
:param path: The folder that contains all checkpoint files.
|
||||
"""
|
||||
logging.debug(f"Files in checkpoint folder: {' '.join(p.name for p in path.glob('*'))}")
|
||||
logging.info(f"Files in checkpoint folder: {' '.join(p.name for p in path.glob('*'))}")
|
||||
last_ckpt = path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
all_files = f"Existing files: {' '.join(p.name for p in path.glob('*'))}"
|
||||
if not last_ckpt.is_file():
|
||||
raise FileNotFoundError(f"Checkpoint file {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} not found. {all_files}")
|
||||
logging.info(f"Using {LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX} as the best checkpoint: Renaming to "
|
||||
f"{BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}")
|
||||
best = path / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
last_ckpt.rename(best)
|
||||
return best
|
||||
# Training is finished now. To save storage, remove the autosave checkpoint which is now obsolete.
|
||||
# Lightning does not overwrite checkpoints in-place. Rather, it writes "autosave.ckpt",
|
||||
# then "autosave-1.ckpt" and deletes "autosave.ckpt", then "autosave.ckpt" and deletes "autosave-v1.ckpt"
|
||||
for candidate in AUTOSAVE_CHECKPOINT_CANDIDATES:
|
||||
autosave = path / candidate
|
||||
if autosave.is_file():
|
||||
autosave.unlink()
|
||||
|
||||
|
||||
def download_best_checkpoints_from_child_runs(config: OutputParams, run: Run) -> RunRecovery:
|
||||
|
@ -385,7 +379,7 @@ def download_best_checkpoints_from_child_runs(config: OutputParams, run: Run) ->
|
|||
subdir = str(child.tags[tag_to_use] if can_use_split_indices else child.number)
|
||||
child_dst = config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME / subdir
|
||||
download_run_output_file(
|
||||
blob_path=Path(CHECKPOINT_FOLDER) / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
blob_path=Path(CHECKPOINT_FOLDER) / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
destination=child_dst,
|
||||
run=child
|
||||
)
|
||||
|
|
|
@ -1 +1 @@
|
|||
{"model_name": "BasicModel2Epochs", "checkpoint_paths": ["checkpoints/best_checkpoint.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.segmentation.BasicModel2Epochs"}
|
||||
{"model_name": "BasicModel2Epochs", "checkpoint_paths": ["checkpoints/last.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.segmentation.BasicModel2Epochs"}
|
|
@ -1 +1 @@
|
|||
{"model_name": "GlaucomaPublic", "checkpoint_paths": ["checkpoints/OTHER_RUNS/1/best_checkpoint.ckpt", "checkpoints/best_checkpoint.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.classification.GlaucomaPublic"}
|
||||
{"model_name": "GlaucomaPublic", "checkpoint_paths": ["checkpoints/OTHER_RUNS/1/last.ckpt", "checkpoints/last.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.classification.GlaucomaPublic"}
|
|
@ -1 +1 @@
|
|||
{"model_name": "HelloContainer", "checkpoint_paths": ["checkpoints/best_checkpoint.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.other.HelloContainer"}
|
||||
{"model_name": "HelloContainer", "checkpoint_paths": ["checkpoints/last.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.other.HelloContainer"}
|
|
@ -1 +1 @@
|
|||
{"model_name": "BasicModelForEnsembleTest", "checkpoint_paths": ["checkpoints/OTHER_RUNS/1/best_checkpoint.ckpt", "checkpoints/best_checkpoint.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.segmentation.BasicModel2Epochs"}
|
||||
{"model_name": "BasicModelForEnsembleTest", "checkpoint_paths": ["checkpoints/OTHER_RUNS/1/last.ckpt", "checkpoints/last.ckpt"], "model_configs_namespace": "InnerEye.ML.configs.segmentation.BasicModel2Epochs"}
|
|
@ -36,8 +36,8 @@ from InnerEye.Common.fixed_paths import (DEFAULT_AML_UPLOAD_DIR, DEFAULT_RESULT_
|
|||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.Common.spawn_subprocess import spawn_and_monitor_subprocess
|
||||
from InnerEye.ML.common import (CHECKPOINT_FOLDER, DATASET_CSV_FILE_NAME, ModelExecutionMode,
|
||||
RECOVERY_CHECKPOINT_FILE_NAME)
|
||||
from InnerEye.ML.common import (LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CHECKPOINT_FOLDER, DATASET_CSV_FILE_NAME,
|
||||
ModelExecutionMode)
|
||||
from InnerEye.ML.configs.other.HelloContainer import HelloContainer
|
||||
from InnerEye.ML.configs.segmentation.BasicModel2Epochs import BasicModel2Epochs
|
||||
from InnerEye.ML.deep_learning_config import ModelCategory
|
||||
|
@ -47,7 +47,7 @@ from InnerEye.ML.reports.notebook_report import get_html_report_name
|
|||
from InnerEye.ML.run_ml import MLRunner
|
||||
from InnerEye.ML.runner import main
|
||||
from InnerEye.ML.utils.checkpoint_handling import download_folder_from_run_to_temp_folder, \
|
||||
find_recovery_checkpoint_and_epoch
|
||||
find_recovery_checkpoint_on_disk_or_cloud
|
||||
from InnerEye.ML.utils.config_loader import ModelConfigLoader
|
||||
from InnerEye.ML.utils.image_util import get_unit_image_header
|
||||
from InnerEye.ML.utils.io_util import zip_random_dicom_series
|
||||
|
@ -56,11 +56,11 @@ from InnerEye.Scripts import submit_for_inference
|
|||
from Tests.ML.util import assert_nifti_content, get_default_azure_config, get_default_workspace, get_nifti_shape
|
||||
from health_azure.himl import RUN_RECOVERY_FILE
|
||||
|
||||
FALLBACK_SINGLE_RUN = "refs_pull_606_merge:refs_pull_606_merge_1638867172_17ba8dc5"
|
||||
FALLBACK_ENSEMBLE_RUN = "refs_pull_606_merge:HD_b8a6ad93-8c19-45de-8ea1-f87fce92c3bd"
|
||||
FALLBACK_2NODE_RUN = "refs_pull_593_merge:refs_pull_591_merge_1639416130_e5d29ba7"
|
||||
FALLBACK_SINGLE_RUN = "refs_pull_633_merge_1642019743_f212b068" # PR job TrainBasicModel
|
||||
FALLBACK_ENSEMBLE_RUN = "HD_5ebb378b-272b-4633-a5f8-23e958ddbf8f" # PR job TrainEnsemble
|
||||
FALLBACK_2NODE_RUN = "refs_pull_633_merge_1642019739_a6dbe9e6" # PR job Train2Nodes
|
||||
FALLBACK_CV_GLAUCOMA = "refs_pull_545_merge:HD_72ecc647-07c3-4353-a538-620346114ebd"
|
||||
FALLBACK_HELLO_CONTAINER_RUN = "refs_pull_606_merge:refs_pull_606_merge_1638867108_789991ac"
|
||||
FALLBACK_HELLO_CONTAINER_RUN = "refs_pull_633_merge_1642019742_0d8a7e73" # PR job HelloContainerPR
|
||||
|
||||
|
||||
def get_most_recent_run_id(fallback_run_id_for_local_execution: str = FALLBACK_SINGLE_RUN) -> str:
|
||||
|
@ -232,7 +232,8 @@ def test_download_checkpoints_from_aml(test_output_dirs: OutputFolderForTests) -
|
|||
run=run,
|
||||
workspace=get_default_workspace())
|
||||
files = list(temp_folder.glob("*"))
|
||||
assert len(files) == 2
|
||||
assert len(files) == 1
|
||||
assert (temp_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
|
||||
# Test if what's in the folder are really files, not directories
|
||||
for file in files:
|
||||
assert file.is_file()
|
||||
|
@ -243,13 +244,10 @@ def test_download_checkpoints_from_aml(test_output_dirs: OutputFolderForTests) -
|
|||
return_value=temp_folder) as download:
|
||||
# Call the checkpoint finder with a temp folder that does not contain any files, so it should try to
|
||||
# download
|
||||
result = find_recovery_checkpoint_and_epoch(test_output_dirs.root_dir)
|
||||
result = find_recovery_checkpoint_on_disk_or_cloud(test_output_dirs.root_dir)
|
||||
download.assert_called_once_with(folder=checkpoint_folder)
|
||||
assert result is not None
|
||||
p, epoch = result
|
||||
# The basic model only writes one checkpoint at epoch 1
|
||||
assert epoch == 1
|
||||
assert RECOVERY_CHECKPOINT_FILE_NAME in p.stem
|
||||
assert result.name == LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
|
||||
|
||||
@pytest.mark.inference
|
||||
|
|
|
@ -217,7 +217,7 @@ def test_rnn_classifier_via_config_1(use_combined_model: bool,
|
|||
image_and_seg = ImageAndSegmentations[np.ndarray](images=np.random.uniform(0, 1, SCAN_SIZE),
|
||||
segmentations=np.random.randint(0, 2, SCAN_SIZE))
|
||||
with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg):
|
||||
model_train_unittest(config, dirs=test_output_dirs)
|
||||
model_train_unittest(config, output_folder=test_output_dirs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(common_util.is_windows(), reason="Has issues on windows build")
|
||||
|
@ -384,7 +384,7 @@ def test_rnn_classifier_via_config_2(test_output_dirs: OutputFolderForTests) ->
|
|||
config.num_epochs = 2
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.dataset_data_frame = _get_mock_sequence_dataset(dataset_contents)
|
||||
results, _ = model_train_unittest(config, dirs=test_output_dirs)
|
||||
results, _ = model_train_unittest(config, output_folder=test_output_dirs)
|
||||
|
||||
actual_train_loss = results.get_metric(is_training=True, metric_type=MetricType.LOSS.value)[-1]
|
||||
actual_val_loss = results.get_metric(is_training=False, metric_type=MetricType.LOSS.value)[-1]
|
||||
|
|
|
@ -232,7 +232,7 @@ S3,week1,scan3.npy,True,6,60,Male,Val2
|
|||
summarizer.generate_summary(input_sizes=input_size)
|
||||
config.local_dataset = dataset_folder
|
||||
config.validate()
|
||||
model_train_unittest(config, dirs=test_output_dirs)
|
||||
model_train_unittest(config, output_folder=test_output_dirs)
|
||||
# No further asserts here because the models are still in experimental state. Most errors would come
|
||||
# from having invalid model architectures, which would throw runtime errors during training.
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ def test_non_image_encoder(test_output_dirs: OutputFolderForTests,
|
|||
config.max_batch_grad_cam = 1
|
||||
config.validate()
|
||||
# run model training
|
||||
_, checkpoint_handler = model_train_unittest(config, dirs=test_output_dirs)
|
||||
_, checkpoint_handler = model_train_unittest(config, output_folder=test_output_dirs)
|
||||
# run model inference
|
||||
runner = MLRunner(config)
|
||||
runner.setup()
|
||||
|
|
|
@ -30,7 +30,7 @@ def test_train_2d_classification_model(test_output_dirs: OutputFolderForTests,
|
|||
# Train for 4 epochs, checkpoints at epochs 2 and 4
|
||||
config.num_epochs = 4
|
||||
config.use_mixed_precision = use_mixed_precision
|
||||
model_training_result, checkpoint_handler = model_train_unittest(config, dirs=test_output_dirs)
|
||||
model_training_result, checkpoint_handler = model_train_unittest(config, output_folder=test_output_dirs)
|
||||
assert model_training_result is not None
|
||||
expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
|
||||
|
||||
|
|
|
@ -169,5 +169,5 @@ def test_run_model_with_invalid_trainer_arguments(test_output_dirs: OutputFolder
|
|||
"""
|
||||
container = DummyContainerWithInvalidTrainerArguments()
|
||||
with pytest.raises(Exception) as ex:
|
||||
model_train_unittest(config=None, dirs=test_output_dirs, lightning_container=container)
|
||||
model_train_unittest(config=None, output_folder=test_output_dirs, lightning_container=container)
|
||||
assert "no_such_argument" in str(ex)
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
@ -15,34 +14,29 @@ import pytest
|
|||
import torch
|
||||
|
||||
from InnerEye.Common import common_util, fixed_paths
|
||||
from InnerEye.Common.common_util import BEST_EPOCH_FOLDER_NAME, CROSSVAL_RESULTS_FOLDER, EPOCH_METRICS_FILE_NAME, \
|
||||
METRICS_AGGREGATES_FILE, SUBJECT_METRICS_FILE_NAME, get_best_epoch_results_path, logging_to_stdout
|
||||
from InnerEye.Common.common_util import (BEST_EPOCH_FOLDER_NAME, CROSSVAL_RESULTS_FOLDER, EPOCH_METRICS_FILE_NAME,
|
||||
METRICS_AGGREGATES_FILE, SUBJECT_METRICS_FILE_NAME,
|
||||
get_best_epoch_results_path, logging_to_stdout)
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns, MetricType
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML import model_testing, runner
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CHECKPOINT_SUFFIX, ModelExecutionMode, \
|
||||
RECOVERY_CHECKPOINT_FILE_NAME
|
||||
from InnerEye.ML.configs.classification.DummyClassification import DummyClassification
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.configs.classification.DummyMulticlassClassification import DummyMulticlassClassification
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
|
||||
from InnerEye.ML.metrics import InferenceMetricsForClassification, binary_classification_accuracy, \
|
||||
compute_scalar_metrics
|
||||
from InnerEye.ML.metrics_dict import MetricsDict, ScalarMetricsDict
|
||||
from InnerEye.ML.model_training import model_train
|
||||
from InnerEye.ML.reports.notebook_report import generate_classification_multilabel_notebook, \
|
||||
generate_classification_notebook, get_html_report_name, get_ipynb_report_name
|
||||
from InnerEye.ML.run_ml import MLRunner
|
||||
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
|
||||
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
|
||||
from InnerEye.ML.utils.config_loader import ModelConfigLoader
|
||||
from InnerEye.ML.visualizers.plot_cross_validation import EpochMetricValues, get_config_and_results_for_offline_runs, \
|
||||
unroll_aggregate_metrics
|
||||
from InnerEye.ML.visualizers.plot_cross_validation import (EpochMetricValues, get_config_and_results_for_offline_runs,
|
||||
unroll_aggregate_metrics)
|
||||
from Tests.ML.configs.ClassificationModelForTesting import ClassificationModelForTesting
|
||||
from Tests.ML.configs.DummyModel import DummyModel
|
||||
from Tests.ML.util import get_default_azure_config, machine_has_gpu, \
|
||||
model_train_unittest
|
||||
from Tests.ML.utils.test_model_util import FIXED_EPOCH, create_model_and_store_checkpoint
|
||||
from Tests.ML.util import get_default_azure_config, machine_has_gpu, model_train_unittest
|
||||
|
||||
|
||||
@pytest.mark.cpu_and_gpu
|
||||
|
@ -59,7 +53,7 @@ def test_train_classification_model(class_name: str, test_output_dirs: OutputFol
|
|||
config.set_output_to(test_output_dirs.root_dir)
|
||||
# Train for 4 epochs, checkpoints at epochs 2 and 4
|
||||
config.num_epochs = 4
|
||||
model_training_result, checkpoint_handler = model_train_unittest(config, dirs=test_output_dirs)
|
||||
model_training_result, checkpoint_handler = model_train_unittest(config, output_folder=test_output_dirs)
|
||||
assert model_training_result is not None
|
||||
expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
|
||||
expected_train_loss = [0.686614, 0.686465, 0.686316, 0.686167]
|
||||
|
@ -168,7 +162,7 @@ def test_train_classification_multilabel_model(test_output_dirs: OutputFolderFor
|
|||
config.set_output_to(test_output_dirs.root_dir)
|
||||
# Train for 4 epochs, checkpoints at epochs 2 and 4
|
||||
config.num_epochs = 4
|
||||
model_training_result, checkpoint_handler = model_train_unittest(config, dirs=test_output_dirs)
|
||||
model_training_result, checkpoint_handler = model_train_unittest(config, output_folder=test_output_dirs)
|
||||
assert model_training_result is not None
|
||||
expected_learning_rates = [0.0001, 9.99971e-05, 9.99930e-05, 9.99861e-05]
|
||||
expected_train_loss = [0.699870228767395, 0.6239662170410156, 0.551329493522644, 0.4825132489204407]
|
||||
|
@ -338,8 +332,6 @@ def test_runner1(test_output_dirs: OutputFolderForTests) -> None:
|
|||
"--non_image_feature_channels", scalar1,
|
||||
"--output_to", output_root,
|
||||
"--max_num_gpus", "1",
|
||||
"--recovery_checkpoint_save_interval", "2",
|
||||
"--recovery_checkpoints_save_last_k", "2",
|
||||
"--num_epochs", "6",
|
||||
]
|
||||
with mock.patch("sys.argv", args):
|
||||
|
@ -350,53 +342,6 @@ def test_runner1(test_output_dirs: OutputFolderForTests) -> None:
|
|||
assert config.get_effective_random_seed() == set_from_commandline
|
||||
assert config.non_image_feature_channels == ["label"]
|
||||
assert str(config.outputs_folder).startswith(output_root)
|
||||
# Check that we saved one checkpoint every second epoch and that we kept only that last 2 and that last.ckpt has
|
||||
# been renamed to best.ckpt
|
||||
assert len(os.listdir(config.checkpoint_folder)) == 3
|
||||
assert (config.checkpoint_folder / str(RECOVERY_CHECKPOINT_FILE_NAME + "_epoch=3" + CHECKPOINT_SUFFIX)).exists()
|
||||
assert (config.checkpoint_folder / str(RECOVERY_CHECKPOINT_FILE_NAME + "_epoch=5" + CHECKPOINT_SUFFIX)).exists()
|
||||
assert (config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(common_util.is_windows(), reason="Has OOM issues on windows build")
|
||||
def test_runner_restart(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test if starting training from a folder where the checkpoints folder already has recovery checkpoints picks up
|
||||
that it is a recovery run. Also checks that we update the start epoch in the config at loading time.
|
||||
"""
|
||||
model_config = DummyClassification()
|
||||
model_config.set_output_to(test_output_dirs.root_dir)
|
||||
model_config.num_epochs = FIXED_EPOCH + 2
|
||||
# We save all checkpoints - if recovery works as expected we should have a new checkpoint for epoch 4, 5.
|
||||
model_config.recovery_checkpoint_save_interval = 1
|
||||
model_config.recovery_checkpoints_save_last_k = -1
|
||||
runner = MLRunner(model_config=model_config)
|
||||
runner.setup()
|
||||
# Epochs are 0 based for saving
|
||||
create_model_and_store_checkpoint(model_config,
|
||||
runner.container.checkpoint_folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
|
||||
f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}",
|
||||
weights_only=False)
|
||||
azure_config = get_default_azure_config()
|
||||
checkpoint_handler = CheckpointHandler(azure_config=azure_config,
|
||||
container=runner.container,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
_, storing_logger = model_train(checkpoint_path=checkpoint_handler.get_recovery_or_checkpoint_path_train(),
|
||||
container=runner.container)
|
||||
# We expect to have 4 checkpoints, FIXED_EPOCH (recovery), FIXED_EPOCH+1, FIXED_EPOCH and best.
|
||||
assert len(os.listdir(runner.container.checkpoint_folder)) == 4
|
||||
assert (
|
||||
runner.container.checkpoint_folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
|
||||
f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}").exists()
|
||||
assert (
|
||||
runner.container.checkpoint_folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
|
||||
f"{FIXED_EPOCH}{CHECKPOINT_SUFFIX}").exists()
|
||||
assert (
|
||||
runner.container.checkpoint_folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
|
||||
f"{FIXED_EPOCH + 1}{CHECKPOINT_SUFFIX}").exists()
|
||||
assert (runner.container.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).exists()
|
||||
# Check that we really restarted epoch from epoch FIXED_EPOCH.
|
||||
assert list(storing_logger.epochs) == [FIXED_EPOCH, FIXED_EPOCH + 1] # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.skipif(common_util.is_windows(), reason="Has OOM issues on windows build")
|
||||
|
|
|
@ -21,7 +21,7 @@ from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, ModelProcessi
|
|||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, DATASET_CSV_FILE_NAME, ModelExecutionMode
|
||||
from InnerEye.ML.common import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, DATASET_CSV_FILE_NAME, ModelExecutionMode
|
||||
from InnerEye.ML.configs.unit_testing.passthrough_model import PassThroughModel
|
||||
from InnerEye.ML.deep_learning_config import DeepLearningConfig
|
||||
from InnerEye.ML.metrics import InferenceMetricsForSegmentation
|
||||
|
@ -282,7 +282,7 @@ def run_model_inference_train_and_test(test_output_dirs: OutputFolderForTests,
|
|||
train_and_test_data_small_dir,
|
||||
"data")
|
||||
|
||||
checkpoint_path = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
checkpoint_path = config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
create_model_and_store_checkpoint(config, checkpoint_path)
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
|
||||
project_root=test_output_dirs.root_dir)
|
||||
|
|
|
@ -14,7 +14,7 @@ from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
|||
from InnerEye.Common.metrics_constants import MetricsFileColumns
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML import model_testing
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, DATASET_CSV_FILE_NAME, ModelExecutionMode
|
||||
from InnerEye.ML.common import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, DATASET_CSV_FILE_NAME, ModelExecutionMode
|
||||
from InnerEye.ML.config import DATASET_ID_FILE, GROUND_TRUTH_IDS_FILE, ModelArchitectureConfig
|
||||
from InnerEye.ML.dataset.full_image_dataset import FullImageDataset
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
|
@ -98,7 +98,7 @@ def test_model_test(
|
|||
execution_mode = ModelExecutionMode.TEST
|
||||
checkpoint_handler = get_default_checkpoint_handler(model_config=config, project_root=test_output_dirs.root_dir)
|
||||
# Mimic the behaviour that checkpoints are downloaded from blob storage into the checkpoints folder.
|
||||
create_model_and_store_checkpoint(config, config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX)
|
||||
create_model_and_store_checkpoint(config, config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX)
|
||||
checkpoint_handler.additional_training_done()
|
||||
inference_results = model_testing.segmentation_model_test(config,
|
||||
execution_mode=execution_mode,
|
||||
|
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
|
||||
from InnerEye.Common.metrics_constants import MetricType
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, ModelExecutionMode
|
||||
from InnerEye.ML.common import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, ModelExecutionMode
|
||||
from InnerEye.ML.configs.classification.DummyClassification import DummyClassification
|
||||
from InnerEye.ML.metrics import InferenceMetricsForClassification
|
||||
from InnerEye.ML.model_testing import model_test
|
||||
|
@ -32,9 +32,8 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool,
|
|||
config.mean_teacher_alpha = 0.999
|
||||
config.set_output_to(test_output_dirs.root_dir / "original")
|
||||
os.makedirs(str(config.outputs_folder))
|
||||
config.recovery_checkpoint_save_interval = 2
|
||||
|
||||
train_results, checkpoint_handler = model_train_unittest(config, dirs=test_output_dirs)
|
||||
train_results, checkpoint_handler = model_train_unittest(config, output_folder=test_output_dirs)
|
||||
assert len(train_results.train_results_per_epoch()) == config.num_epochs
|
||||
|
||||
# Run inference on this
|
||||
|
@ -69,7 +68,7 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool,
|
|||
os.makedirs(str(config_local_weights.outputs_folder))
|
||||
|
||||
local_weights_path = test_output_dirs.root_dir / "local_weights_file.pth"
|
||||
shutil.copyfile(str(config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX),
|
||||
shutil.copyfile(str(config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX),
|
||||
local_weights_path)
|
||||
config_local_weights.local_weights_path = [local_weights_path]
|
||||
|
||||
|
@ -81,3 +80,41 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool,
|
|||
assert isinstance(test_results_local_weights, InferenceMetricsForClassification)
|
||||
assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
|
||||
test_results_local_weights.metrics.values()[MetricType.CROSS_ENTROPY.value]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_epochs", [1, 2])
|
||||
def test_autosave_checkpoints(test_output_dirs: OutputFolderForTests, num_epochs: int) -> None:
|
||||
"""
|
||||
Tests that all autosave checkpoints are cleaned up after training.
|
||||
"""
|
||||
# Lightning does not overwrite checkpoints in-place. Rather, it writes "autosave.ckpt",
|
||||
# then "autosave-1.ckpt" and deletes "autosave.ckpt", then "autosave.ckpt" and deletes "autosave-v1.ckpt"
|
||||
# All those checkpoints should be cleaned up after training, only the best checkpoint should remain.
|
||||
config = DummyClassification()
|
||||
config.autosave_every_n_val_epochs = 1
|
||||
config.set_output_to(test_output_dirs.root_dir)
|
||||
config.num_epochs = num_epochs
|
||||
model_train_unittest(config, output_folder=test_output_dirs)
|
||||
assert len(list(config.checkpoint_folder.glob("*.*"))) == 1
|
||||
assert (config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
|
||||
|
||||
|
||||
def test_recovery_e2e(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test restarting a training: Train a small model for 5 epochs, then continue training to epoch 10 from the results
|
||||
of the first training run.
|
||||
"""
|
||||
model_config = DummyClassification()
|
||||
model_config.set_output_to(test_output_dirs.root_dir)
|
||||
num_epochs_1 = 5
|
||||
model_config.num_epochs = num_epochs_1
|
||||
storing_logger_1, checkpoint_handler = model_train_unittest(model_config, output_folder=test_output_dirs)
|
||||
# Logger should have results for epochs 0..4
|
||||
assert list(storing_logger_1.epochs) == list(range(num_epochs_1))
|
||||
# Now restart the job, train to epoch 10
|
||||
num_epochs_2 = 10
|
||||
model_config.num_epochs = num_epochs_2
|
||||
storing_logger_2, _ = model_train_unittest(model_config, output_folder=test_output_dirs,
|
||||
checkpoint_handler=checkpoint_handler)
|
||||
# Logger should have results only for epochs 5..9
|
||||
assert list(storing_logger_2.epochs) == list(range(num_epochs_1, num_epochs_2))
|
||||
|
|
|
@ -2,26 +2,25 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import h5py
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import shutil
|
||||
|
||||
from pathlib import Path
|
||||
from torch.utils.data import DataLoader
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, is_windows, logging_to_stdout
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.metrics_constants import MetricType, TrackedMetrics, VALIDATION_PREFIX
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, CHECKPOINT_SUFFIX, DATASET_CSV_FILE_NAME, \
|
||||
ModelExecutionMode, \
|
||||
RECOVERY_CHECKPOINT_FILE_NAME, STORED_CSV_FILE_NAMES
|
||||
from InnerEye.ML.common import (LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX,
|
||||
DATASET_CSV_FILE_NAME, ModelExecutionMode, STORED_CSV_FILE_NAMES)
|
||||
from InnerEye.ML.config import MixtureLossComponent, SegmentationLoss
|
||||
from InnerEye.ML.configs.classification.DummyClassification import DummyClassification
|
||||
from InnerEye.ML.dataset.sample import CroppedSample
|
||||
|
@ -100,7 +99,6 @@ def _test_model_train(output_dirs: OutputFolderForTests,
|
|||
train_config.random_seed = 42
|
||||
train_config.class_weights = [0.5, 0.25, 0.25]
|
||||
train_config.store_dataset_sample = no_mask_channel
|
||||
train_config.recovery_checkpoint_save_interval = 1
|
||||
train_config.check_exclusive = False
|
||||
|
||||
if machine_has_gpu:
|
||||
|
@ -112,7 +110,7 @@ def _test_model_train(output_dirs: OutputFolderForTests,
|
|||
loss_absolute_tolerance = 1e-6
|
||||
expected_learning_rates = [train_config.l_rate, 5.3589e-4]
|
||||
|
||||
model_training_result, _ = model_train_unittest(train_config, dirs=output_dirs)
|
||||
model_training_result, _ = model_train_unittest(train_config, output_folder=output_dirs)
|
||||
assert isinstance(model_training_result, StoringLogger)
|
||||
# Check that all metrics from the BatchTimeCallback are present
|
||||
# # TODO: re-enable once the BatchTimeCallback is fixed
|
||||
|
@ -193,10 +191,8 @@ def _test_model_train(output_dirs: OutputFolderForTests,
|
|||
# Checkpoint folder
|
||||
assert train_config.checkpoint_folder.is_dir()
|
||||
actual_checkpoints = list(train_config.checkpoint_folder.rglob("*.ckpt"))
|
||||
assert len(actual_checkpoints) == 2, f"Actual checkpoints: {actual_checkpoints}"
|
||||
assert (train_config.checkpoint_folder / str(
|
||||
RECOVERY_CHECKPOINT_FILE_NAME + f"_epoch={train_config.num_epochs - 1}" + CHECKPOINT_SUFFIX)).is_file()
|
||||
assert (train_config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
|
||||
assert len(actual_checkpoints) == 1, f"Actual checkpoints: {actual_checkpoints}"
|
||||
assert (train_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
|
||||
assert (train_config.outputs_folder / DATASET_CSV_FILE_NAME).is_file()
|
||||
assert (train_config.outputs_folder / STORED_CSV_FILE_NAMES[ModelExecutionMode.TRAIN]).is_file()
|
||||
assert (train_config.outputs_folder / STORED_CSV_FILE_NAMES[ModelExecutionMode.VAL]).is_file()
|
||||
|
@ -324,16 +320,17 @@ def test_recover_training_mean_teacher_model(test_output_dirs: OutputFolderForTe
|
|||
"""
|
||||
config = DummyClassification()
|
||||
config.mean_teacher_alpha = 0.999
|
||||
config.recovery_checkpoint_save_interval = 1
|
||||
config.autosave_every_n_val_epochs = 1
|
||||
config.set_output_to(test_output_dirs.root_dir / "original")
|
||||
os.makedirs(str(config.outputs_folder))
|
||||
|
||||
original_checkpoint_folder = config.checkpoint_folder
|
||||
|
||||
# First round of training
|
||||
config.num_epochs = 2
|
||||
model_train_unittest(config, dirs=test_output_dirs)
|
||||
assert len(list(config.checkpoint_folder.glob("*.*"))) == 2
|
||||
config.num_epochs = 4
|
||||
model_train_unittest(config, output_folder=test_output_dirs)
|
||||
assert len(list(config.checkpoint_folder.glob("*.*"))) == 1
|
||||
assert (config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
|
||||
|
||||
# Restart training from previous run
|
||||
config.num_epochs = 3
|
||||
|
@ -348,10 +345,10 @@ def test_recover_training_mean_teacher_model(test_output_dirs: OutputFolderForTe
|
|||
project_root=test_output_dirs.root_dir)
|
||||
checkpoint_handler.run_recovery = RunRecovery([checkpoint_root])
|
||||
|
||||
model_train_unittest(config, dirs=test_output_dirs, checkpoint_handler=checkpoint_handler)
|
||||
model_train_unittest(config, output_folder=test_output_dirs, checkpoint_handler=checkpoint_handler)
|
||||
# remove recovery checkpoints
|
||||
shutil.rmtree(checkpoint_root)
|
||||
assert len(list(config.checkpoint_folder.glob("*.*"))) == 2
|
||||
assert len(list(config.checkpoint_folder.glob("*.ckpt"))) == 1
|
||||
|
||||
|
||||
def test_script_names_correct() -> None:
|
||||
|
|
|
@ -189,7 +189,7 @@ def test_compare_folder_against_run(test_output_dirs: OutputFolderForTests) -> N
|
|||
FINAL_MODEL_FOLDER / MODEL_INFERENCE_JSON_FILE_NAME
|
||||
create_folder_and_write_text(file1,
|
||||
'{"model_name": "BasicModel2Epochs", "checkpoint_paths": ['
|
||||
'"checkpoints/best_checkpoint.ckpt"], '
|
||||
'"checkpoints/last.ckpt"], '
|
||||
'"model_configs_namespace": "InnerEye.ML.configs.segmentation.BasicModel2Epochs"}')
|
||||
with mock.patch("InnerEye.ML.baselines_util.RUN_CONTEXT", run):
|
||||
# First comparison only on the .json file should pass
|
||||
|
|
|
@ -277,7 +277,7 @@ def get_default_workspace() -> Workspace:
|
|||
|
||||
|
||||
def model_train_unittest(config: Optional[DeepLearningConfig],
|
||||
dirs: OutputFolderForTests,
|
||||
output_folder: Union[OutputFolderForTests, Path],
|
||||
checkpoint_handler: Optional[CheckpointHandler] = None,
|
||||
lightning_container: Optional[LightningContainer] = None) -> \
|
||||
Tuple[StoringLogger, CheckpointHandler]:
|
||||
|
@ -285,7 +285,7 @@ def model_train_unittest(config: Optional[DeepLearningConfig],
|
|||
A shortcut for running model training in the unit test suite. It runs training for the given config, with the
|
||||
default checkpoint handler initialized to point to the test output folder specified in dirs.
|
||||
:param config: The configuration of the model to train.
|
||||
:param dirs: The test fixture that provides an output folder for the test.
|
||||
:param output_folder: The test fixture that provides an output folder for the test.
|
||||
:param lightning_container: An optional LightningContainer object that will be pass through to the training routine.
|
||||
:param checkpoint_handler: The checkpoint handler that should be used for training. If not provided, it will be
|
||||
created via get_default_checkpoint_handler.
|
||||
|
@ -299,9 +299,10 @@ def model_train_unittest(config: Optional[DeepLearningConfig],
|
|||
runner.setup()
|
||||
if checkpoint_handler is None:
|
||||
azure_config = get_default_azure_config()
|
||||
output_folder = output_folder if isinstance(output_folder, Path) else output_folder.root_dir
|
||||
checkpoint_handler = CheckpointHandler(azure_config=azure_config,
|
||||
container=runner.container,
|
||||
project_root=dirs.root_dir)
|
||||
project_root=output_folder)
|
||||
_, storing_logger = model_train(checkpoint_path=checkpoint_handler.get_recovery_or_checkpoint_path_train(),
|
||||
container=runner.container)
|
||||
checkpoint_handler.additional_training_done()
|
||||
|
|
|
@ -14,7 +14,7 @@ from InnerEye.Common.common_util import OTHER_RUNS_SUBDIR_NAME
|
|||
from InnerEye.Common.fixed_paths import MODEL_INFERENCE_JSON_FILE_NAME
|
||||
from InnerEye.ML.utils.checkpoint_handling import MODEL_WEIGHTS_DIR_NAME, get_recovery_checkpoint_path
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER
|
||||
from InnerEye.ML.common import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
from InnerEye.ML.model_inference_config import read_model_inference_config
|
||||
from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler
|
||||
|
@ -86,7 +86,7 @@ def test_download_recovery_checkpoints_from_single_run(test_output_dirs: OutputF
|
|||
|
||||
expected_checkpoint_root = config.checkpoint_folder
|
||||
expected_paths = [get_recovery_checkpoint_path(path=expected_checkpoint_root),
|
||||
expected_checkpoint_root / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
|
||||
expected_checkpoint_root / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
|
||||
assert checkpoint_handler.run_recovery.checkpoints_roots == [expected_checkpoint_root]
|
||||
for path in expected_paths:
|
||||
assert path.is_file()
|
||||
|
@ -199,7 +199,7 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
|
|||
# in the run, into a subfolder of the checkpoint folder
|
||||
checkpoint_handler.azure_config.run_recovery_id = run_recovery_id
|
||||
checkpoint_handler.download_recovery_checkpoints_or_weights()
|
||||
expected_checkpoint = config.checkpoint_folder / f"{BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX}"
|
||||
expected_checkpoint = config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
|
||||
assert checkpoint_paths
|
||||
assert len(checkpoint_paths) == 1
|
||||
|
@ -215,13 +215,13 @@ def test_get_best_checkpoint_single_run(test_output_dirs: OutputFolderForTests)
|
|||
|
||||
# There is no checkpoint in the current run - use the one from run_recovery
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
|
||||
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint = config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
assert checkpoint_paths
|
||||
assert len(checkpoint_paths) == 1
|
||||
assert checkpoint_paths[0] == expected_checkpoint
|
||||
|
||||
# Copy over checkpoints to make it look like training has happened and a better checkpoint written
|
||||
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint = config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint.touch()
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
|
||||
assert checkpoint_paths
|
||||
|
@ -238,7 +238,7 @@ def test_download_checkpoints_from_hyperdrive_child_runs(test_output_dirs: Outpu
|
|||
hyperdrive_run = get_most_recent_run(fallback_run_id_for_local_execution=FALLBACK_ENSEMBLE_RUN)
|
||||
checkpoint_handler.download_checkpoints_from_hyperdrive_child_runs(hyperdrive_run)
|
||||
expected_checkpoints = [config.checkpoint_folder / OTHER_RUNS_SUBDIR_NAME / str(i)
|
||||
/ BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX for i in range(2)]
|
||||
/ LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX for i in range(2)]
|
||||
checkpoint_paths = checkpoint_handler.get_best_checkpoints()
|
||||
assert checkpoint_paths
|
||||
assert len(checkpoint_paths) == 2
|
||||
|
@ -266,7 +266,7 @@ def test_get_checkpoints_to_test(test_output_dirs: OutputFolderForTests) -> None
|
|||
checkpoint_handler.container.checkpoint_folder.mkdir(parents=True)
|
||||
|
||||
# Copy checkpoint to make it seem like training has happened
|
||||
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint = config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint.touch()
|
||||
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
|
||||
|
||||
|
@ -295,10 +295,10 @@ def test_get_checkpoints_to_test_single_run(test_output_dirs: OutputFolderForTes
|
|||
|
||||
assert checkpoint_and_paths
|
||||
assert len(checkpoint_and_paths) == 1
|
||||
assert checkpoint_and_paths[0] == config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
assert checkpoint_and_paths[0] == config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
|
||||
# Copy checkpoint to make it seem like training has happened
|
||||
expected_checkpoint = config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint = config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
expected_checkpoint.touch()
|
||||
checkpoint_and_paths = checkpoint_handler.get_checkpoints_to_test()
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
@ -11,17 +10,15 @@ import torch
|
|||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, LAST_CHECKPOINT_FILE_NAME, \
|
||||
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, RECOVERY_CHECKPOINT_FILE_NAME
|
||||
from InnerEye.ML.utils.checkpoint_handling import create_best_checkpoint, extract_latest_checkpoint_and_epoch, \
|
||||
find_all_recovery_checkpoints, \
|
||||
find_recovery_checkpoint_and_epoch
|
||||
from InnerEye.ML.common import (AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME,
|
||||
LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX, LEGACY_RECOVERY_CHECKPOINT_FILE_NAME)
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.lightning_base import InnerEyeContainer
|
||||
from InnerEye.ML.lightning_helpers import load_from_checkpoint_and_adjust_for_inference
|
||||
from InnerEye.ML.lightning_models import create_lightning_model
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase
|
||||
from InnerEye.ML.model_training import create_lightning_trainer
|
||||
from InnerEye.ML.utils.checkpoint_handling import (cleanup_checkpoints, find_recovery_checkpoint)
|
||||
from Tests.ML.configs.ClassificationModelForTesting import ClassificationModelForTesting
|
||||
from Tests.ML.configs.DummyModel import DummyModel
|
||||
from Tests.ML.util import machine_has_gpu
|
||||
|
@ -98,66 +95,45 @@ def test_checkpoint_path() -> None:
|
|||
assert LAST_CHECKPOINT_FILE_NAME == ModelCheckpoint.CHECKPOINT_NAME_LAST
|
||||
|
||||
|
||||
def test_find_all_recovery_checkpoints(test_output_dirs: OutputFolderForTests) -> None:
|
||||
def test_recovery_checkpoints_fails(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Using old recovering checkpoints is not supported, and should raise an error.
|
||||
"""
|
||||
checkpoint_folder = test_output_dirs.root_dir
|
||||
# No recovery yet available
|
||||
(checkpoint_folder / "epoch=2.ckpt").touch()
|
||||
assert find_all_recovery_checkpoints(checkpoint_folder) is None
|
||||
# Add recovery file to fake folder
|
||||
file_list = ["recovery_epoch=1.ckpt", "recovery.ckpt"]
|
||||
for f in file_list:
|
||||
(checkpoint_folder / f).touch()
|
||||
found_file_names = set([f.stem for f in find_all_recovery_checkpoints(checkpoint_folder)]) # type: ignore
|
||||
assert len(found_file_names.difference(found_file_names)) == 0
|
||||
assert find_recovery_checkpoint(checkpoint_folder) is None
|
||||
(checkpoint_folder / LEGACY_RECOVERY_CHECKPOINT_FILE_NAME).touch()
|
||||
with pytest.raises(ValueError) as ex:
|
||||
find_recovery_checkpoint(checkpoint_folder)
|
||||
assert "The legacy recovery checkpoint setup is no longer supported." in str(ex)
|
||||
|
||||
|
||||
def test_find_latest_checkpoint_and_epoch() -> None:
|
||||
file_list = [Path("epoch=1.ckpt"), Path("epoch=3.ckpt"), Path("epoch=2.ckpt")]
|
||||
assert Path("epoch=3.ckpt"), 3 == extract_latest_checkpoint_and_epoch(file_list)
|
||||
invalid_file_list = [Path("epoch.ckpt"), Path("epoch=3.ckpt"), Path("epoch=2.ckpt")]
|
||||
with pytest.raises(IndexError):
|
||||
extract_latest_checkpoint_and_epoch(invalid_file_list)
|
||||
|
||||
|
||||
def test_find_recovery_checkpoint(test_output_dirs: OutputFolderForTests) -> None:
|
||||
def test_find_all_recovery_checkpoints(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test if the logic to keep only the most recently modified file works.
|
||||
Test if the search for recovery checkpoints respects the correct order of files
|
||||
"""
|
||||
folder = test_output_dirs.root_dir
|
||||
prefix = RECOVERY_CHECKPOINT_FILE_NAME
|
||||
file1 = folder / (prefix + "epoch=1.txt")
|
||||
file2 = folder / (prefix + "epoch=2.txt")
|
||||
# No file present yet
|
||||
assert find_recovery_checkpoint_and_epoch(folder) is None
|
||||
# Single file present: This should be returned.
|
||||
file1.touch()
|
||||
# Without sleeping, the test can fail in Azure build agents
|
||||
time.sleep(0.1)
|
||||
recovery = find_recovery_checkpoint_and_epoch(folder)
|
||||
assert recovery is not None
|
||||
latest_checkpoint, latest_epoch = recovery
|
||||
assert latest_checkpoint == file1
|
||||
assert latest_epoch == 1
|
||||
assert latest_checkpoint.is_file()
|
||||
# Two files present: keep file2 should be returned
|
||||
file2.touch()
|
||||
time.sleep(0.1)
|
||||
recovery = find_recovery_checkpoint_and_epoch(folder)
|
||||
assert recovery is not None
|
||||
latest_checkpoint, latest_epoch = recovery
|
||||
assert latest_checkpoint == file2
|
||||
assert latest_checkpoint.is_file()
|
||||
assert latest_epoch == 2
|
||||
# Add file1 again: file should should still be returned as it has the
|
||||
# highest epoch number
|
||||
file1.touch()
|
||||
time.sleep(0.1)
|
||||
recovery = find_recovery_checkpoint_and_epoch(folder)
|
||||
assert recovery is not None
|
||||
latest_checkpoint, latest_epoch = recovery
|
||||
assert latest_checkpoint == file2
|
||||
assert latest_checkpoint.is_file()
|
||||
assert latest_epoch == 2
|
||||
checkpoint_folder = test_output_dirs.root_dir
|
||||
# If the checkpoint folder only contains a single checkpoint file of whatever kind, return that.
|
||||
single_files = [*AUTOSAVE_CHECKPOINT_CANDIDATES, LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX]
|
||||
for i, file in enumerate(single_files):
|
||||
subfolder = checkpoint_folder / str(i)
|
||||
subfolder.mkdir()
|
||||
full_file = subfolder / file
|
||||
torch.save({"epoch": 1}, full_file)
|
||||
result = find_recovery_checkpoint(subfolder)
|
||||
assert result is not None
|
||||
assert result.name == file
|
||||
|
||||
# If both "autosave" and "best_checkpoint" are present, return the one with the highest epoch
|
||||
both = checkpoint_folder / "both"
|
||||
both.mkdir()
|
||||
file_with_highest_epoch = AUTOSAVE_CHECKPOINT_CANDIDATES[1]
|
||||
for file in single_files:
|
||||
full_file = both / file
|
||||
epoch = 100 if file == file_with_highest_epoch else 1
|
||||
torch.save({"epoch": epoch}, full_file)
|
||||
result_both = find_recovery_checkpoint(both)
|
||||
assert result_both is not None
|
||||
assert result_both.name == file_with_highest_epoch
|
||||
|
||||
|
||||
def test_keep_best_checkpoint(test_output_dirs: OutputFolderForTests) -> None:
|
||||
|
@ -166,41 +142,18 @@ def test_keep_best_checkpoint(test_output_dirs: OutputFolderForTests) -> None:
|
|||
"""
|
||||
folder = test_output_dirs.root_dir
|
||||
with pytest.raises(FileNotFoundError) as ex:
|
||||
create_best_checkpoint(folder)
|
||||
cleanup_checkpoints(folder)
|
||||
assert "Checkpoint file" in str(ex)
|
||||
# Create a folder with a "last" and "autosave" checkpoint, as they come out of the trainer loop.
|
||||
last = folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
last.touch()
|
||||
actual = create_best_checkpoint(folder)
|
||||
assert not last.is_file(), "Checkpoint file should have been renamed"
|
||||
expected = folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
assert actual == expected
|
||||
assert actual.is_file()
|
||||
|
||||
|
||||
def test_cleanup_checkpoints1(test_output_dirs: OutputFolderForTests) -> None:
|
||||
folder = test_output_dirs.root_dir
|
||||
with pytest.raises(FileNotFoundError) as ex:
|
||||
create_best_checkpoint(folder)
|
||||
assert "Checkpoint file" in str(ex)
|
||||
# Single checkpoint file, nothing else: This file should be rename to best_checkpoint
|
||||
last = folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
last.touch()
|
||||
create_best_checkpoint(folder)
|
||||
for autosave in AUTOSAVE_CHECKPOINT_CANDIDATES:
|
||||
(folder / autosave).touch()
|
||||
assert len(list(folder.glob("*"))) > 1
|
||||
cleanup_checkpoints(folder)
|
||||
# All code outside the trainer loop assumes that there is a checkpoint with this name. The constant actually
|
||||
# matches "last.ckpt", but the constant is kept to reduce code changes from the legacy behaviour.
|
||||
expected = folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
assert expected.is_file()
|
||||
# The autosave checkpoint should be deleted after training, only the single best checkpoint should remain
|
||||
assert len(list(folder.glob("*"))) == 1
|
||||
assert (folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
|
||||
|
||||
|
||||
def test_cleanup_checkpoints2(test_output_dirs: OutputFolderForTests) -> None:
|
||||
# Single checkpoint file and two recovery checkpoints: Should keep the last and rename it.
|
||||
folder = test_output_dirs.root_dir
|
||||
last = folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
last.touch()
|
||||
(folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-epoch=3").touch()
|
||||
(folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-epoch=6").touch()
|
||||
# Before cleanup: last.ckpt, recovery-epoch=6.ckpt, recovery-epoch=3.ckpt
|
||||
create_best_checkpoint(folder)
|
||||
# After cleanup: best.ckpt, recovery-epoch=6.ckpt, recovery-epoch=3.ckpt
|
||||
assert len(list(folder.glob("*"))) == 3
|
||||
assert (folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
|
||||
assert (folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-epoch=6").is_file()
|
||||
assert (folder / f"{RECOVERY_CHECKPOINT_FILE_NAME}-epoch=3").is_file()
|
||||
|
|
|
@ -28,7 +28,7 @@ from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
|
|||
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
|
||||
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
|
||||
from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
from InnerEye.ML.common import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
from InnerEye.ML.configs.ssl.CIFAR_SSL_configs import CIFAR10SimCLR
|
||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier
|
||||
from InnerEye.ML.runner import Runner
|
||||
|
@ -114,7 +114,6 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None:
|
|||
assert loaded_config.encoder_output_dim == 2048
|
||||
assert loaded_config.l_rate == 1e-4
|
||||
assert loaded_config.num_epochs == 1
|
||||
assert loaded_config.recovery_checkpoint_save_interval == 200
|
||||
assert loaded_config.ssl_training_type == SSLTrainingType.SimCLR
|
||||
assert loaded_config.online_eval.num_classes == 10
|
||||
assert loaded_config.online_eval.dataset == SSLDatasetName.CIFAR10.value
|
||||
|
@ -131,13 +130,12 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None:
|
|||
'simclr/train/loss': 3.6261844635009766,
|
||||
'simclr/learning_rate': 0.0,
|
||||
'ssl_online_evaluator/train/loss': 3.1140503883361816,
|
||||
'ssl_online_evaluator/train/online_AccuracyAtThreshold05': 0.0,
|
||||
'epoch_started': 0.0}
|
||||
'ssl_online_evaluator/train/online_AccuracyAtThreshold05': 0.0}
|
||||
|
||||
_compare_stored_metrics(runner, expected_metrics, abs=5e-5)
|
||||
|
||||
# Check that the checkpoint contains both the optimizer for the embedding and for the linear head
|
||||
checkpoint_path = loaded_config.outputs_folder / "checkpoints" / "best_checkpoint.ckpt"
|
||||
checkpoint_path = loaded_config.outputs_folder / "checkpoints" / "last.ckpt"
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
assert len(checkpoint["optimizer_states"]) == 1
|
||||
assert len(checkpoint["lr_schedulers"]) == 1
|
||||
|
@ -225,13 +223,12 @@ def test_innereye_ssl_container_rsna() -> None:
|
|||
'ssl_online_evaluator/train/loss': 0.6938587427139282,
|
||||
'ssl_online_evaluator/train/online_AreaUnderRocCurve': 0.5,
|
||||
'ssl_online_evaluator/train/online_AreaUnderPRCurve': 0.6000000238418579,
|
||||
'ssl_online_evaluator/train/online_AccuracyAtThreshold05': 0.20000000298023224,
|
||||
'epoch_started': 0.0}
|
||||
'ssl_online_evaluator/train/online_AccuracyAtThreshold05': 0.20000000298023224}
|
||||
|
||||
_compare_stored_metrics(runner, expected_metrics)
|
||||
|
||||
# Check that we are able to load the checkpoint and create classifier model
|
||||
checkpoint_path = loaded_config.checkpoint_folder / BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
checkpoint_path = loaded_config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
args = common_test_args + ["--model=CXRImageClassifier",
|
||||
f"--local_dataset={str(path_to_test_dataset)}",
|
||||
"--use_balanced_binary_loss_for_linear_head=True",
|
||||
|
@ -413,7 +410,7 @@ def test_online_evaluator_distributed() -> None:
|
|||
assert callback.evaluator == mock_ddp_result
|
||||
|
||||
|
||||
def test_simclr_batch_size() -> None:
|
||||
def test_simclr_num_nodes() -> None:
|
||||
"""
|
||||
Test if the number of nodes is correctly passed through to the SIMCLR model. After an update of the semantics of
|
||||
the "gpus" argument in LightningBolts, we had a regression, leading to incorrect use of the cosine
|
||||
|
@ -434,3 +431,44 @@ def test_simclr_batch_size() -> None:
|
|||
container.num_nodes = 2
|
||||
model2 = container.create_model()
|
||||
assert model2.train_iters_per_epoch == old_iters_per_epoch // container.num_nodes # type:ignore
|
||||
|
||||
|
||||
def test_simclr_num_gpus() -> None:
|
||||
"""
|
||||
Test if the number of GPUs is correctly passed through to the SIMCLR model.
|
||||
"""
|
||||
device_count = 8
|
||||
num_epochs = 30
|
||||
# Warmup epochs == 10 is hardcoded in SIMClr. The core SIMClr module has an argument for it, but we are not
|
||||
# passing that through.
|
||||
warmup_epochs = 10
|
||||
with mock.patch("torch.cuda.device_count", return_value=device_count):
|
||||
with mock.patch("InnerEye.ML.deep_learning_config.TrainerParams.use_gpu", return_value=True):
|
||||
with mock.patch("InnerEye.ML.SSL.lightning_containers.ssl_container.get_encoder_output_dim", return_value=1):
|
||||
container = CIFAR10SimCLR()
|
||||
container.num_epochs = num_epochs
|
||||
num_samples = 800
|
||||
batch_size = 10
|
||||
container.data_module = mock.MagicMock(num_samples=num_samples, batch_size=batch_size)
|
||||
model1 = container.create_model()
|
||||
assert model1.train_iters_per_epoch == num_samples // (batch_size * device_count)
|
||||
# Reducing the number of GPUs should decrease effective batch size, and hence increase number of
|
||||
# iterations per epoch
|
||||
container.max_num_gpus = 4
|
||||
model2 = container.create_model()
|
||||
assert model2.train_iters_per_epoch == num_samples // (batch_size * container.max_num_gpus)
|
||||
scheduler = model2.configure_optimizers()[1][0]["scheduler"]
|
||||
|
||||
total_training_steps = model2.train_iters_per_epoch * num_epochs # type: ignore
|
||||
warmup_steps = model2.train_iters_per_epoch * warmup_epochs # type: ignore
|
||||
previous_lr = None
|
||||
for i in range(total_training_steps):
|
||||
lr = scheduler.get_last_lr()
|
||||
if previous_lr is not None:
|
||||
if i <= warmup_steps:
|
||||
assert lr > previous_lr, "During warmup, LR should increase"
|
||||
else:
|
||||
assert lr < previous_lr, "After warmup, LR should decrease"
|
||||
print(f"Iteration {i}: LR = {lr}")
|
||||
scheduler.step()
|
||||
previous_lr = lr
|
||||
|
|
Загрузка…
Ссылка в новой задаче