Add dropout to DeepMIL and fix feature extractor setup (#653)

* Add dropout to DeepMILModule, with param in BaseMIL
* Fix feature extractor setup for torchvision models
This commit is contained in:
Daniel Coelho de Castro 2022-02-07 13:09:04 +00:00 коммит произвёл GitHub
Родитель d617c8107c
Коммит eda76357f0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 84 добавлений и 36 удалений

Просмотреть файл

@ -47,6 +47,7 @@ jobs that run in AzureML.
- ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs
- ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.

Просмотреть файл

@ -46,6 +46,7 @@ class DeepMILModule(LightningModule):
pooling_layer: Callable[[int, int, int], nn.Module],
pool_hidden_dim: int = 128,
pool_out_dim: int = 1,
dropout_rate: Optional[float] = None,
class_weights: Optional[Tensor] = None,
l_rate: float = 5e-4,
weight_decay: float = 1e-4,
@ -64,6 +65,7 @@ class DeepMILModule(LightningModule):
`torch.nn.Module` constructor accepting input, hidden, and output pooling `int` dimensions.
:param pool_hidden_dim: Hidden dimension of pooling layer (default=128).
:param pool_out_dim: Output dimension of pooling layer (default=1).
:param dropout_rate: Rate of pre-classifier dropout (0-1). `None` for no dropout (default).
:param class_weights: Tensor containing class weights (default=None).
:param l_rate: Optimiser learning rate.
:param weight_decay: Weight decay parameter for L2 regularisation.
@ -82,6 +84,7 @@ class DeepMILModule(LightningModule):
self.pool_hidden_dim = pool_hidden_dim
self.pool_out_dim = pool_out_dim
self.pooling_layer = pooling_layer
self.dropout_rate = dropout_rate
self.class_weights = class_weights
self.encoder = encoder
self.num_encoding = self.encoder.num_encoding
@ -130,8 +133,14 @@ class DeepMILModule(LightningModule):
return pooling_layer, num_features
def get_classifier(self) -> Callable:
return nn.Linear(in_features=self.num_pooling,
out_features=self.n_classes)
classifier_layer = nn.Linear(in_features=self.num_pooling,
out_features=self.n_classes)
if self.dropout_rate is None:
return classifier_layer
elif 0 <= self.dropout_rate < 1:
return nn.Sequential(nn.Dropout(self.dropout_rate), classifier_layer)
else:
raise ValueError(f"Dropout rate should be in [0, 1), got {self.dropout_rate}")
def get_loss(self) -> Callable:
if self.n_classes > 1:
@ -186,13 +195,13 @@ class DeepMILModule(LightningModule):
else:
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)
def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
with no_grad():
H = self.encoder(images) # N X L x 1 x 1
A, M = self.aggregation_fn(H) # A: K x N | M: K x L
M = M.view(-1, self.num_encoding * self.pool_out_dim)
Y_prob = self.classifier_fn(M)
return Y_prob, A
instance_features = self.encoder(instances) # N X L x 1 x 1
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)
bag_logit = self.classifier_fn(bag_features)
return bag_logit, attentions
def configure_optimizers(self) -> optim.Optimizer:
return optim.Adam(self.parameters(), lr=self.l_rate, weight_decay=self.weight_decay,

Просмотреть файл

@ -3,9 +3,9 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Callable, Tuple
from typing import Tuple
from torch import as_tensor, device, nn, prod, rand
from torch import as_tensor, device, nn, no_grad, prod, rand
from torch.hub import load_state_dict_from_url
from torchvision.transforms import Normalize
@ -15,15 +15,23 @@ def get_imagenet_preprocessing() -> nn.Module:
def setup_feature_extractor(pretrained_model: nn.Module,
input_dim: Tuple[int, int, int]) -> Tuple[Callable, int]:
layers = list(pretrained_model.children())[:-1]
layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps
feature_extractor = nn.Sequential(*layers)
input_dim: Tuple[int, int, int]) -> Tuple[nn.Module, int]:
try:
# Attempt to auto-detect final classification layer:
num_features: int = pretrained_model.fc.in_features # type: ignore
pretrained_model.fc = nn.Flatten()
feature_extractor = pretrained_model
except AttributeError:
# Otherwise fallback to sequence of child modules:
layers = list(pretrained_model.children())[:-1]
layers.append(nn.Flatten()) # flatten non-batch dims in case of spatial feature maps
feature_extractor = nn.Sequential(*layers)
with no_grad():
feature_shape = feature_extractor(rand(1, *input_dim)).shape
num_features = int(prod(as_tensor(feature_shape)).item())
# fix weights, no fine-tuning
for param in feature_extractor.parameters():
param.requires_grad = False
feature_shape = feature_extractor(rand(1, *input_dim)).shape
num_features = int(prod(as_tensor(feature_shape)).item())
return feature_extractor, num_features

Просмотреть файл

@ -8,7 +8,7 @@ It is responsible for instantiating the encoder and full DeepMIL model. Subclass
their datamodules and configure experiment-specific parameters.
"""
from pathlib import Path
from typing import Type # noqa
from typing import Optional, Type # noqa
import param
from torch import nn
@ -27,6 +27,7 @@ from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, Identit
class BaseMIL(LightningContainer):
# Model parameters:
pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
# l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass
# Encoder parameters:
@ -98,6 +99,7 @@ class BaseMIL(LightningContainer):
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
dropout_rate=self.dropout_rate,
class_weights=self.data_module.class_weights,
l_rate=self.l_rate,
weight_decay=self.weight_decay,

Просмотреть файл

@ -4,7 +4,7 @@
# ------------------------------------------------------------------------------------------
import os
from typing import Callable, Dict, List, Type # noqa
from typing import Callable, Dict, List, Optional, Type # noqa
import pytest
import torch
@ -39,10 +39,11 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("max_bag_size", [1, 3])
@pytest.mark.parametrize("pool_hidden_dim", [1, 4])
@pytest.mark.parametrize("pool_out_dim", [1, 5])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
@ -50,6 +51,7 @@ def test_lightningmodule(
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
assert n_classes > 0
@ -63,6 +65,7 @@ def test_lightningmodule(
pooling_layer=pooling_layer,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate,
)
bag_images = rand([batch_size, max_bag_size, *module.encoder.input_dim])

Просмотреть файл

@ -3,43 +3,68 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Callable
from typing import Callable, Tuple
import numpy as np
import pytest
from torch import Tensor, float32, nn, rand
from torchvision.models import resnet18
from InnerEye.ML.Histopathology.models.encoders import (TileEncoder, HistoSSLEncoder, ImageNetEncoder,
ImageNetSimCLREncoder)
from InnerEye.ML.Histopathology.utils.layer_utils import setup_feature_extractor
TILE_SIZE = 224
INPUT_DIMS = (3, TILE_SIZE, TILE_SIZE)
def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=TILE_SIZE)
def get_simclr_imagenet_encoder() -> TileEncoder:
return ImageNetSimCLREncoder(tile_size=224)
return ImageNetSimCLREncoder(tile_size=TILE_SIZE)
def get_histo_ssl_encoder() -> TileEncoder:
return HistoSSLEncoder(tile_size=224)
return HistoSSLEncoder(tile_size=TILE_SIZE)
def _test_encoder(encoder: nn.Module, input_dims: Tuple[int, ...], output_dim: int,
batch_size: int = 5) -> None:
if isinstance(encoder, nn.Module):
for param_name, param in encoder.named_parameters():
assert not param.requires_grad, \
f"Feature extractor has unfrozen parameters: {param_name}"
images = rand(batch_size, *input_dims, dtype=float32)
features = encoder(images)
assert isinstance(features, Tensor)
assert features.shape == (batch_size, output_dim)
@pytest.mark.parametrize("create_encoder_fn", [get_supervised_imagenet_encoder,
get_simclr_imagenet_encoder,
get_histo_ssl_encoder])
def test_encoder(create_encoder_fn: Callable[[], TileEncoder]) -> None:
batch_size = 10
encoder = create_encoder_fn()
_test_encoder(encoder, input_dims=encoder.input_dim, output_dim=encoder.num_encoding)
if isinstance(encoder, nn.Module):
for param_name, param in encoder.named_parameters():
assert not param.requires_grad, \
f"Feature extractor has unfrozen parameters: {param_name}"
images = rand(batch_size, *encoder.input_dim, dtype=float32)
def _dummy_classifier() -> nn.Module:
input_size = np.prod(INPUT_DIMS)
hidden_dim = 10
return nn.Sequential(
nn.Flatten(),
nn.Linear(input_size, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
features = encoder(images)
assert isinstance(features, Tensor)
assert features.shape == (batch_size, encoder.num_encoding)
@pytest.mark.parametrize('create_classifier_fn', [resnet18, _dummy_classifier])
def test_setup_feature_extractor(create_classifier_fn: Callable[[], nn.Module]) -> None:
classifier = create_classifier_fn()
encoder, num_features = setup_feature_extractor(classifier, INPUT_DIMS)
_test_encoder(encoder, input_dims=INPUT_DIMS, output_dim=num_features)