Add densenet encoder to supported encoder for deepmil

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Kenza Bouzid 2023-06-01 13:43:47 +01:00 коммит произвёл GitHub
Родитель 3f0ae71042
Коммит 7003cdd464
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 143 добавлений и 27 удалений

Просмотреть файл

@ -13,9 +13,10 @@ from timm.models import swin_tiny_patch4_window7_224
from timm.models.swin_transformer import SwinTransformer
from torch import Tensor as T, nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from torchvision.models import resnet18, resnet50
from torchvision.models import resnet18, resnet50, densenet121
from torchvision.models.resnet import ResNet
from typing import Callable, Optional, Sequence, Tuple
from torchvision.models.densenet import DenseNet
from typing import Any, Callable, Optional, Sequence, Tuple
from health_cpath.utils.layer_utils import get_imagenet_preprocessing, load_weights_to_model, setup_feature_extractor
@ -115,18 +116,18 @@ class ImageNetEncoder(TileEncoder):
return setup_feature_extractor(pretrained_model, self.input_dim)
class ResNetCheckpointingMixin:
"""Mixin class for checkpointing activations in ResNet-based encoders."""
class BaseCheckpointingMixin:
"""Mixin class for checkpointing activations."""
def __init__(
self,
feature_extractor_fn: ResNet,
feature_extractor_fn: Any,
batchnorm_momentum: Optional[float] = None,
checkpoint_segments_size: int = 2,
use_activation_checkpointing: bool = False,
) -> None:
"""
:param feature_extractor_fn: A ResNet model.
:param feature_extractor_fn: A feature extractor model.
:param batchnorm_momentum: An optional momentum value to use for batch norm layers statistics updates when
`use_activation_checkpointing` is True. If None (default), sqrt of the default momentum retrieved from the
model is used to avoid running statistics from going out of sync due to activations checkpointing.
@ -135,12 +136,53 @@ class ResNetCheckpointingMixin:
:param use_activation_checkpointing: Whether to checkpoint activations during forward pass. This can be used
to reduce the memory required to store gradients by checkpointing the activations (default=False).
"""
assert isinstance(feature_extractor_fn, ResNet), "Expected ResNet model for feature_extractor_fn argument."
self.feature_extractor_fn = feature_extractor_fn
self.checkpoint_segments_size = checkpoint_segments_size
self.batchnorm_momentum = batchnorm_momentum
if use_activation_checkpointing:
self._set_batch_norm_momentum()
self.validate()
def validate(self) -> None:
"""Validation checks for the feature extractor model."""
pass
def _set_batch_norm_momentum_in_layer(self, layer: nn.Module) -> None:
"""Set the momentum of batch norm layers in the given layer to the value of `self.batchnorm_momentum`.
:param layer: The layer to set the batch norm momentum for.
"""
if isinstance(layer, nn.BatchNorm2d):
assert self.batchnorm_momentum is not None, "batchnorm_momentum must be set"
layer.momentum = self.batchnorm_momentum
def _set_all_batch_norm_momentum_in_block(self, layer_block: nn.Sequential) -> None:
"""Set the momentum of batch norm layers in the given block to the value of `self.batchnorm_momentum`.
:param layer_block: A block of layers to set the batch norm momentum for.
"""
for sub_layer in layer_block:
if len(sub_layer._modules) > 0:
for _, layer in sub_layer._modules.items():
self._set_batch_norm_momentum_in_layer(layer)
else:
self._set_batch_norm_momentum_in_layer(sub_layer)
def _set_batch_norm_momentum(self) -> None:
"""Set the momentum of batch norm layers in the feature extractor model"""
raise NotImplementedError
def custom_forward(self, images: torch.Tensor) -> torch.Tensor:
"""Custom forward pass that uses activation checkpointing to save memory."""
raise NotImplementedError
class ResNetCheckpointingMixin(BaseCheckpointingMixin):
"""Mixin class for checkpointing activations in ResNet-based encoders."""
def validate(self) -> None:
"""Validate that the feature extractor is a ResNet model."""
assert isinstance(self.feature_extractor_fn, ResNet), "Expected ResNet model for feature_extractor_fn argument."
def _set_batch_norm_momentum(self) -> None:
"""Set the momentum of batch norm layers in the ResNet model to avoid running statistics from going out of
@ -148,26 +190,17 @@ class ResNetCheckpointingMixin:
these statistics. We can workaround that by using sqrt of default momentum retrieved from the
feature_extractor_fn.
"""
if self.batchnorm_momentum is not None:
_momentum = self.batchnorm_momentum
else:
_momentum = math.sqrt(self.feature_extractor_fn.bn1.momentum)
self.batchnorm_momentum = _momentum
if self.batchnorm_momentum is None:
self.batchnorm_momentum = math.sqrt(self.feature_extractor_fn.bn1.momentum)
# Set momentum for the first batch norm layer
self.feature_extractor_fn.bn1.momentum = _momentum
def _set_bn_momentum(layer_block: nn.Sequential) -> None:
for sub_layer in layer_block:
for _, layer in sub_layer._modules.items():
if isinstance(layer, nn.BatchNorm2d):
layer.momentum = _momentum
self.feature_extractor_fn.bn1.momentum = self.batchnorm_momentum
# Fetch all nested batch norm layers and set momentum
_set_bn_momentum(self.feature_extractor_fn.layer1)
_set_bn_momentum(self.feature_extractor_fn.layer2)
_set_bn_momentum(self.feature_extractor_fn.layer3)
_set_bn_momentum(self.feature_extractor_fn.layer4)
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.layer1)
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.layer2)
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.layer3)
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.layer4)
def custom_forward(self, images: torch.Tensor) -> torch.Tensor:
"""Custom forward pass that uses activation checkpointing to save memory."""
@ -397,6 +430,75 @@ class SwinTransformer_NoPreproc(SwinTransformerCheckpointingMixin, ImageNetEncod
return pretrained_model, pretrained_model.num_features # type: ignore
class DenseNetCheckpointingMixin(BaseCheckpointingMixin):
"""Mixin class for checkpointing activations in DenseNet-based encoders."""
def validate(self) -> None:
"""Validate that the feature extractor is a DenseNet model."""
assert isinstance(self.feature_extractor_fn, DenseNet), "Expected DenseNet for feature_extractor_fn argument."
def _set_batch_norm_momentum(self) -> None:
"""Set the momentum of batch norm layers in the DenseNet model to avoid running statistics from going out of
sync due to activations checkpointing. The forward pass is applied twice which results in double updates of
these statistics. We can workaround that by using sqrt of default momentum retrieved from the
feature_extractor_fn.
"""
if self.batchnorm_momentum is None:
self.batchnorm_momentum = math.sqrt(self.feature_extractor_fn.features.norm0.momentum)
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.features)
def custom_forward(self, images: torch.Tensor) -> torch.Tensor:
"""Custom forward pass that uses activation checkpointing to save memory."""
segments = self.checkpoint_segments_size
features = checkpoint_sequential(self.feature_extractor_fn.features, segments, images)
out = nn.functional.relu(features)
out = nn.functional.adaptive_avg_pool2d(out, (1, 1))
out = torch.flatten(out, 1)
out = self.feature_extractor_fn.classifier(out)
return out
class DenseNet121_NoPreproc(DenseNetCheckpointingMixin, ImageNetEncoder):
"""DenseNet121 encoder without imagenet preprocessing."""
def __init__(
self,
tile_size: int = 224,
n_channels: int = 3,
use_activation_checkpointing: bool = False,
checkpoint_segments_size: int = 2,
batchnorm_momentum: Optional[float] = None,
) -> None:
"""
:param tile_size: The size of the input tiles (default=224).
:param n_channels: The number of channels in the input tiles (default=3).
:param use_activation_checkpointing: Whether to checkpoint activations during forward pass. This can be used
to reduce the memory required to store gradients by checkpointing the activations (default=False).
:param checkpoint_segments_size: The size of checkpointed segments in sequential layers (default=2).
:param batchnorm_momentum: An optional momentum value to use for batch norm layers statistics updates when
`use_activation_checkpointing` is True. If None (default), sqrt of the default momentum retrieved from the
model is used to avoid running statistics from going out of sync due to activations checkpointing.
"""
ImageNetEncoder.__init__(
self,
feature_extraction_model=densenet121,
tile_size=tile_size,
n_channels=n_channels,
apply_imagenet_preprocessing=False,
use_activation_checkpointing=use_activation_checkpointing,
)
DenseNetCheckpointingMixin.__init__(
self, self.feature_extractor_fn, batchnorm_momentum, checkpoint_segments_size, use_activation_checkpointing
)
def _get_encoder(self) -> Tuple[torch.nn.Module, int]:
pretrained_model = self.create_feature_extractor_fn(pretrained=True)
num_features: int = pretrained_model.classifier.in_features # type: ignore
pretrained_model.classifier = nn.Identity()
return pretrained_model, num_features
class ImageNetSimCLREncoder(TileEncoder):
"""SimCLR encoder pretrained on ImageNet"""

Просмотреть файл

@ -9,6 +9,7 @@ from pathlib import Path
from typing import Optional, Tuple
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_cpath.models.encoders import (
DenseNet121_NoPreproc,
HistoSSLEncoder,
ImageNetSimCLREncoder,
SSLEncoder,
@ -147,6 +148,15 @@ class EncoderParams(param.Parameterized):
checkpoint_segments_size=self.checkpoint_segments_size,
)
elif self.encoder_type == DenseNet121_NoPreproc.__name__:
encoder = DenseNet121_NoPreproc(
tile_size=self.tile_size,
n_channels=self.n_channels,
use_activation_checkpointing=self.use_encoder_checkpointing,
checkpoint_segments_size=self.checkpoint_segments_size,
batchnorm_momentum=self.batchnorm_momentum,
)
elif self.encoder_type == ImageNetSimCLREncoder.__name__:
encoder = ImageNetSimCLREncoder(tile_size=self.tile_size, n_channels=self.n_channels)

Просмотреть файл

@ -31,6 +31,7 @@ from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDatase
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
from health_cpath.models.deepmil import DeepMILModule
from health_cpath.models.encoders import (
DenseNet121_NoPreproc,
IdentityEncoder,
ImageNetEncoder,
Resnet18,
@ -764,7 +765,7 @@ def validate_loss_with_activations_checkpointing(
limit = 2
for batch_idx, batch in enumerate(dataloader):
if encoder_type == SwinTransformer_NoPreproc.__name__:
if "Resnet" not in encoder_type:
batch[SlideKey.IMAGE] = [torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8) / 255.0] * 2
loss_ckpt_enc = _get_loss(model_ckpt_enc, batch, batch_idx)
loss_no_ckpt_enc = _get_loss(model_no_ckpt_enc, batch, batch_idx)
@ -781,6 +782,7 @@ def validate_loss_with_activations_checkpointing(
(Resnet18.__name__, 512, None),
(Resnet50.__name__, 2048, 0.1),
(SwinTransformer_NoPreproc.__name__, 768, 0.1),
(DenseNet121_NoPreproc.__name__, 1024, 0.1),
],
)
def test_encoder_checkpointning(
@ -809,7 +811,7 @@ def test_encoder_checkpointning(
# 1. Compare the loss and gradients of the encoder with and without checkpointing
validate_loss_with_activations_checkpointing(train_dataloader, model_ckpt_enc, model_no_ckpt_enc, encoder_type)
if encoder_type != SwinTransformer_NoPreproc.__name__: # SwinT requires images of 224 input size, mock tiles are 28
if "Resnet" in encoder_type: # SwinT and DenseNet require images of 224 input size, mock tiles are 28
# 2. Train the model with and without checkpointing and compare the encoder parameters
trainer_no_ckpt = Trainer(max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2)
trainer_no_ckpt.fit(model_no_ckpt_enc, train_dataloader, val_dataloader)
@ -822,7 +824,7 @@ def test_encoder_checkpointning(
# 3. Check that the custom forward is called only when checkpointing is enabled
sample = next(iter(train_dataloader))
if encoder_type == SwinTransformer_NoPreproc.__name__:
if "Resnet" not in encoder_type:
sample[SlideKey.IMAGE][0] = torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8) / 255.0
with patch.object(model_no_ckpt_enc.encoder, "custom_forward") as custom_forward:
_, _ = model_no_ckpt_enc(sample[SlideKey.IMAGE][0])

Просмотреть файл

@ -15,6 +15,7 @@ from torchvision.models import resnet18
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME, CheckpointDownloader
from health_cpath.models.encoders import (
DenseNet121_NoPreproc,
ImageNetEncoder,
Resnet18,
Resnet18_NoPreproc,
@ -122,7 +123,8 @@ def test_resnet_checkpointing_bn_momentum(encoder_class: ImageNetEncoder, bn_mom
@pytest.mark.parametrize(
"encoder_class", [Resnet18, Resnet18_NoPreproc, Resnet50, Resnet50_NoPreproc, SwinTransformer_NoPreproc]
"encoder_class",
[Resnet18, Resnet18_NoPreproc, Resnet50, Resnet50_NoPreproc, SwinTransformer_NoPreproc, DenseNet121_NoPreproc],
)
def test_custom_forward(encoder_class: ImageNetEncoder) -> None:
encoder = encoder_class(tile_size=TILE_SIZE, use_activation_checkpointing=True)