SimCLR recovery test (#643)
This commit is contained in:
Родитель
884e3fd7dc
Коммит
6a4919c361
|
@ -11,6 +11,8 @@ created.
|
|||
|
||||
|
||||
## Upcoming
|
||||
- ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train
|
||||
loss.
|
||||
|
||||
### Added
|
||||
- ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run.
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Callable, Optional
|
||||
from enum import Enum
|
||||
from yacs.config import CfgNode
|
||||
|
||||
import pandas as pd
|
||||
import param
|
||||
|
@ -13,10 +15,17 @@ from torchmetrics.regression import MeanSquaredError
|
|||
from torch import Tensor
|
||||
from torch.nn import Identity
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
from torchvision.transforms import Lambda
|
||||
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer, LightningModuleWithOptimizer
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import DualViewTransformWrapper
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
|
||||
from InnerEye.ML.SSL.utils import SSLTrainingType
|
||||
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
|
||||
|
||||
|
||||
class DummyContainerWithDatasets(LightningContainer):
|
||||
|
@ -262,3 +271,84 @@ class DummyContainerWithHooks(LightningContainer):
|
|||
assert self.hook_local_zero.is_file(), "before_training_on_local_rank_zero should have been called already"
|
||||
assert not self.hook_all.is_file(), "before_training_on_all_ranks should only be called once"
|
||||
self.hook_all.touch()
|
||||
|
||||
|
||||
class DummySimCLRData(VisionDataset):
|
||||
"""
|
||||
Returns a constant vector of size three [1., 1., 1.]
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
train: bool = True,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
download: bool = False,
|
||||
) -> None:
|
||||
super(DummySimCLRData, self).__init__(root, transform=transform,
|
||||
target_transform=target_transform)
|
||||
|
||||
self.train = train
|
||||
self.data = torch.ones(20, 1, 1, 3)
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
(Any): Sample and meta data, optionally transformed by the respective transforms.
|
||||
"""
|
||||
# item = (self.data[index], self.data[index]), torch.Tensor([0])
|
||||
image = self.data[index]
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
return image, 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.data.shape[0]
|
||||
|
||||
class DummySimCLRInnerEyeData(InnerEyeDataClassBaseWithReturnIndex, DummySimCLRData):
|
||||
"""
|
||||
Wrapper class around the DummySimCLRData class to optionally return the
|
||||
index on top of the image and the label in __getitem__ as well as defining num_classes property.
|
||||
"""
|
||||
|
||||
@property
|
||||
def num_classes(self) -> int:
|
||||
return 2
|
||||
|
||||
class DummySimCLRSSLDatasetName(SSLDatasetName, Enum):
|
||||
DUMMY = "DUMMY"
|
||||
|
||||
class DummySimCLR(SSLContainer):
|
||||
"""
|
||||
This module trains an SSL encoder using SimCLR on the DummySimCLRData and finetunes a linear head too.
|
||||
"""
|
||||
SSLContainer._SSLDataClassMappings.update({DummySimCLRSSLDatasetName.DUMMY.value: DummySimCLRInnerEyeData})
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(ssl_training_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
|
||||
linear_head_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
|
||||
# Train with as little data as possible for the test
|
||||
ssl_training_batch_size=2,
|
||||
linear_head_batch_size=2,
|
||||
ssl_encoder=EncoderName.resnet50, # This gets overwritten by the test itself
|
||||
ssl_training_type=SSLTrainingType.SimCLR,
|
||||
random_seed=0,
|
||||
num_epochs=20,
|
||||
num_workers=0,
|
||||
max_num_gpus=1)
|
||||
|
||||
def _get_transforms(self, augmentation_config: Optional[CfgNode],
|
||||
dataset_name: str,
|
||||
is_ssl_encoder_module: bool) -> Tuple[Any, Any]:
|
||||
|
||||
# is_ssl_encoder_module will be True for ssl training, False for linear head training
|
||||
train_transforms = ImageTransformationPipeline([Lambda(lambda x: x)]) # do nothing
|
||||
val_transforms = ImageTransformationPipeline([Lambda(lambda x: x + 1)]) # add 1
|
||||
|
||||
if is_ssl_encoder_module:
|
||||
train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore
|
||||
val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore
|
||||
return train_transforms, val_transforms
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
|
@ -13,7 +13,7 @@ import pytest
|
|||
import torch
|
||||
from pl_bolts.models.self_supervised.resnets import ResNet
|
||||
from pl_bolts.optimizers import linear_warmup_decay
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from torch.nn import Module
|
||||
|
@ -23,6 +23,7 @@ from InnerEye.Common import fixed_paths
|
|||
from InnerEye.Common.fixed_paths import repository_root_directory
|
||||
from InnerEye.Common.fixed_paths_for_tests import TEST_OUTPUTS_PATH
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.lightning_loggers import StoringLogger
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName
|
||||
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
|
||||
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
|
||||
|
@ -33,8 +34,9 @@ from InnerEye.ML.common import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
|||
from InnerEye.ML.configs.ssl.CIFAR_SSL_configs import CIFAR10SimCLR
|
||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR
|
||||
from InnerEye.ML.runner import Runner
|
||||
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel
|
||||
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel, DummySimCLR
|
||||
from Tests.ML.utils.test_io_util import write_test_dicom
|
||||
from health_ml.utils import AzureMLProgressBar
|
||||
|
||||
path_to_cxr_test_dataset = TEST_OUTPUTS_PATH / "cxr_test_dataset"
|
||||
|
||||
|
@ -281,6 +283,67 @@ def test_simclr_lr_scheduler() -> None:
|
|||
for i in range(highest_lr, len(lr) - 1):
|
||||
assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}"
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_simclr_training_recovery(test_output_dirs: OutputFolderForTests) -> None:
|
||||
""" This test checks if a SSLContainer correctly resumes training.
|
||||
First we run SSL using a Trainer for 20 epochs.
|
||||
Second, we run a new SSL job for 15 epochs.
|
||||
Third we resume the job and run it for 5 more epochs.
|
||||
The test checks the learning rate and the loss.
|
||||
The test is meant to run on a GPU!
|
||||
"""
|
||||
def run_simclr_dummy_container(test_output_dirs: OutputFolderForTests,
|
||||
num_epochs: int,
|
||||
last_checkpoint: Optional[ModelCheckpoint] = None) -> Tuple[list, list, ModelCheckpoint]:
|
||||
seed_everything(0, workers=True)
|
||||
container = DummySimCLR()
|
||||
container.setup()
|
||||
model = container.create_model()
|
||||
data = container.get_data_module()
|
||||
|
||||
# add logger
|
||||
logger = StoringLogger()
|
||||
progress = AzureMLProgressBar(refresh_rate=1)
|
||||
checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints")
|
||||
checkpoint_folder.mkdir(exist_ok=True)
|
||||
checkpoint = ModelCheckpoint(dirpath=checkpoint_folder,
|
||||
every_n_val_epochs=1,
|
||||
save_last=True)
|
||||
|
||||
trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir),
|
||||
logger=logger,
|
||||
callbacks=[progress, checkpoint],
|
||||
max_epochs=num_epochs,
|
||||
resume_from_checkpoint=last_checkpoint.last_model_path if last_checkpoint is not None else None,
|
||||
deterministic=True,
|
||||
benchmark=False,
|
||||
gpus=1)
|
||||
trainer.fit(model, datamodule=data)
|
||||
|
||||
lrs = []
|
||||
loss = []
|
||||
for item in logger.results_per_epoch:
|
||||
lrs.append(logger.results_per_epoch[item]['simclr/learning_rate'])
|
||||
loss.append(logger.results_per_epoch[item]['simclr/train/loss'])
|
||||
|
||||
return lrs, loss, checkpoint
|
||||
|
||||
small_encoder = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(3, 2))
|
||||
with mock.patch("InnerEye.ML.SSL.encoders.create_ssl_encoder", return_value=small_encoder):
|
||||
with mock.patch("InnerEye.ML.SSL.encoders.get_encoder_output_dim", return_value=2):
|
||||
# Normal run
|
||||
normal_lrs, normal_loss, _ = run_simclr_dummy_container(test_output_dirs, 20, last_checkpoint=None)
|
||||
|
||||
# Short run
|
||||
short_lrs, short_loss, short_checkpoint = run_simclr_dummy_container(test_output_dirs, 15, last_checkpoint=None)
|
||||
|
||||
# Resumed run
|
||||
resumed_lrs, resumed_loss, _ = run_simclr_dummy_container(test_output_dirs, 20, last_checkpoint=short_checkpoint)
|
||||
|
||||
resumed_lrs = short_lrs + resumed_lrs
|
||||
assert resumed_lrs == normal_lrs
|
||||
resumed_loss = short_loss + resumed_loss
|
||||
assert resumed_loss == normal_loss
|
||||
|
||||
def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
|
|
Загрузка…
Ссылка в новой задаче