Enabling distributed training for SSL online evaluator (#612)
This commit is contained in:
Родитель
c7eef5ea69
Коммит
46017f40a0
|
@ -93,6 +93,7 @@ in inference-only runs when using lightning containers.
|
|||
- ([#558](https://github.com/microsoft/InnerEye-DeepLearning/pull/558)) Fix issue with the CovidModel config where model
|
||||
weights from a finetuning run were incompatible with the model architecture created for non-finetuning runs.
|
||||
- ([#604](https://github.com/microsoft/InnerEye-DeepLearning/pull/604)) Fix issue where runs on a VM would download the dataset even when a local dataset is provided.
|
||||
- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training
|
||||
|
||||
### Removed
|
||||
|
||||
|
|
|
@ -9,21 +9,24 @@ import pytorch_lightning as pl
|
|||
import torch
|
||||
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
|
||||
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from torch import Tensor as T
|
||||
from health_ml.utils import log_on_epoch
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import SyncBatchNorm, functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torchmetrics import Metric
|
||||
|
||||
from InnerEye.ML.SSL.utils import SSLDataModuleType
|
||||
from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve
|
||||
from InnerEye.ML.utils.layer_util import set_model_to_eval_mode
|
||||
from health_ml.utils import log_on_epoch
|
||||
|
||||
BatchType = Union[Dict[SSLDataModuleType, Any], Any]
|
||||
|
||||
OPTIMIZER_STATE_NAME = "evaluator_optimizer"
|
||||
EVALUATOR_STATE_NAME = "evaluator_weights"
|
||||
|
||||
|
||||
class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
||||
OPTIMIZER_STATE_NAME = "evaluator_optimizer"
|
||||
EVALUATOR_STATE_NAME = "evaluator_weights"
|
||||
|
||||
def __init__(self,
|
||||
learning_rate: float,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
|
@ -47,11 +50,11 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
|||
Accuracy05()] \
|
||||
if self.num_classes == 2 else [Accuracy05()]
|
||||
self.class_weights = class_weights
|
||||
self.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim,
|
||||
n_classes=self.num_classes,
|
||||
p=self.drop_p,
|
||||
n_hidden=self.hidden_dim)
|
||||
self.optimizer = torch.optim.Adam(self.non_linear_evaluator.parameters(),
|
||||
self.evaluator = SSLEvaluator(n_input=self.z_dim,
|
||||
n_classes=self.num_classes,
|
||||
p=self.drop_p,
|
||||
n_hidden=self.hidden_dim)
|
||||
self.optimizer = torch.optim.Adam(self.evaluator.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay)
|
||||
|
||||
|
@ -61,24 +64,34 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
|||
checkpoint: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Each callback gets its own state dictionary, that are fed back in during load
|
||||
return {
|
||||
OPTIMIZER_STATE_NAME: self.optimizer.state_dict(),
|
||||
EVALUATOR_STATE_NAME: self.non_linear_evaluator.state_dict()
|
||||
self.OPTIMIZER_STATE_NAME: self.optimizer.state_dict(),
|
||||
self.EVALUATOR_STATE_NAME: self.evaluator.state_dict()
|
||||
}
|
||||
|
||||
def on_load_checkpoint(self,
|
||||
trainer: pl.Trainer,
|
||||
pl_module: pl.LightningModule,
|
||||
callback_state: Dict[str, Any]) -> None:
|
||||
self.optimizer.load_state_dict(callback_state[OPTIMIZER_STATE_NAME])
|
||||
self.non_linear_evaluator.load_state_dict(callback_state[EVALUATOR_STATE_NAME])
|
||||
self.optimizer.load_state_dict(callback_state[self.OPTIMIZER_STATE_NAME])
|
||||
self.evaluator.load_state_dict(callback_state[self.EVALUATOR_STATE_NAME])
|
||||
|
||||
def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
||||
"""
|
||||
Initializes modules and moves metrics and class weights to module device
|
||||
Moves metrics and the online evaluator to the correct GPU.
|
||||
If training happens via DDP, SyncBatchNorm is enabled for the online evaluator, and it is converted to
|
||||
a DDP module.
|
||||
"""
|
||||
for metric in [*self.train_metrics, *self.val_metrics]:
|
||||
metric.to(device=pl_module.device) # type: ignore
|
||||
self.non_linear_evaluator.to(pl_module.device)
|
||||
self.evaluator.to(pl_module.device)
|
||||
accelerator = trainer.accelerator_connector
|
||||
if accelerator.is_distributed:
|
||||
if accelerator.use_ddp:
|
||||
self.evaluator = SyncBatchNorm.convert_sync_batchnorm(self.evaluator)
|
||||
self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore
|
||||
else:
|
||||
rank_zero_warn("This type of distributed accelerator is not supported. "
|
||||
"The online evaluator will not synchronize across GPUs.")
|
||||
|
||||
@staticmethod
|
||||
def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]:
|
||||
|
@ -108,7 +121,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
|||
representations = representations.detach()
|
||||
|
||||
# Run the linear-head with SSL embeddings.
|
||||
mlp_preds = self.non_linear_evaluator(representations)
|
||||
mlp_preds = self.evaluator(representations)
|
||||
weights = None if self.class_weights is None else self.class_weights.to(device=pl_module.device)
|
||||
mlp_loss = F.cross_entropy(mlp_preds, y, weight=weights)
|
||||
|
||||
|
@ -133,15 +146,11 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
|||
ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist())
|
||||
if ids_linear_head not in self.visited_ids:
|
||||
self.visited_ids.add(ids_linear_head)
|
||||
# Put the online evaluator into "eval" mode
|
||||
old_mode = self.non_linear_evaluator.training
|
||||
self.non_linear_evaluator.eval()
|
||||
loss = self.shared_step(batch, pl_module, is_training=False)
|
||||
log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss)
|
||||
for metric in self.val_metrics:
|
||||
log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric)
|
||||
# Put the online evaluator back into the state (eval or train) that it was before calling this method
|
||||
self.non_linear_evaluator.train(old_mode)
|
||||
with set_model_to_eval_mode(self.evaluator):
|
||||
loss = self.shared_step(batch, pl_module, is_training=False)
|
||||
log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss)
|
||||
for metric in self.val_metrics:
|
||||
log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric)
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore
|
||||
"""
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from typing import Iterable, Sized, Tuple, Union
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Iterable, Sized, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import init
|
||||
|
@ -90,3 +91,16 @@ def get_upsampling_kernel_size(downsampling_factor: IntOrTuple3, num_dimensions:
|
|||
upsample_size(downsampling_factor[1]), # type: ignore
|
||||
upsample_size(downsampling_factor[2])) # type: ignore
|
||||
return upsampling_kernel_size
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_model_to_eval_mode(model: torch.nn.Module) -> Generator:
|
||||
"""
|
||||
Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to
|
||||
what is was before the call.
|
||||
:param model: The model to modify.
|
||||
"""
|
||||
old_mode = model.training
|
||||
model.eval()
|
||||
yield
|
||||
model.train(old_mode)
|
||||
|
|
|
@ -17,6 +17,7 @@ from InnerEye.Common.common_util import logging_only_to_file
|
|||
from InnerEye.Common.fixed_paths import DEFAULT_MODEL_SUMMARIES_DIR_PATH
|
||||
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
|
||||
from InnerEye.ML.utils.ml_util import RandomStateSnapshot
|
||||
from InnerEye.ML.utils.layer_util import set_model_to_eval_mode
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -217,15 +218,11 @@ def forward_preserve_state(module: DeviceAwareModule, inputs: List[torch.Tensor]
|
|||
inputs = [input_tensor.cuda() for input_tensor in inputs]
|
||||
|
||||
# collect the current state of the model
|
||||
is_train = module.training
|
||||
module_state = RandomStateSnapshot.snapshot_random_state()
|
||||
|
||||
# set the model in evaluation mode and perform a forward pass
|
||||
module.eval()
|
||||
with torch.no_grad():
|
||||
output = module.forward(*inputs)
|
||||
if is_train:
|
||||
module.train()
|
||||
with set_model_to_eval_mode(module):
|
||||
with torch.no_grad():
|
||||
output = module.forward(*inputs)
|
||||
|
||||
# restore the seed for torch and numpy
|
||||
module_state.restore_random_state()
|
||||
|
|
|
@ -12,22 +12,26 @@ import pandas as pd
|
|||
import pytest
|
||||
import torch
|
||||
from pl_bolts.models.self_supervised.resnets import ResNet
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from torch.nn import Module
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from InnerEye.Common import fixed_paths
|
||||
from InnerEye.Common.common_util import is_windows
|
||||
from InnerEye.Common.fixed_paths import repository_root_directory
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName
|
||||
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
|
||||
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
|
||||
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
|
||||
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import EVALUATOR_STATE_NAME, OPTIMIZER_STATE_NAME, \
|
||||
SSLOnlineEvaluatorInnerEye
|
||||
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
|
||||
from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType
|
||||
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier
|
||||
from InnerEye.ML.runner import Runner
|
||||
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel
|
||||
from Tests.ML.utils.test_io_util import write_test_dicom
|
||||
|
||||
path_to_test_dataset = full_ml_test_data_path("cxr_test_dataset")
|
||||
|
@ -133,8 +137,8 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None:
|
|||
assert "callbacks" in checkpoint
|
||||
assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"]
|
||||
callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye]
|
||||
assert OPTIMIZER_STATE_NAME in callback_state
|
||||
assert EVALUATOR_STATE_NAME in callback_state
|
||||
assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state
|
||||
assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state
|
||||
|
||||
# Now run the actual SSL classifier off the stored checkpoint
|
||||
args = common_test_args + ["--model=SSLClassifierCIFAR", f"--local_ssl_weights_path={checkpoint_path}"]
|
||||
|
@ -268,3 +272,130 @@ def test_simclr_lr_scheduler() -> None:
|
|||
assert lr[i] < lr[i + 1], f"Not strictly monotonically increasing at index {i}"
|
||||
for i in range(highest_lr, len(lr) - 1):
|
||||
assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}"
|
||||
|
||||
|
||||
def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None:
|
||||
"""
|
||||
Test checkpoint recovery for the online evaluator in an end-to-end training run.
|
||||
"""
|
||||
container = DummyContainerWithModel()
|
||||
model = container.create_model()
|
||||
data = container.get_data_module()
|
||||
checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints")
|
||||
checkpoint_folder.mkdir(exist_ok=True)
|
||||
checkpoints = ModelCheckpoint(dirpath=checkpoint_folder,
|
||||
every_n_val_epochs=1,
|
||||
save_last=True)
|
||||
# Create a first callback, that will be used in training.
|
||||
callback1 = SSLOnlineEvaluatorInnerEye(class_weights=None,
|
||||
z_dim=1,
|
||||
num_classes=2,
|
||||
dataset="foo",
|
||||
drop_p=0.2,
|
||||
learning_rate=1e-5)
|
||||
# To simplify the test setup, do not run any actual training (this would require complicated dataset with a
|
||||
# combined loader)
|
||||
with mock.patch(
|
||||
"InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye.on_train_batch_end",
|
||||
return_value=None) as mock_train:
|
||||
with mock.patch(
|
||||
"InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye"
|
||||
".on_validation_batch_end",
|
||||
return_value=None):
|
||||
trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir),
|
||||
callbacks=[checkpoints, callback1],
|
||||
max_epochs=10)
|
||||
trainer.fit(model, datamodule=data)
|
||||
# Check that the callback was actually used
|
||||
mock_train.assert_called()
|
||||
# Now read out the parameters of the callback.
|
||||
# We will then run a second training job, with a new callback object, that will be initialized randomly,
|
||||
# and should have different parameters initially. After checkpoint recovery, it should have exactly the
|
||||
# same parameters as the first callback.
|
||||
parameters1 = list(callback1.evaluator.parameters())
|
||||
callback2 = SSLOnlineEvaluatorInnerEye(class_weights=None,
|
||||
z_dim=1,
|
||||
num_classes=2,
|
||||
dataset="foo",
|
||||
drop_p=0.2,
|
||||
learning_rate=1e-5)
|
||||
# Ensure that the parameters are really different initially
|
||||
parameters2_before_training = list(callback2.evaluator.parameters())
|
||||
assert not torch.allclose(parameters2_before_training[0], parameters1[0])
|
||||
# Start a second training run with recovery
|
||||
last_checkpoint = checkpoints.last_model_path
|
||||
trainer2 = Trainer(default_root_dir=str(test_output_dirs.root_dir),
|
||||
callbacks=[callback2],
|
||||
max_epochs=20,
|
||||
resume_from_checkpoint=last_checkpoint)
|
||||
trainer2.fit(model, datamodule=data)
|
||||
# Read the parameters and check if they are the same as what was stored in the first callback.
|
||||
parameters2_after_training = list(callback2.evaluator.parameters())
|
||||
assert torch.allclose(parameters2_after_training[0], parameters1[0])
|
||||
|
||||
# It's somewhat obsolete, but we can now check that the checkpoint file really contained the optimizer and weights
|
||||
checkpoint = torch.load(last_checkpoint)
|
||||
assert "callbacks" in checkpoint
|
||||
assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"]
|
||||
callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye]
|
||||
assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state
|
||||
assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_online_evaluator_not_distributed() -> None:
|
||||
"""
|
||||
Check if the online evaluator uses the DDP flag correctly when running not distributed
|
||||
"""
|
||||
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel") as mock_ddp:
|
||||
callback = SSLOnlineEvaluatorInnerEye(class_weights=None,
|
||||
z_dim=1,
|
||||
num_classes=2,
|
||||
dataset="foo",
|
||||
drop_p=0.2,
|
||||
learning_rate=1e-5)
|
||||
mock_ddp.assert_not_called()
|
||||
|
||||
# Standard trainer without DDP
|
||||
trainer = Trainer()
|
||||
# Test the flag that the internal logic of on_pretrain_routine_start uses
|
||||
assert not trainer.accelerator_connector.is_distributed
|
||||
mock_module = mock.MagicMock(device=torch.device("cpu"))
|
||||
callback.on_pretrain_routine_start(trainer, mock_module)
|
||||
assert isinstance(callback.evaluator, Module)
|
||||
mock_ddp.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
def test_online_evaluator_distributed() -> None:
|
||||
"""
|
||||
Check if the online evaluator uses the DDP flag correctly when running distributed.
|
||||
"""
|
||||
mock_ddp_result = "mock_ddp_result"
|
||||
mock_sync_result = "mock_sync_result"
|
||||
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SyncBatchNorm.convert_sync_batchnorm",
|
||||
return_value=mock_sync_result) as mock_sync:
|
||||
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel",
|
||||
return_value=mock_ddp_result) as mock_ddp:
|
||||
callback = SSLOnlineEvaluatorInnerEye(class_weights=None,
|
||||
z_dim=1,
|
||||
num_classes=2,
|
||||
dataset="foo",
|
||||
drop_p=0.2,
|
||||
learning_rate=1e-5)
|
||||
|
||||
# Trainer with DDP
|
||||
device = torch.device("cuda:0")
|
||||
mock_module = mock.MagicMock(device=device)
|
||||
trainer = Trainer(accelerator="ddp", gpus=2)
|
||||
# Test the two flags that the internal logic of on_pretrain_routine_start uses
|
||||
assert trainer.accelerator_connector.is_distributed
|
||||
assert trainer.accelerator_connector.use_ddp
|
||||
original_evaluator = callback.evaluator
|
||||
callback.on_pretrain_routine_start(trainer, mock_module)
|
||||
# Check that SyncBatchNorm has been turned on
|
||||
mock_sync.assert_called_once_with(original_evaluator)
|
||||
# Check that the evaluator has been turned into a DDP object
|
||||
# We still need to mock DDP here because the constructor relies on having a process group available
|
||||
mock_ddp.assert_called_once_with(mock_sync_result, device_ids=[device])
|
||||
assert callback.evaluator == mock_ddp_result
|
||||
|
|
Загрузка…
Ссылка в новой задаче