This commit is contained in:
maxilse 2022-01-24 19:56:54 +01:00 коммит произвёл GitHub
Родитель 884e3fd7dc
Коммит 6a4919c361
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 159 добавлений и 4 удалений

Просмотреть файл

@ -11,6 +11,8 @@ created.
## Upcoming ## Upcoming
- ([#643](https://github.com/microsoft/InnerEye-DeepLearning/pull/643)) Test for recovery of SSL job. Tracks learning rate and train
loss.
### Added ### Added
- ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run. - ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run.

Просмотреть файл

@ -3,7 +3,9 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple, Callable, Optional
from enum import Enum
from yacs.config import CfgNode
import pandas as pd import pandas as pd
import param import param
@ -13,10 +15,17 @@ from torchmetrics.regression import MeanSquaredError
from torch import Tensor from torch import Tensor
from torch.nn import Identity from torch.nn import Identity
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torchvision.datasets.vision import VisionDataset
from torchvision.transforms import Lambda
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.ML.common import ModelExecutionMode from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer, LightningModuleWithOptimizer from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer, LightningModuleWithOptimizer
from InnerEye.ML.SSL.datamodules_and_datasets.dataset_cls_utils import InnerEyeDataClassBaseWithReturnIndex
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import DualViewTransformWrapper
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLContainer, SSLDatasetName
from InnerEye.ML.SSL.utils import SSLTrainingType
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
class DummyContainerWithDatasets(LightningContainer): class DummyContainerWithDatasets(LightningContainer):
@ -262,3 +271,84 @@ class DummyContainerWithHooks(LightningContainer):
assert self.hook_local_zero.is_file(), "before_training_on_local_rank_zero should have been called already" assert self.hook_local_zero.is_file(), "before_training_on_local_rank_zero should have been called already"
assert not self.hook_all.is_file(), "before_training_on_all_ranks should only be called once" assert not self.hook_all.is_file(), "before_training_on_all_ranks should only be called once"
self.hook_all.touch() self.hook_all.touch()
class DummySimCLRData(VisionDataset):
"""
Returns a constant vector of size three [1., 1., 1.]
"""
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(DummySimCLRData, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train
self.data = torch.ones(20, 1, 1, 3)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
(Any): Sample and meta data, optionally transformed by the respective transforms.
"""
# item = (self.data[index], self.data[index]), torch.Tensor([0])
image = self.data[index]
if self.transform:
image = self.transform(image)
return image, 0
def __len__(self) -> int:
return self.data.shape[0]
class DummySimCLRInnerEyeData(InnerEyeDataClassBaseWithReturnIndex, DummySimCLRData):
"""
Wrapper class around the DummySimCLRData class to optionally return the
index on top of the image and the label in __getitem__ as well as defining num_classes property.
"""
@property
def num_classes(self) -> int:
return 2
class DummySimCLRSSLDatasetName(SSLDatasetName, Enum):
DUMMY = "DUMMY"
class DummySimCLR(SSLContainer):
"""
This module trains an SSL encoder using SimCLR on the DummySimCLRData and finetunes a linear head too.
"""
SSLContainer._SSLDataClassMappings.update({DummySimCLRSSLDatasetName.DUMMY.value: DummySimCLRInnerEyeData})
def __init__(self) -> None:
super().__init__(ssl_training_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
linear_head_dataset_name=DummySimCLRSSLDatasetName.DUMMY,
# Train with as little data as possible for the test
ssl_training_batch_size=2,
linear_head_batch_size=2,
ssl_encoder=EncoderName.resnet50, # This gets overwritten by the test itself
ssl_training_type=SSLTrainingType.SimCLR,
random_seed=0,
num_epochs=20,
num_workers=0,
max_num_gpus=1)
def _get_transforms(self, augmentation_config: Optional[CfgNode],
dataset_name: str,
is_ssl_encoder_module: bool) -> Tuple[Any, Any]:
# is_ssl_encoder_module will be True for ssl training, False for linear head training
train_transforms = ImageTransformationPipeline([Lambda(lambda x: x)]) # do nothing
val_transforms = ImageTransformationPipeline([Lambda(lambda x: x + 1)]) # add 1
if is_ssl_encoder_module:
train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore
val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore
return train_transforms, val_transforms

Просмотреть файл

@ -4,7 +4,7 @@
# ------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------
import math import math
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple, Optional
from unittest import mock from unittest import mock
import numpy as np import numpy as np
@ -13,7 +13,7 @@ import pytest
import torch import torch
from pl_bolts.models.self_supervised.resnets import ResNet from pl_bolts.models.self_supervised.resnets import ResNet
from pl_bolts.optimizers import linear_warmup_decay from pl_bolts.optimizers import linear_warmup_decay
from pytorch_lightning import Trainer from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.trainer.supporters import CombinedLoader
from torch.nn import Module from torch.nn import Module
@ -23,6 +23,7 @@ from InnerEye.Common import fixed_paths
from InnerEye.Common.fixed_paths import repository_root_directory from InnerEye.Common.fixed_paths import repository_root_directory
from InnerEye.Common.fixed_paths_for_tests import TEST_OUTPUTS_PATH from InnerEye.Common.fixed_paths_for_tests import TEST_OUTPUTS_PATH
from InnerEye.Common.output_directories import OutputFolderForTests from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.lightning_loggers import StoringLogger
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
@ -33,8 +34,9 @@ from InnerEye.ML.common import LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
from InnerEye.ML.configs.ssl.CIFAR_SSL_configs import CIFAR10SimCLR from InnerEye.ML.configs.ssl.CIFAR_SSL_configs import CIFAR10SimCLR
from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier, NIH_RSNA_SimCLR
from InnerEye.ML.runner import Runner from InnerEye.ML.runner import Runner
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel, DummySimCLR
from Tests.ML.utils.test_io_util import write_test_dicom from Tests.ML.utils.test_io_util import write_test_dicom
from health_ml.utils import AzureMLProgressBar
path_to_cxr_test_dataset = TEST_OUTPUTS_PATH / "cxr_test_dataset" path_to_cxr_test_dataset = TEST_OUTPUTS_PATH / "cxr_test_dataset"
@ -281,6 +283,67 @@ def test_simclr_lr_scheduler() -> None:
for i in range(highest_lr, len(lr) - 1): for i in range(highest_lr, len(lr) - 1):
assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}" assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}"
@pytest.mark.gpu
def test_simclr_training_recovery(test_output_dirs: OutputFolderForTests) -> None:
""" This test checks if a SSLContainer correctly resumes training.
First we run SSL using a Trainer for 20 epochs.
Second, we run a new SSL job for 15 epochs.
Third we resume the job and run it for 5 more epochs.
The test checks the learning rate and the loss.
The test is meant to run on a GPU!
"""
def run_simclr_dummy_container(test_output_dirs: OutputFolderForTests,
num_epochs: int,
last_checkpoint: Optional[ModelCheckpoint] = None) -> Tuple[list, list, ModelCheckpoint]:
seed_everything(0, workers=True)
container = DummySimCLR()
container.setup()
model = container.create_model()
data = container.get_data_module()
# add logger
logger = StoringLogger()
progress = AzureMLProgressBar(refresh_rate=1)
checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints")
checkpoint_folder.mkdir(exist_ok=True)
checkpoint = ModelCheckpoint(dirpath=checkpoint_folder,
every_n_val_epochs=1,
save_last=True)
trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir),
logger=logger,
callbacks=[progress, checkpoint],
max_epochs=num_epochs,
resume_from_checkpoint=last_checkpoint.last_model_path if last_checkpoint is not None else None,
deterministic=True,
benchmark=False,
gpus=1)
trainer.fit(model, datamodule=data)
lrs = []
loss = []
for item in logger.results_per_epoch:
lrs.append(logger.results_per_epoch[item]['simclr/learning_rate'])
loss.append(logger.results_per_epoch[item]['simclr/train/loss'])
return lrs, loss, checkpoint
small_encoder = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(3, 2))
with mock.patch("InnerEye.ML.SSL.encoders.create_ssl_encoder", return_value=small_encoder):
with mock.patch("InnerEye.ML.SSL.encoders.get_encoder_output_dim", return_value=2):
# Normal run
normal_lrs, normal_loss, _ = run_simclr_dummy_container(test_output_dirs, 20, last_checkpoint=None)
# Short run
short_lrs, short_loss, short_checkpoint = run_simclr_dummy_container(test_output_dirs, 15, last_checkpoint=None)
# Resumed run
resumed_lrs, resumed_loss, _ = run_simclr_dummy_container(test_output_dirs, 20, last_checkpoint=short_checkpoint)
resumed_lrs = short_lrs + resumed_lrs
assert resumed_lrs == normal_lrs
resumed_loss = short_loss + resumed_loss
assert resumed_loss == normal_loss
def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None: def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None:
""" """