зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add DenseNet encoder (#892)
Add densenet encoder to supported encoder for deepmil --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Родитель
3f0ae71042
Коммит
7003cdd464
|
@ -13,9 +13,10 @@ from timm.models import swin_tiny_patch4_window7_224
|
|||
from timm.models.swin_transformer import SwinTransformer
|
||||
from torch import Tensor as T, nn
|
||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
from torchvision.models import resnet18, resnet50
|
||||
from torchvision.models import resnet18, resnet50, densenet121
|
||||
from torchvision.models.resnet import ResNet
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
from torchvision.models.densenet import DenseNet
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple
|
||||
|
||||
from health_cpath.utils.layer_utils import get_imagenet_preprocessing, load_weights_to_model, setup_feature_extractor
|
||||
|
||||
|
@ -115,18 +116,18 @@ class ImageNetEncoder(TileEncoder):
|
|||
return setup_feature_extractor(pretrained_model, self.input_dim)
|
||||
|
||||
|
||||
class ResNetCheckpointingMixin:
|
||||
"""Mixin class for checkpointing activations in ResNet-based encoders."""
|
||||
class BaseCheckpointingMixin:
|
||||
"""Mixin class for checkpointing activations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_extractor_fn: ResNet,
|
||||
feature_extractor_fn: Any,
|
||||
batchnorm_momentum: Optional[float] = None,
|
||||
checkpoint_segments_size: int = 2,
|
||||
use_activation_checkpointing: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
:param feature_extractor_fn: A ResNet model.
|
||||
:param feature_extractor_fn: A feature extractor model.
|
||||
:param batchnorm_momentum: An optional momentum value to use for batch norm layers statistics updates when
|
||||
`use_activation_checkpointing` is True. If None (default), sqrt of the default momentum retrieved from the
|
||||
model is used to avoid running statistics from going out of sync due to activations checkpointing.
|
||||
|
@ -135,12 +136,53 @@ class ResNetCheckpointingMixin:
|
|||
:param use_activation_checkpointing: Whether to checkpoint activations during forward pass. This can be used
|
||||
to reduce the memory required to store gradients by checkpointing the activations (default=False).
|
||||
"""
|
||||
assert isinstance(feature_extractor_fn, ResNet), "Expected ResNet model for feature_extractor_fn argument."
|
||||
self.feature_extractor_fn = feature_extractor_fn
|
||||
self.checkpoint_segments_size = checkpoint_segments_size
|
||||
self.batchnorm_momentum = batchnorm_momentum
|
||||
if use_activation_checkpointing:
|
||||
self._set_batch_norm_momentum()
|
||||
self.validate()
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validation checks for the feature extractor model."""
|
||||
pass
|
||||
|
||||
def _set_batch_norm_momentum_in_layer(self, layer: nn.Module) -> None:
|
||||
"""Set the momentum of batch norm layers in the given layer to the value of `self.batchnorm_momentum`.
|
||||
|
||||
:param layer: The layer to set the batch norm momentum for.
|
||||
"""
|
||||
if isinstance(layer, nn.BatchNorm2d):
|
||||
assert self.batchnorm_momentum is not None, "batchnorm_momentum must be set"
|
||||
layer.momentum = self.batchnorm_momentum
|
||||
|
||||
def _set_all_batch_norm_momentum_in_block(self, layer_block: nn.Sequential) -> None:
|
||||
"""Set the momentum of batch norm layers in the given block to the value of `self.batchnorm_momentum`.
|
||||
|
||||
:param layer_block: A block of layers to set the batch norm momentum for.
|
||||
"""
|
||||
for sub_layer in layer_block:
|
||||
if len(sub_layer._modules) > 0:
|
||||
for _, layer in sub_layer._modules.items():
|
||||
self._set_batch_norm_momentum_in_layer(layer)
|
||||
else:
|
||||
self._set_batch_norm_momentum_in_layer(sub_layer)
|
||||
|
||||
def _set_batch_norm_momentum(self) -> None:
|
||||
"""Set the momentum of batch norm layers in the feature extractor model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def custom_forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""Custom forward pass that uses activation checkpointing to save memory."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ResNetCheckpointingMixin(BaseCheckpointingMixin):
|
||||
"""Mixin class for checkpointing activations in ResNet-based encoders."""
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate that the feature extractor is a ResNet model."""
|
||||
assert isinstance(self.feature_extractor_fn, ResNet), "Expected ResNet model for feature_extractor_fn argument."
|
||||
|
||||
def _set_batch_norm_momentum(self) -> None:
|
||||
"""Set the momentum of batch norm layers in the ResNet model to avoid running statistics from going out of
|
||||
|
@ -148,26 +190,17 @@ class ResNetCheckpointingMixin:
|
|||
these statistics. We can workaround that by using sqrt of default momentum retrieved from the
|
||||
feature_extractor_fn.
|
||||
"""
|
||||
if self.batchnorm_momentum is not None:
|
||||
_momentum = self.batchnorm_momentum
|
||||
else:
|
||||
_momentum = math.sqrt(self.feature_extractor_fn.bn1.momentum)
|
||||
self.batchnorm_momentum = _momentum
|
||||
if self.batchnorm_momentum is None:
|
||||
self.batchnorm_momentum = math.sqrt(self.feature_extractor_fn.bn1.momentum)
|
||||
|
||||
# Set momentum for the first batch norm layer
|
||||
self.feature_extractor_fn.bn1.momentum = _momentum
|
||||
|
||||
def _set_bn_momentum(layer_block: nn.Sequential) -> None:
|
||||
for sub_layer in layer_block:
|
||||
for _, layer in sub_layer._modules.items():
|
||||
if isinstance(layer, nn.BatchNorm2d):
|
||||
layer.momentum = _momentum
|
||||
self.feature_extractor_fn.bn1.momentum = self.batchnorm_momentum
|
||||
|
||||
# Fetch all nested batch norm layers and set momentum
|
||||
_set_bn_momentum(self.feature_extractor_fn.layer1)
|
||||
_set_bn_momentum(self.feature_extractor_fn.layer2)
|
||||
_set_bn_momentum(self.feature_extractor_fn.layer3)
|
||||
_set_bn_momentum(self.feature_extractor_fn.layer4)
|
||||
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.layer1)
|
||||
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.layer2)
|
||||
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.layer3)
|
||||
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.layer4)
|
||||
|
||||
def custom_forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""Custom forward pass that uses activation checkpointing to save memory."""
|
||||
|
@ -397,6 +430,75 @@ class SwinTransformer_NoPreproc(SwinTransformerCheckpointingMixin, ImageNetEncod
|
|||
return pretrained_model, pretrained_model.num_features # type: ignore
|
||||
|
||||
|
||||
class DenseNetCheckpointingMixin(BaseCheckpointingMixin):
|
||||
"""Mixin class for checkpointing activations in DenseNet-based encoders."""
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate that the feature extractor is a DenseNet model."""
|
||||
assert isinstance(self.feature_extractor_fn, DenseNet), "Expected DenseNet for feature_extractor_fn argument."
|
||||
|
||||
def _set_batch_norm_momentum(self) -> None:
|
||||
"""Set the momentum of batch norm layers in the DenseNet model to avoid running statistics from going out of
|
||||
sync due to activations checkpointing. The forward pass is applied twice which results in double updates of
|
||||
these statistics. We can workaround that by using sqrt of default momentum retrieved from the
|
||||
feature_extractor_fn.
|
||||
"""
|
||||
if self.batchnorm_momentum is None:
|
||||
self.batchnorm_momentum = math.sqrt(self.feature_extractor_fn.features.norm0.momentum)
|
||||
|
||||
self._set_all_batch_norm_momentum_in_block(self.feature_extractor_fn.features)
|
||||
|
||||
def custom_forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
"""Custom forward pass that uses activation checkpointing to save memory."""
|
||||
segments = self.checkpoint_segments_size
|
||||
features = checkpoint_sequential(self.feature_extractor_fn.features, segments, images)
|
||||
out = nn.functional.relu(features)
|
||||
out = nn.functional.adaptive_avg_pool2d(out, (1, 1))
|
||||
out = torch.flatten(out, 1)
|
||||
out = self.feature_extractor_fn.classifier(out)
|
||||
return out
|
||||
|
||||
|
||||
class DenseNet121_NoPreproc(DenseNetCheckpointingMixin, ImageNetEncoder):
|
||||
"""DenseNet121 encoder without imagenet preprocessing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tile_size: int = 224,
|
||||
n_channels: int = 3,
|
||||
use_activation_checkpointing: bool = False,
|
||||
checkpoint_segments_size: int = 2,
|
||||
batchnorm_momentum: Optional[float] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param tile_size: The size of the input tiles (default=224).
|
||||
:param n_channels: The number of channels in the input tiles (default=3).
|
||||
:param use_activation_checkpointing: Whether to checkpoint activations during forward pass. This can be used
|
||||
to reduce the memory required to store gradients by checkpointing the activations (default=False).
|
||||
:param checkpoint_segments_size: The size of checkpointed segments in sequential layers (default=2).
|
||||
:param batchnorm_momentum: An optional momentum value to use for batch norm layers statistics updates when
|
||||
`use_activation_checkpointing` is True. If None (default), sqrt of the default momentum retrieved from the
|
||||
model is used to avoid running statistics from going out of sync due to activations checkpointing.
|
||||
"""
|
||||
ImageNetEncoder.__init__(
|
||||
self,
|
||||
feature_extraction_model=densenet121,
|
||||
tile_size=tile_size,
|
||||
n_channels=n_channels,
|
||||
apply_imagenet_preprocessing=False,
|
||||
use_activation_checkpointing=use_activation_checkpointing,
|
||||
)
|
||||
DenseNetCheckpointingMixin.__init__(
|
||||
self, self.feature_extractor_fn, batchnorm_momentum, checkpoint_segments_size, use_activation_checkpointing
|
||||
)
|
||||
|
||||
def _get_encoder(self) -> Tuple[torch.nn.Module, int]:
|
||||
pretrained_model = self.create_feature_extractor_fn(pretrained=True)
|
||||
num_features: int = pretrained_model.classifier.in_features # type: ignore
|
||||
pretrained_model.classifier = nn.Identity()
|
||||
return pretrained_model, num_features
|
||||
|
||||
|
||||
class ImageNetSimCLREncoder(TileEncoder):
|
||||
"""SimCLR encoder pretrained on ImageNet"""
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||
from typing import Optional, Tuple
|
||||
from health_ml.utils.checkpoint_utils import CheckpointParser
|
||||
from health_cpath.models.encoders import (
|
||||
DenseNet121_NoPreproc,
|
||||
HistoSSLEncoder,
|
||||
ImageNetSimCLREncoder,
|
||||
SSLEncoder,
|
||||
|
@ -147,6 +148,15 @@ class EncoderParams(param.Parameterized):
|
|||
checkpoint_segments_size=self.checkpoint_segments_size,
|
||||
)
|
||||
|
||||
elif self.encoder_type == DenseNet121_NoPreproc.__name__:
|
||||
encoder = DenseNet121_NoPreproc(
|
||||
tile_size=self.tile_size,
|
||||
n_channels=self.n_channels,
|
||||
use_activation_checkpointing=self.use_encoder_checkpointing,
|
||||
checkpoint_segments_size=self.checkpoint_segments_size,
|
||||
batchnorm_momentum=self.batchnorm_momentum,
|
||||
)
|
||||
|
||||
elif self.encoder_type == ImageNetSimCLREncoder.__name__:
|
||||
encoder = ImageNetSimCLREncoder(tile_size=self.tile_size, n_channels=self.n_channels)
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, TilesDatase
|
|||
from health_cpath.datasets.default_paths import PANDA_5X_TILES_DATASET_ID, TCGA_CRCK_DATASET_DIR
|
||||
from health_cpath.models.deepmil import DeepMILModule
|
||||
from health_cpath.models.encoders import (
|
||||
DenseNet121_NoPreproc,
|
||||
IdentityEncoder,
|
||||
ImageNetEncoder,
|
||||
Resnet18,
|
||||
|
@ -764,7 +765,7 @@ def validate_loss_with_activations_checkpointing(
|
|||
limit = 2
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if encoder_type == SwinTransformer_NoPreproc.__name__:
|
||||
if "Resnet" not in encoder_type:
|
||||
batch[SlideKey.IMAGE] = [torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8) / 255.0] * 2
|
||||
loss_ckpt_enc = _get_loss(model_ckpt_enc, batch, batch_idx)
|
||||
loss_no_ckpt_enc = _get_loss(model_no_ckpt_enc, batch, batch_idx)
|
||||
|
@ -781,6 +782,7 @@ def validate_loss_with_activations_checkpointing(
|
|||
(Resnet18.__name__, 512, None),
|
||||
(Resnet50.__name__, 2048, 0.1),
|
||||
(SwinTransformer_NoPreproc.__name__, 768, 0.1),
|
||||
(DenseNet121_NoPreproc.__name__, 1024, 0.1),
|
||||
],
|
||||
)
|
||||
def test_encoder_checkpointning(
|
||||
|
@ -809,7 +811,7 @@ def test_encoder_checkpointning(
|
|||
# 1. Compare the loss and gradients of the encoder with and without checkpointing
|
||||
validate_loss_with_activations_checkpointing(train_dataloader, model_ckpt_enc, model_no_ckpt_enc, encoder_type)
|
||||
|
||||
if encoder_type != SwinTransformer_NoPreproc.__name__: # SwinT requires images of 224 input size, mock tiles are 28
|
||||
if "Resnet" in encoder_type: # SwinT and DenseNet require images of 224 input size, mock tiles are 28
|
||||
# 2. Train the model with and without checkpointing and compare the encoder parameters
|
||||
trainer_no_ckpt = Trainer(max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2)
|
||||
trainer_no_ckpt.fit(model_no_ckpt_enc, train_dataloader, val_dataloader)
|
||||
|
@ -822,7 +824,7 @@ def test_encoder_checkpointning(
|
|||
|
||||
# 3. Check that the custom forward is called only when checkpointing is enabled
|
||||
sample = next(iter(train_dataloader))
|
||||
if encoder_type == SwinTransformer_NoPreproc.__name__:
|
||||
if "Resnet" not in encoder_type:
|
||||
sample[SlideKey.IMAGE][0] = torch.randint(0, 255, (4, 3, 224, 224), dtype=torch.uint8) / 255.0
|
||||
with patch.object(model_no_ckpt_enc.encoder, "custom_forward") as custom_forward:
|
||||
_, _ = model_no_ckpt_enc(sample[SlideKey.IMAGE][0])
|
||||
|
|
|
@ -15,6 +15,7 @@ from torchvision.models import resnet18
|
|||
from health_ml.utils.common_utils import DEFAULT_AML_CHECKPOINT_DIR
|
||||
from health_ml.utils.checkpoint_utils import LAST_CHECKPOINT_FILE_NAME, CheckpointDownloader
|
||||
from health_cpath.models.encoders import (
|
||||
DenseNet121_NoPreproc,
|
||||
ImageNetEncoder,
|
||||
Resnet18,
|
||||
Resnet18_NoPreproc,
|
||||
|
@ -122,7 +123,8 @@ def test_resnet_checkpointing_bn_momentum(encoder_class: ImageNetEncoder, bn_mom
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"encoder_class", [Resnet18, Resnet18_NoPreproc, Resnet50, Resnet50_NoPreproc, SwinTransformer_NoPreproc]
|
||||
"encoder_class",
|
||||
[Resnet18, Resnet18_NoPreproc, Resnet50, Resnet50_NoPreproc, SwinTransformer_NoPreproc, DenseNet121_NoPreproc],
|
||||
)
|
||||
def test_custom_forward(encoder_class: ImageNetEncoder) -> None:
|
||||
encoder = encoder_class(tile_size=TILE_SIZE, use_activation_checkpointing=True)
|
||||
|
|
Загрузка…
Ссылка в новой задаче