Add dropout to DeepMIL and fix feature extractor setup (#653)
* Add dropout to DeepMILModule, with param in BaseMIL * Fix feature extractor setup for torchvision models
This commit is contained in:
Родитель
d617c8107c
Коммит
eda76357f0
|
@ -47,6 +47,7 @@ jobs that run in AzureML.
|
|||
- ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs
|
||||
- ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL
|
||||
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
|
||||
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
|
||||
|
||||
### Changed
|
||||
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
|
||||
|
|
|
@ -46,6 +46,7 @@ class DeepMILModule(LightningModule):
|
|||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
pool_hidden_dim: int = 128,
|
||||
pool_out_dim: int = 1,
|
||||
dropout_rate: Optional[float] = None,
|
||||
class_weights: Optional[Tensor] = None,
|
||||
l_rate: float = 5e-4,
|
||||
weight_decay: float = 1e-4,
|
||||
|
@ -64,6 +65,7 @@ class DeepMILModule(LightningModule):
|
|||
`torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions.
|
||||
:param pool_hidden_dim: Hidden dimension of pooling layer (default=128).
|
||||
:param pool_out_dim: Output dimension of pooling layer (default=1).
|
||||
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
|
||||
:param class_weights: Tensor containing class weights (default=None).
|
||||
:param l_rate: Optimiser learning rate.
|
||||
:param weight_decay: Weight decay parameter for L2 regularisation.
|
||||
|
@ -82,6 +84,7 @@ class DeepMILModule(LightningModule):
|
|||
self.pool_hidden_dim = pool_hidden_dim
|
||||
self.pool_out_dim = pool_out_dim
|
||||
self.pooling_layer = pooling_layer
|
||||
self.dropout_rate = dropout_rate
|
||||
self.class_weights = class_weights
|
||||
self.encoder = encoder
|
||||
self.num_encoding = self.encoder.num_encoding
|
||||
|
@ -130,8 +133,14 @@ class DeepMILModule(LightningModule):
|
|||
return pooling_layer, num_features
|
||||
|
||||
def get_classifier(self) -> Callable:
|
||||
return nn.Linear(in_features=self.num_pooling,
|
||||
classifier_layer = nn.Linear(in_features=self.num_pooling,
|
||||
out_features=self.n_classes)
|
||||
if self.dropout_rate is None:
|
||||
return classifier_layer
|
||||
elif 0 <= self.dropout_rate < 1:
|
||||
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
|
||||
else:
|
||||
raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}")
|
||||
|
||||
def get_loss(self) -> Callable:
|
||||
if self.n_classes > 1:
|
||||
|
@ -186,13 +195,13 @@ class DeepMILModule(LightningModule):
|
|||
else:
|
||||
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
|
||||
|
||||
def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
|
||||
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
|
||||
with no_grad():
|
||||
H = self.encoder(images) # N X L x 1 x 1
|
||||
A, M = self.aggregation_fn(H) # A: K x N | M: K x L
|
||||
M = M.view(-1, self.num_encoding * self.pool_out_dim)
|
||||
Y_prob = self.classifier_fn(M)
|
||||
return Y_prob, A
|
||||
instance_features = self.encoder(instances) # N X L x 1 x 1
|
||||
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
|
||||
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)
|
||||
bag_logit = self.classifier_fn(bag_features)
|
||||
return bag_logit, attentions
|
||||
|
||||
def configure_optimizers(self) -> optim.Optimizer:
|
||||
return optim.Adam(self.parameters(), lr=self.l_rate, weight_decay=self.weight_decay,
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Callable, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from torch import as_tensor, device, nn, prod, rand
|
||||
from torch import as_tensor, device, nn, no_grad, prod, rand
|
||||
from torch.hub import load_state_dict_from_url
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
|
@ -15,15 +15,23 @@ def get_imagenet_preprocessing() -> nn.Module:
|
|||
|
||||
|
||||
def setup_feature_extractor(pretrained_model: nn.Module,
|
||||
input_dim: Tuple[int, int, int]) -> Tuple[Callable, int]:
|
||||
input_dim: Tuple[int, int, int]) -> Tuple[nn.Module, int]:
|
||||
try:
|
||||
# Attempt to auto-detect final classification layer:
|
||||
num_features: int = pretrained_model.fc.in_features # type: ignore
|
||||
pretrained_model.fc = nn.Flatten()
|
||||
feature_extractor = pretrained_model
|
||||
except AttributeError:
|
||||
# Otherwise fallback to sequence of child modules:
|
||||
layers = list(pretrained_model.children())[:-1]
|
||||
layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps
|
||||
feature_extractor = nn.Sequential(*layers)
|
||||
with no_grad():
|
||||
feature_shape = feature_extractor(rand(1, *input_dim)).shape
|
||||
num_features = int(prod(as_tensor(feature_shape)).item())
|
||||
# fix weights, no fine-tuning
|
||||
for param in feature_extractor.parameters():
|
||||
param.requires_grad = False
|
||||
feature_shape = feature_extractor(rand(1, *input_dim)).shape
|
||||
num_features = int(prod(as_tensor(feature_shape)).item())
|
||||
return feature_extractor, num_features
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ It is responsible for instantiating the encoder and full DeepMIL model. Subclass
|
|||
their datamodules and configure experiment-specific parameters.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Type # noqa
|
||||
from typing import Optional, Type # noqa
|
||||
|
||||
import param
|
||||
from torch import nn
|
||||
|
@ -27,6 +27,7 @@ from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, Identit
|
|||
class BaseMIL(LightningContainer):
|
||||
# Model parameters:
|
||||
pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
|
||||
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
|
||||
# l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass
|
||||
|
||||
# Encoder parameters:
|
||||
|
@ -98,6 +99,7 @@ class BaseMIL(LightningContainer):
|
|||
label_column=self.data_module.train_dataset.LABEL_COLUMN,
|
||||
n_classes=self.data_module.train_dataset.N_CLASSES,
|
||||
pooling_layer=self.get_pooling_layer(),
|
||||
dropout_rate=self.dropout_rate,
|
||||
class_weights=self.data_module.class_weights,
|
||||
l_rate=self.l_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from typing import Callable, Dict, List, Type # noqa
|
||||
from typing import Callable, Dict, List, Optional, Type # noqa
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -39,10 +39,11 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
|
|||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 3])
|
||||
@pytest.mark.parametrize("pool_hidden_dim", [1, 4])
|
||||
@pytest.mark.parametrize("pool_out_dim", [1, 5])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
|
||||
@pytest.mark.parametrize("pool_out_dim", [1, 6])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule(
|
||||
n_classes: int,
|
||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
|
@ -50,6 +51,7 @@ def test_lightningmodule(
|
|||
max_bag_size: int,
|
||||
pool_hidden_dim: int,
|
||||
pool_out_dim: int,
|
||||
dropout_rate: Optional[float],
|
||||
) -> None:
|
||||
|
||||
assert n_classes > 0
|
||||
|
@ -63,6 +65,7 @@ def test_lightningmodule(
|
|||
pooling_layer=pooling_layer,
|
||||
pool_hidden_dim=pool_hidden_dim,
|
||||
pool_out_dim=pool_out_dim,
|
||||
dropout_rate=dropout_rate,
|
||||
)
|
||||
|
||||
bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim])
|
||||
|
|
|
@ -3,43 +3,68 @@
|
|||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Callable
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from torch import Tensor, float32, nn, rand
|
||||
from torchvision.models import resnet18
|
||||
|
||||
from InnerEye.ML.Histopathology.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder,
|
||||
ImageNetSimCLREncoder)
|
||||
from InnerEye.ML.Histopathology.utils.layer_utils import setup_feature_extractor
|
||||
|
||||
TILE_SIZE = 224
|
||||
INPUT_DIMS = (3, TILE_SIZE, TILE_SIZE)
|
||||
|
||||
|
||||
def get_supervised_imagenet_encoder() -> TileEncoder:
|
||||
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
|
||||
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=TILE_SIZE)
|
||||
|
||||
|
||||
def get_simclr_imagenet_encoder() -> TileEncoder:
|
||||
return ImageNetSimCLREncoder(tile_size=224)
|
||||
return ImageNetSimCLREncoder(tile_size=TILE_SIZE)
|
||||
|
||||
|
||||
def get_histo_ssl_encoder() -> TileEncoder:
|
||||
return HistoSSLEncoder(tile_size=224)
|
||||
return HistoSSLEncoder(tile_size=TILE_SIZE)
|
||||
|
||||
|
||||
def _test_encoder(encoder: nn.Module, input_dims: Tuple[int, ...], output_dim: int,
|
||||
batch_size: int = 5) -> None:
|
||||
if isinstance(encoder, nn.Module):
|
||||
for param_name, param in encoder.named_parameters():
|
||||
assert not param.requires_grad, \
|
||||
f"Feature extractor has unfrozen parameters: {param_name}"
|
||||
|
||||
images = rand(batch_size, *input_dims, dtype=float32)
|
||||
|
||||
features = encoder(images)
|
||||
assert isinstance(features, Tensor)
|
||||
assert features.shape == (batch_size, output_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder,
|
||||
get_simclr_imagenet_encoder,
|
||||
get_histo_ssl_encoder])
|
||||
def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None:
|
||||
batch_size = 10
|
||||
|
||||
encoder = create_encoder_fn()
|
||||
_test_encoder(encoder, input_dims=encoder.input_dim, output_dim=encoder.num_encoding)
|
||||
|
||||
if isinstance(encoder, nn.Module):
|
||||
for param_name, param in encoder.named_parameters():
|
||||
assert not param.requires_grad, \
|
||||
f"Feature extractor has unfrozen parameters: {param_name}"
|
||||
|
||||
images = rand(batch_size, *encoder.input_dim, dtype=float32)
|
||||
def _dummy_classifier() -> nn.Module:
|
||||
input_size = np.prod(INPUT_DIMS)
|
||||
hidden_dim = 10
|
||||
return nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(input_size, hidden_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_dim, 1)
|
||||
)
|
||||
|
||||
features = encoder(images)
|
||||
assert isinstance(features, Tensor)
|
||||
assert features.shape == (batch_size, encoder.num_encoding)
|
||||
|
||||
@pytest.mark.parametrize('create_classifier_fn', [resnet18, _dummy_classifier])
|
||||
def test_setup_feature_extractor(create_classifier_fn: Callable[[], nn.Module]) -> None:
|
||||
classifier = create_classifier_fn()
|
||||
encoder, num_features = setup_feature_extractor(classifier, INPUT_DIMS)
|
||||
_test_encoder(encoder, input_dims=INPUT_DIMS, output_dim=num_features)
|
||||
|
|
Загрузка…
Ссылка в новой задаче