Enabling distributed training for SSL online evaluator (#612)

This commit is contained in:
Anton Schwaighofer 2021-12-10 09:33:59 +00:00 коммит произвёл GitHub
Родитель c7eef5ea69
Коммит 46017f40a0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 190 добавлений и 38 удалений

Просмотреть файл

@ -93,6 +93,7 @@ in inference-only runs when using lightning containers.
- ([#558](https://github.com/microsoft/InnerEye-DeepLearning/pull/558)) Fix issue with the CovidModel config where model
weights from a finetuning run were incompatible with the model architecture created for non-finetuning runs.
- ([#604](https://github.com/microsoft/InnerEye-DeepLearning/pull/604)) Fix issue where runs on a VM would download the dataset even when a local dataset is provided.
- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training
### Removed

Просмотреть файл

@ -9,21 +9,24 @@ import pytorch_lightning as pl
import torch
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pytorch_lightning.utilities import rank_zero_warn
from torch import Tensor as T
from health_ml.utils import log_on_epoch
from torch.nn import functional as F
from torch.nn import SyncBatchNorm, functional as F
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Metric
from InnerEye.ML.SSL.utils import SSLDataModuleType
from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve
from InnerEye.ML.utils.layer_util import set_model_to_eval_mode
from health_ml.utils import log_on_epoch
BatchType = Union[Dict[SSLDataModuleType, Any], Any]
OPTIMIZER_STATE_NAME = "evaluator_optimizer"
EVALUATOR_STATE_NAME = "evaluator_weights"
class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
OPTIMIZER_STATE_NAME = "evaluator_optimizer"
EVALUATOR_STATE_NAME = "evaluator_weights"
def __init__(self,
learning_rate: float,
class_weights: Optional[torch.Tensor] = None,
@ -47,11 +50,11 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
Accuracy05()] \
if self.num_classes == 2 else [Accuracy05()]
self.class_weights = class_weights
self.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim,
n_classes=self.num_classes,
p=self.drop_p,
n_hidden=self.hidden_dim)
self.optimizer = torch.optim.Adam(self.non_linear_evaluator.parameters(),
self.evaluator = SSLEvaluator(n_input=self.z_dim,
n_classes=self.num_classes,
p=self.drop_p,
n_hidden=self.hidden_dim)
self.optimizer = torch.optim.Adam(self.evaluator.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay)
@ -61,24 +64,34 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
checkpoint: Dict[str, Any]) -> Dict[str, Any]:
# Each callback gets its own state dictionary, that are fed back in during load
return {
OPTIMIZER_STATE_NAME: self.optimizer.state_dict(),
EVALUATOR_STATE_NAME: self.non_linear_evaluator.state_dict()
self.OPTIMIZER_STATE_NAME: self.optimizer.state_dict(),
self.EVALUATOR_STATE_NAME: self.evaluator.state_dict()
}
def on_load_checkpoint(self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
callback_state: Dict[str, Any]) -> None:
self.optimizer.load_state_dict(callback_state[OPTIMIZER_STATE_NAME])
self.non_linear_evaluator.load_state_dict(callback_state[EVALUATOR_STATE_NAME])
self.optimizer.load_state_dict(callback_state[self.OPTIMIZER_STATE_NAME])
self.evaluator.load_state_dict(callback_state[self.EVALUATOR_STATE_NAME])
def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""
Initializes modules and moves metrics and class weights to module device
Moves metrics and the online evaluator to the correct GPU.
If training happens via DDP, SyncBatchNorm is enabled for the online evaluator, and it is converted to
a DDP module.
"""
for metric in [*self.train_metrics, *self.val_metrics]:
metric.to(device=pl_module.device) # type: ignore
self.non_linear_evaluator.to(pl_module.device)
self.evaluator.to(pl_module.device)
accelerator = trainer.accelerator_connector
if accelerator.is_distributed:
if accelerator.use_ddp:
self.evaluator = SyncBatchNorm.convert_sync_batchnorm(self.evaluator)
self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore
else:
rank_zero_warn("This type of distributed accelerator is not supported. "
"The online evaluator will not synchronize across GPUs.")
@staticmethod
def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]:
@ -108,7 +121,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
representations = representations.detach()
# Run the linear-head with SSL embeddings.
mlp_preds = self.non_linear_evaluator(representations)
mlp_preds = self.evaluator(representations)
weights = None if self.class_weights is None else self.class_weights.to(device=pl_module.device)
mlp_loss = F.cross_entropy(mlp_preds, y, weight=weights)
@ -133,15 +146,11 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist())
if ids_linear_head not in self.visited_ids:
self.visited_ids.add(ids_linear_head)
# Put the online evaluator into "eval" mode
old_mode = self.non_linear_evaluator.training
self.non_linear_evaluator.eval()
loss = self.shared_step(batch, pl_module, is_training=False)
log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss)
for metric in self.val_metrics:
log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric)
# Put the online evaluator back into the state (eval or train) that it was before calling this method
self.non_linear_evaluator.train(old_mode)
with set_model_to_eval_mode(self.evaluator):
loss = self.shared_step(batch, pl_module, is_training=False)
log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss)
for metric in self.val_metrics:
log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore
"""

Просмотреть файл

@ -2,7 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Iterable, Sized, Tuple, Union
from contextlib import contextmanager
from typing import Generator, Iterable, Sized, Tuple, Union
import torch
from torch.nn import init
@ -90,3 +91,16 @@ def get_upsampling_kernel_size(downsampling_factor: IntOrTuple3, num_dimensions:
upsample_size(downsampling_factor[1]), # type: ignore
upsample_size(downsampling_factor[2])) # type: ignore
return upsampling_kernel_size
@contextmanager
def set_model_to_eval_mode(model: torch.nn.Module) -> Generator:
"""
Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to
what is was before the call.
:param model: The model to modify.
"""
old_mode = model.training
model.eval()
yield
model.train(old_mode)

Просмотреть файл

@ -17,6 +17,7 @@ from InnerEye.Common.common_util import logging_only_to_file
from InnerEye.Common.fixed_paths import DEFAULT_MODEL_SUMMARIES_DIR_PATH
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
from InnerEye.ML.utils.ml_util import RandomStateSnapshot
from InnerEye.ML.utils.layer_util import set_model_to_eval_mode
@dataclass
@ -217,15 +218,11 @@ def forward_preserve_state(module: DeviceAwareModule, inputs: List[torch.Tensor]
inputs = [input_tensor.cuda() for input_tensor in inputs]
# collect the current state of the model
is_train = module.training
module_state = RandomStateSnapshot.snapshot_random_state()
# set the model in evaluation mode and perform a forward pass
module.eval()
with torch.no_grad():
output = module.forward(*inputs)
if is_train:
module.train()
with set_model_to_eval_mode(module):
with torch.no_grad():
output = module.forward(*inputs)
# restore the seed for torch and numpy
module_state.restore_random_state()

Просмотреть файл

@ -12,22 +12,26 @@ import pandas as pd
import pytest
import torch
from pl_bolts.models.self_supervised.resnets import ResNet
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn import Module
from torch.optim.lr_scheduler import _LRScheduler
from InnerEye.Common import fixed_paths
from InnerEye.Common.common_util import is_windows
from InnerEye.Common.fixed_paths import repository_root_directory
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import EVALUATOR_STATE_NAME, OPTIMIZER_STATE_NAME, \
SSLOnlineEvaluatorInnerEye
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier
from InnerEye.ML.runner import Runner
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel
from Tests.ML.utils.test_io_util import write_test_dicom
path_to_test_dataset = full_ml_test_data_path("cxr_test_dataset")
@ -133,8 +137,8 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None:
assert "callbacks" in checkpoint
assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"]
callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye]
assert OPTIMIZER_STATE_NAME in callback_state
assert EVALUATOR_STATE_NAME in callback_state
assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state
assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state
# Now run the actual SSL classifier off the stored checkpoint
args = common_test_args + ["--model=SSLClassifierCIFAR", f"--local_ssl_weights_path={checkpoint_path}"]
@ -268,3 +272,130 @@ def test_simclr_lr_scheduler() -> None:
assert lr[i] < lr[i + 1], f"Not strictly monotonically increasing at index {i}"
for i in range(highest_lr, len(lr) - 1):
assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}"
def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None:
"""
Test checkpoint recovery for the online evaluator in an end-to-end training run.
"""
container = DummyContainerWithModel()
model = container.create_model()
data = container.get_data_module()
checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints")
checkpoint_folder.mkdir(exist_ok=True)
checkpoints = ModelCheckpoint(dirpath=checkpoint_folder,
every_n_val_epochs=1,
save_last=True)
# Create a first callback, that will be used in training.
callback1 = SSLOnlineEvaluatorInnerEye(class_weights=None,
z_dim=1,
num_classes=2,
dataset="foo",
drop_p=0.2,
learning_rate=1e-5)
# To simplify the test setup, do not run any actual training (this would require complicated dataset with a
# combined loader)
with mock.patch(
"InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye.on_train_batch_end",
return_value=None) as mock_train:
with mock.patch(
"InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye"
".on_validation_batch_end",
return_value=None):
trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir),
callbacks=[checkpoints, callback1],
max_epochs=10)
trainer.fit(model, datamodule=data)
# Check that the callback was actually used
mock_train.assert_called()
# Now read out the parameters of the callback.
# We will then run a second training job, with a new callback object, that will be initialized randomly,
# and should have different parameters initially. After checkpoint recovery, it should have exactly the
# same parameters as the first callback.
parameters1 = list(callback1.evaluator.parameters())
callback2 = SSLOnlineEvaluatorInnerEye(class_weights=None,
z_dim=1,
num_classes=2,
dataset="foo",
drop_p=0.2,
learning_rate=1e-5)
# Ensure that the parameters are really different initially
parameters2_before_training = list(callback2.evaluator.parameters())
assert not torch.allclose(parameters2_before_training[0], parameters1[0])
# Start a second training run with recovery
last_checkpoint = checkpoints.last_model_path
trainer2 = Trainer(default_root_dir=str(test_output_dirs.root_dir),
callbacks=[callback2],
max_epochs=20,
resume_from_checkpoint=last_checkpoint)
trainer2.fit(model, datamodule=data)
# Read the parameters and check if they are the same as what was stored in the first callback.
parameters2_after_training = list(callback2.evaluator.parameters())
assert torch.allclose(parameters2_after_training[0], parameters1[0])
# It's somewhat obsolete, but we can now check that the checkpoint file really contained the optimizer and weights
checkpoint = torch.load(last_checkpoint)
assert "callbacks" in checkpoint
assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"]
callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye]
assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state
assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state
@pytest.mark.gpu
def test_online_evaluator_not_distributed() -> None:
"""
Check if the online evaluator uses the DDP flag correctly when running not distributed
"""
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel") as mock_ddp:
callback = SSLOnlineEvaluatorInnerEye(class_weights=None,
z_dim=1,
num_classes=2,
dataset="foo",
drop_p=0.2,
learning_rate=1e-5)
mock_ddp.assert_not_called()
# Standard trainer without DDP
trainer = Trainer()
# Test the flag that the internal logic of on_pretrain_routine_start uses
assert not trainer.accelerator_connector.is_distributed
mock_module = mock.MagicMock(device=torch.device("cpu"))
callback.on_pretrain_routine_start(trainer, mock_module)
assert isinstance(callback.evaluator, Module)
mock_ddp.assert_not_called()
@pytest.mark.gpu
def test_online_evaluator_distributed() -> None:
"""
Check if the online evaluator uses the DDP flag correctly when running distributed.
"""
mock_ddp_result = "mock_ddp_result"
mock_sync_result = "mock_sync_result"
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SyncBatchNorm.convert_sync_batchnorm",
return_value=mock_sync_result) as mock_sync:
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel",
return_value=mock_ddp_result) as mock_ddp:
callback = SSLOnlineEvaluatorInnerEye(class_weights=None,
z_dim=1,
num_classes=2,
dataset="foo",
drop_p=0.2,
learning_rate=1e-5)
# Trainer with DDP
device = torch.device("cuda:0")
mock_module = mock.MagicMock(device=device)
trainer = Trainer(accelerator="ddp", gpus=2)
# Test the two flags that the internal logic of on_pretrain_routine_start uses
assert trainer.accelerator_connector.is_distributed
assert trainer.accelerator_connector.use_ddp
original_evaluator = callback.evaluator
callback.on_pretrain_routine_start(trainer, mock_module)
# Check that SyncBatchNorm has been turned on
mock_sync.assert_called_once_with(original_evaluator)
# Check that the evaluator has been turned into a DDP object
# We still need to mock DDP here because the constructor relies on having a process group available
mock_ddp.assert_called_once_with(mock_sync_result, device_ids=[device])
assert callback.evaluator == mock_ddp_result