diff --git a/CHANGELOG.md b/CHANGELOG.md index fbe9e935..d5d072f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ created. ## Upcoming +- ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train +loss. ### Added - ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run. diff --git a/Tests/ML/configs/lightning_test_containers.py b/Tests/ML/configs/lightning_test_containers.py index 645277d0..5c15e094 100644 --- a/Tests/ML/configs/lightning_test_containers.py +++ b/Tests/ML/configs/lightning_test_containers.py @@ -3,7 +3,9 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Callable, Optional +from enum import Enum +from yacs.config import CfgNode import pandas as pd import param @@ -13,10 +15,17 @@ from torchmetrics.regression import MeanSquaredError from torch import Tensor from torch.nn import Identity from torch.utils.data import DataLoader, Dataset +from torchvision.datasets.vision import VisionDataset +from torchvision.transforms import Lambda from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path from InnerEye.ML.common import ModelExecutionMode from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer, LightningModuleWithOptimizer +from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex +from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import DualViewTransformWrapper +from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName +from InnerEye.ML.SSL.utils import SSLTrainingType +from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline class DummyContainerWithDatasets(LightningContainer): @@ -262,3 +271,84 @@ class DummyContainerWithHooks(LightningContainer): assert self.hook_local_zero.is_file(), "before_training_on_local_rank_zero should have been called already" assert not self.hook_all.is_file(), "before_training_on_all_ranks should only be called once" self.hook_all.touch() + + +class DummySimCLRData(VisionDataset): + """ + Returns a constant vector of size three [1., 1., 1.] + """ + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super(DummySimCLRData, self).__init__(root, transform=transform, + target_transform=target_transform) + + self.train = train + self.data = torch.ones(20, 1, 1, 3) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + (Any): Sample and meta data, optionally transformed by the respective transforms. + """ + # item = (self.data[index], self.data[index]), torch.Tensor([0]) + image = self.data[index] + if self.transform: + image = self.transform(image) + return image, 0 + + def __len__(self) -> int: + return self.data.shape[0] + +class DummySimCLRInnerEyeData(InnerEyeDataClassBaseWithReturnIndex, DummySimCLRData): + """ + Wrapper class around the DummySimCLRData class to optionally return the + index on top of the image and the label in __getitem__ as well as defining num_classes property. + """ + + @property + def num_classes(self) -> int: + return 2 + +class DummySimCLRSSLDatasetName(SSLDatasetName, Enum): + DUMMY = "DUMMY" + +class DummySimCLR(SSLContainer): + """ + This module trains an SSL encoder using SimCLR on the DummySimCLRData and finetunes a linear head too. + """ + SSLContainer._SSLDataClassMappings.update({DummySimCLRSSLDatasetName.DUMMY.value: DummySimCLRInnerEyeData}) + + def __init__(self) -> None: + super().__init__(ssl_training_dataset_name=DummySimCLRSSLDatasetName.DUMMY, + linear_head_dataset_name=DummySimCLRSSLDatasetName.DUMMY, + # Train with as little data as possible for the test + ssl_training_batch_size=2, + linear_head_batch_size=2, + ssl_encoder=EncoderName.resnet50, # This gets overwritten by the test itself + ssl_training_type=SSLTrainingType.SimCLR, + random_seed=0, + num_epochs=20, + num_workers=0, + max_num_gpus=1) + + def _get_transforms(self, augmentation_config: Optional[CfgNode], + dataset_name: str, + is_ssl_encoder_module: bool) -> Tuple[Any, Any]: + + # is_ssl_encoder_module will be True for ssl training, False for linear head training + train_transforms = ImageTransformationPipeline([Lambda(lambda x: x)]) # do nothing + val_transforms = ImageTransformationPipeline([Lambda(lambda x: x + 1)]) # add 1 + + if is_ssl_encoder_module: + train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore + val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore + return train_transforms, val_transforms diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index 2ff39ae6..b158f7be 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------------------ import math from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional from unittest import mock import numpy as np @@ -13,7 +13,7 @@ import pytest import torch from pl_bolts.models.self_supervised.resnets import ResNet from pl_bolts.optimizers import linear_warmup_decay -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.supporters import CombinedLoader from torch.nn import Module @@ -23,6 +23,7 @@ from InnerEye.Common import fixed_paths from InnerEye.Common.fixed_paths import repository_root_directory from InnerEye.Common.fixed_paths_for_tests import TEST_OUTPUTS_PATH from InnerEye.Common.output_directories import OutputFolderForTests +from InnerEye.ML.lightning_loggers import StoringLogger from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye @@ -33,8 +34,9 @@ from InnerEye.ML.common import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX from InnerEye.ML.configs.ssl.CIFAR_SSL_configs import CIFAR10SimCLR from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR from InnerEye.ML.runner import Runner -from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel +from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel, DummySimCLR from Tests.ML.utils.test_io_util import write_test_dicom +from health_ml.utils import AzureMLProgressBar path_to_cxr_test_dataset = TEST_OUTPUTS_PATH / "cxr_test_dataset" @@ -281,6 +283,67 @@ def test_simclr_lr_scheduler() -> None: for i in range(highest_lr, len(lr) - 1): assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}" +@pytest.mark.gpu +def test_simclr_training_recovery(test_output_dirs: OutputFolderForTests) -> None: + """ This test checks if a SSLContainer correctly resumes training. + First we run SSL using a Trainer for 20 epochs. + Second, we run a new SSL job for 15 epochs. + Third we resume the job and run it for 5 more epochs. + The test checks the learning rate and the loss. + The test is meant to run on a GPU! + """ + def run_simclr_dummy_container(test_output_dirs: OutputFolderForTests, + num_epochs: int, + last_checkpoint: Optional[ModelCheckpoint] = None) -> Tuple[list, list, ModelCheckpoint]: + seed_everything(0, workers=True) + container = DummySimCLR() + container.setup() + model = container.create_model() + data = container.get_data_module() + + # add logger + logger = StoringLogger() + progress = AzureMLProgressBar(refresh_rate=1) + checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints") + checkpoint_folder.mkdir(exist_ok=True) + checkpoint = ModelCheckpoint(dirpath=checkpoint_folder, + every_n_val_epochs=1, + save_last=True) + + trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir), + logger=logger, + callbacks=[progress, checkpoint], + max_epochs=num_epochs, + resume_from_checkpoint=last_checkpoint.last_model_path if last_checkpoint is not None else None, + deterministic=True, + benchmark=False, + gpus=1) + trainer.fit(model, datamodule=data) + + lrs = [] + loss = [] + for item in logger.results_per_epoch: + lrs.append(logger.results_per_epoch[item]['simclr/learning_rate']) + loss.append(logger.results_per_epoch[item]['simclr/train/loss']) + + return lrs, loss, checkpoint + + small_encoder = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(3, 2)) + with mock.patch("InnerEye.ML.SSL.encoders.create_ssl_encoder", return_value=small_encoder): + with mock.patch("InnerEye.ML.SSL.encoders.get_encoder_output_dim", return_value=2): + # Normal run + normal_lrs, normal_loss, _ = run_simclr_dummy_container(test_output_dirs, 20, last_checkpoint=None) + + # Short run + short_lrs, short_loss, short_checkpoint = run_simclr_dummy_container(test_output_dirs, 15, last_checkpoint=None) + + # Resumed run + resumed_lrs, resumed_loss, _ = run_simclr_dummy_container(test_output_dirs, 20, last_checkpoint=short_checkpoint) + + resumed_lrs = short_lrs + resumed_lrs + assert resumed_lrs == normal_lrs + resumed_loss = short_loss + resumed_loss + assert resumed_loss == normal_loss def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None: """