Add subsampling transform and mean pooling (#656)

* Add subsampling transform

* Add option to allow_missing_keys for Subsampled

* Add dropout param to BaseMIL

* Add docstring and tests for Subsampled

* Update changelog

* Update to hi-ml with mean pooling

* Enable mean pooling in DeepMIL

* Add/refactor mean pooling tests

* Update changelog

* Update to latest hi-ml with mean pooling
This commit is contained in:
Daniel Coelho de Castro 2022-02-21 11:24:14 +00:00 коммит произвёл GitHub
Родитель 1600ef3ddf
Коммит e2ec5cc839
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 149 добавлений и 14 удалений

Просмотреть файл

@ -48,6 +48,7 @@ jobs that run in AzureML.
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
- ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling.
### Changed
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3.

Просмотреть файл

@ -204,7 +204,7 @@ class DeepMILModule(LightningModule):
with set_grad_enabled(self.is_finetune):
instance_features = self.encoder(instances) # N X L x 1 x 1
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)
bag_features = bag_features.view(1, -1)
bag_logit = self.classifier_fn(bag_features)
return bag_logit, attentions

Просмотреть файл

@ -10,7 +10,7 @@ import torch
import numpy as np
import PIL
from monai.config.type_definitions import KeysCollection
from monai.transforms.transform import MapTransform
from monai.transforms.transform import MapTransform, Randomizable
from torchvision.transforms.functional import to_tensor
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
@ -92,7 +92,7 @@ class EncodeTilesBatchd(MapTransform):
allow_missing_keys: bool = False,
chunk_size: int = 0) -> None:
"""
:param keys: Key(s) for the image path(s) in the input dictionary.
:param keys: Key(s) for the image tensor(s) in the input dictionary.
:param encoder: The tile encoder to use for feature extraction.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
@ -128,3 +128,42 @@ class EncodeTilesBatchd(MapTransform):
for key in self.key_iterator(out_data):
out_data[key] = self._encode_tiles(data[key])
return out_data
def take_indices(data: Sequence, indices: np.ndarray) -> Sequence:
if isinstance(data, (np.ndarray, torch.Tensor)):
return data[indices]
elif isinstance(data, Sequence):
return [data[i] for i in indices]
else:
raise ValueError(f"Data of type {type(data)} is not indexable")
class Subsampled(MapTransform, Randomizable):
"""Dictionary transform to randomly subsample the data down to a fixed maximum length"""
def __init__(self, keys: KeysCollection, max_size: int,
allow_missing_keys: bool = False) -> None:
"""
:param keys: Key(s) for all batch elements that must be subsampled.
:param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at
random down to `max_size` along their first dimension. If shorter, the elements are merely
shuffled.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
"""
super().__init__(keys, allow_missing_keys=allow_missing_keys)
self.max_size = max_size
self._indices: np.ndarray
def randomize(self, total_size: int) -> None:
subsample_size = min(self.max_size, total_size)
self._indices = self.R.choice(total_size, size=subsample_size)
def __call__(self, data: Mapping) -> Mapping:
out_data = dict(data) # create shallow copy
size = len(data[self.keys[0]])
self.randomize(size)
for key in self.key_iterator(out_data):
out_data[key] = take_indices(data[key], self._indices)
return out_data

Просмотреть файл

@ -14,7 +14,7 @@ import param
from torch import nn
from torchvision.models.resnet import resnet18
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
@ -90,6 +90,8 @@ class BaseMIL(LightningContainer):
return AttentionLayer
elif self.pooling_type == GatedAttentionLayer.__name__:
return GatedAttentionLayer
elif self.pooling_type == MeanPoolingLayer.__name__:
return MeanPoolingLayer
else:
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")

Просмотреть файл

@ -14,6 +14,7 @@ from torchvision.models import resnet18
from health_ml.networks.layers.attention_layers import (
AttentionLayer,
GatedAttentionLayer,
MeanPoolingLayer,
)
from InnerEye.ML.lightning_container import LightningContainer
@ -37,14 +38,7 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule(
def _test_lightningmodule(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
@ -119,6 +113,50 @@ def test_lightningmodule(
assert torch.all(score <= 1)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_attention(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=pooling_layer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_mean_pooling(
n_classes: int,
batch_size: int,
max_bag_size: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=MeanPoolingLayer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=1,
pool_out_dim=1,
dropout_rate=dropout_rate)
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
device = "cuda" if use_gpu else "cpu"
return {

Просмотреть файл

@ -7,6 +7,7 @@ import os
from pathlib import Path
from typing import Callable, Sequence, Union
import numpy as np
import pytest
import torch
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
@ -19,7 +20,7 @@ from health_ml.utils.bag_utils import BagDataset
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled
from Tests.ML.util import assert_dicts_equal
@ -153,3 +154,57 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
bagged_subset,
transform=transform,
cache_subdir="TCGA-CRCk_embed_cache")
@pytest.mark.parametrize('include_non_indexable', [True, False])
@pytest.mark.parametrize('allow_missing_keys', [True, False])
def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None:
batch_size = 5
max_size = batch_size // 2
data = {
'array_1d': np.random.randn(batch_size),
'array_2d': np.random.randn(batch_size, 4),
'tensor_1d': torch.randn(batch_size),
'tensor_2d': torch.randn(batch_size, 4),
'list': torch.randn(batch_size).tolist(),
'indices': list(range(batch_size)),
'non-indexable': 42,
}
keys_to_subsample = list(data.keys())
if not include_non_indexable:
keys_to_subsample.remove('non-indexable')
keys_to_subsample.append('missing-key')
subsampling = Subsampled(keys_to_subsample, max_size=max_size,
allow_missing_keys=allow_missing_keys)
if include_non_indexable:
with pytest.raises(ValueError):
sub_data = subsampling(data)
return
elif not allow_missing_keys:
with pytest.raises(KeyError):
sub_data = subsampling(data)
return
else:
sub_data = subsampling(data)
assert set(sub_data.keys()) == set(data.keys())
# Check lenghts before and after subsampling
for key in keys_to_subsample:
if key not in data:
continue # Skip missing keys
assert len(data[key]) == batch_size # type: ignore
assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore
# Check contents of subsampled elements
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
for idx, elem in zip(sub_data['indices'], sub_data[key]):
assert np.array_equal(elem, data[key][idx]) # type: ignore
# Check that subsampling is random, i.e. subsequent calls shouldn't give identical results
sub_data2 = subsampling(data)
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore

2
hi-ml

@ -1 +1 @@
Subproject commit a33c1ed07da8a42486dec9f939cd59eea4b2583e
Subproject commit 2bc397b4707b56fecca624ce81e6883e0170b24b