SimCLR recovery test (#643)
This commit is contained in:
Родитель
884e3fd7dc
Коммит
6a4919c361
|
@ -11,6 +11,8 @@ created.
|
||||||
|
|
||||||
|
|
||||||
## Upcoming
|
## Upcoming
|
||||||
|
- ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train
|
||||||
|
loss.
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
- ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run.
|
- ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run.
|
||||||
|
|
|
@ -3,7 +3,9 @@
|
||||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple, Callable, Optional
|
||||||
|
from enum import Enum
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import param
|
import param
|
||||||
|
@ -13,10 +15,17 @@ from torchmetrics.regression import MeanSquaredError
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Identity
|
from torch.nn import Identity
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torchvision.datasets.vision import VisionDataset
|
||||||
|
from torchvision.transforms import Lambda
|
||||||
|
|
||||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||||
from InnerEye.ML.common import ModelExecutionMode
|
from InnerEye.ML.common import ModelExecutionMode
|
||||||
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer, LightningModuleWithOptimizer
|
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer, LightningModuleWithOptimizer
|
||||||
|
from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex
|
||||||
|
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import DualViewTransformWrapper
|
||||||
|
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
|
||||||
|
from InnerEye.ML.SSL.utils import SSLTrainingType
|
||||||
|
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
|
||||||
|
|
||||||
|
|
||||||
class DummyContainerWithDatasets(LightningContainer):
|
class DummyContainerWithDatasets(LightningContainer):
|
||||||
|
@ -262,3 +271,84 @@ class DummyContainerWithHooks(LightningContainer):
|
||||||
assert self.hook_local_zero.is_file(), "before_training_on_local_rank_zero should have been called already"
|
assert self.hook_local_zero.is_file(), "before_training_on_local_rank_zero should have been called already"
|
||||||
assert not self.hook_all.is_file(), "before_training_on_all_ranks should only be called once"
|
assert not self.hook_all.is_file(), "before_training_on_all_ranks should only be called once"
|
||||||
self.hook_all.touch()
|
self.hook_all.touch()
|
||||||
|
|
||||||
|
|
||||||
|
class DummySimCLRData(VisionDataset):
|
||||||
|
"""
|
||||||
|
Returns a constant vector of size three [1., 1., 1.]
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: str,
|
||||||
|
train: bool = True,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
target_transform: Optional[Callable] = None,
|
||||||
|
download: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super(DummySimCLRData, self).__init__(root, transform=transform,
|
||||||
|
target_transform=target_transform)
|
||||||
|
|
||||||
|
self.train = train
|
||||||
|
self.data = torch.ones(20, 1, 1, 3)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Any): Sample and meta data, optionally transformed by the respective transforms.
|
||||||
|
"""
|
||||||
|
# item = (self.data[index], self.data[index]), torch.Tensor([0])
|
||||||
|
image = self.data[index]
|
||||||
|
if self.transform:
|
||||||
|
image = self.transform(image)
|
||||||
|
return image, 0
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.data.shape[0]
|
||||||
|
|
||||||
|
class DummySimCLRInnerEyeData(InnerEyeDataClassBaseWithReturnIndex, DummySimCLRData):
|
||||||
|
"""
|
||||||
|
Wrapper class around the DummySimCLRData class to optionally return the
|
||||||
|
index on top of the image and the label in __getitem__ as well as defining num_classes property.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_classes(self) -> int:
|
||||||
|
return 2
|
||||||
|
|
||||||
|
class DummySimCLRSSLDatasetName(SSLDatasetName, Enum):
|
||||||
|
DUMMY = "DUMMY"
|
||||||
|
|
||||||
|
class DummySimCLR(SSLContainer):
|
||||||
|
"""
|
||||||
|
This module trains an SSL encoder using SimCLR on the DummySimCLRData and finetunes a linear head too.
|
||||||
|
"""
|
||||||
|
SSLContainer._SSLDataClassMappings.update({DummySimCLRSSLDatasetName.DUMMY.value: DummySimCLRInnerEyeData})
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(ssl_training_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
|
||||||
|
linear_head_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
|
||||||
|
# Train with as little data as possible for the test
|
||||||
|
ssl_training_batch_size=2,
|
||||||
|
linear_head_batch_size=2,
|
||||||
|
ssl_encoder=EncoderName.resnet50, # This gets overwritten by the test itself
|
||||||
|
ssl_training_type=SSLTrainingType.SimCLR,
|
||||||
|
random_seed=0,
|
||||||
|
num_epochs=20,
|
||||||
|
num_workers=0,
|
||||||
|
max_num_gpus=1)
|
||||||
|
|
||||||
|
def _get_transforms(self, augmentation_config: Optional[CfgNode],
|
||||||
|
dataset_name: str,
|
||||||
|
is_ssl_encoder_module: bool) -> Tuple[Any, Any]:
|
||||||
|
|
||||||
|
# is_ssl_encoder_module will be True for ssl training, False for linear head training
|
||||||
|
train_transforms = ImageTransformationPipeline([Lambda(lambda x: x)]) # do nothing
|
||||||
|
val_transforms = ImageTransformationPipeline([Lambda(lambda x: x + 1)]) # add 1
|
||||||
|
|
||||||
|
if is_ssl_encoder_module:
|
||||||
|
train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore
|
||||||
|
val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore
|
||||||
|
return train_transforms, val_transforms
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple, Optional
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -13,7 +13,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from pl_bolts.models.self_supervised.resnets import ResNet
|
from pl_bolts.models.self_supervised.resnets import ResNet
|
||||||
from pl_bolts.optimizers import linear_warmup_decay
|
from pl_bolts.optimizers import linear_warmup_decay
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer, seed_everything
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
@ -23,6 +23,7 @@ from InnerEye.Common import fixed_paths
|
||||||
from InnerEye.Common.fixed_paths import repository_root_directory
|
from InnerEye.Common.fixed_paths import repository_root_directory
|
||||||
from InnerEye.Common.fixed_paths_for_tests import TEST_OUTPUTS_PATH
|
from InnerEye.Common.fixed_paths_for_tests import TEST_OUTPUTS_PATH
|
||||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||||
|
from InnerEye.ML.lightning_loggers import StoringLogger
|
||||||
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName
|
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName
|
||||||
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
|
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
|
||||||
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
|
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
|
||||||
|
@ -33,8 +34,9 @@ from InnerEye.ML.common import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||||
from InnerEye.ML.configs.ssl.CIFAR_SSL_configs import CIFAR10SimCLR
|
from InnerEye.ML.configs.ssl.CIFAR_SSL_configs import CIFAR10SimCLR
|
||||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR
|
from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR
|
||||||
from InnerEye.ML.runner import Runner
|
from InnerEye.ML.runner import Runner
|
||||||
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel
|
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel, DummySimCLR
|
||||||
from Tests.ML.utils.test_io_util import write_test_dicom
|
from Tests.ML.utils.test_io_util import write_test_dicom
|
||||||
|
from health_ml.utils import AzureMLProgressBar
|
||||||
|
|
||||||
path_to_cxr_test_dataset = TEST_OUTPUTS_PATH / "cxr_test_dataset"
|
path_to_cxr_test_dataset = TEST_OUTPUTS_PATH / "cxr_test_dataset"
|
||||||
|
|
||||||
|
@ -281,6 +283,67 @@ def test_simclr_lr_scheduler() -> None:
|
||||||
for i in range(highest_lr, len(lr) - 1):
|
for i in range(highest_lr, len(lr) - 1):
|
||||||
assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}"
|
assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}"
|
||||||
|
|
||||||
|
@pytest.mark.gpu
|
||||||
|
def test_simclr_training_recovery(test_output_dirs: OutputFolderForTests) -> None:
|
||||||
|
""" This test checks if a SSLContainer correctly resumes training.
|
||||||
|
First we run SSL using a Trainer for 20 epochs.
|
||||||
|
Second, we run a new SSL job for 15 epochs.
|
||||||
|
Third we resume the job and run it for 5 more epochs.
|
||||||
|
The test checks the learning rate and the loss.
|
||||||
|
The test is meant to run on a GPU!
|
||||||
|
"""
|
||||||
|
def run_simclr_dummy_container(test_output_dirs: OutputFolderForTests,
|
||||||
|
num_epochs: int,
|
||||||
|
last_checkpoint: Optional[ModelCheckpoint] = None) -> Tuple[list, list, ModelCheckpoint]:
|
||||||
|
seed_everything(0, workers=True)
|
||||||
|
container = DummySimCLR()
|
||||||
|
container.setup()
|
||||||
|
model = container.create_model()
|
||||||
|
data = container.get_data_module()
|
||||||
|
|
||||||
|
# add logger
|
||||||
|
logger = StoringLogger()
|
||||||
|
progress = AzureMLProgressBar(refresh_rate=1)
|
||||||
|
checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints")
|
||||||
|
checkpoint_folder.mkdir(exist_ok=True)
|
||||||
|
checkpoint = ModelCheckpoint(dirpath=checkpoint_folder,
|
||||||
|
every_n_val_epochs=1,
|
||||||
|
save_last=True)
|
||||||
|
|
||||||
|
trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir),
|
||||||
|
logger=logger,
|
||||||
|
callbacks=[progress, checkpoint],
|
||||||
|
max_epochs=num_epochs,
|
||||||
|
resume_from_checkpoint=last_checkpoint.last_model_path if last_checkpoint is not None else None,
|
||||||
|
deterministic=True,
|
||||||
|
benchmark=False,
|
||||||
|
gpus=1)
|
||||||
|
trainer.fit(model, datamodule=data)
|
||||||
|
|
||||||
|
lrs = []
|
||||||
|
loss = []
|
||||||
|
for item in logger.results_per_epoch:
|
||||||
|
lrs.append(logger.results_per_epoch[item]['simclr/learning_rate'])
|
||||||
|
loss.append(logger.results_per_epoch[item]['simclr/train/loss'])
|
||||||
|
|
||||||
|
return lrs, loss, checkpoint
|
||||||
|
|
||||||
|
small_encoder = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(3, 2))
|
||||||
|
with mock.patch("InnerEye.ML.SSL.encoders.create_ssl_encoder", return_value=small_encoder):
|
||||||
|
with mock.patch("InnerEye.ML.SSL.encoders.get_encoder_output_dim", return_value=2):
|
||||||
|
# Normal run
|
||||||
|
normal_lrs, normal_loss, _ = run_simclr_dummy_container(test_output_dirs, 20, last_checkpoint=None)
|
||||||
|
|
||||||
|
# Short run
|
||||||
|
short_lrs, short_loss, short_checkpoint = run_simclr_dummy_container(test_output_dirs, 15, last_checkpoint=None)
|
||||||
|
|
||||||
|
# Resumed run
|
||||||
|
resumed_lrs, resumed_loss, _ = run_simclr_dummy_container(test_output_dirs, 20, last_checkpoint=short_checkpoint)
|
||||||
|
|
||||||
|
resumed_lrs = short_lrs + resumed_lrs
|
||||||
|
assert resumed_lrs == normal_lrs
|
||||||
|
resumed_loss = short_loss + resumed_loss
|
||||||
|
assert resumed_loss == normal_loss
|
||||||
|
|
||||||
def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None:
|
def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
Загрузка…
Ссылка в новой задаче