Add subsampling transform and mean pooling (#656)

* Add subsampling transform

* Add option to allow_missing_keys for Subsampled

* Add dropout param to BaseMIL

* Add docstring and tests for Subsampled

* Update changelog

* Update to hi-ml with mean pooling

* Enable mean pooling in DeepMIL

* Add/refactor mean pooling tests

* Update changelog

* Update to latest hi-ml with mean pooling
This commit is contained in:
Daniel Coelho de Castro 2022-02-21 11:24:14 +00:00 коммит произвёл GitHub
Родитель 1600ef3ddf
Коммит e2ec5cc839
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 149 добавлений и 14 удалений

Просмотреть файл

@ -48,6 +48,7 @@ jobs that run in AzureML.
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL - ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup. - ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task. - ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
- ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling.
### Changed ### Changed
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3. - ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3.

Просмотреть файл

@ -204,7 +204,7 @@ class DeepMILModule(LightningModule):
with set_grad_enabled(self.is_finetune): with set_grad_enabled(self.is_finetune):
instance_features = self.encoder(instances) # N X L x 1 x 1 instance_features = self.encoder(instances) # N X L x 1 x 1
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim) bag_features = bag_features.view(1, -1)
bag_logit = self.classifier_fn(bag_features) bag_logit = self.classifier_fn(bag_features)
return bag_logit, attentions return bag_logit, attentions

Просмотреть файл

@ -10,7 +10,7 @@ import torch
import numpy as np import numpy as np
import PIL import PIL
from monai.config.type_definitions import KeysCollection from monai.config.type_definitions import KeysCollection
from monai.transforms.transform import MapTransform from monai.transforms.transform import MapTransform, Randomizable
from torchvision.transforms.functional import to_tensor from torchvision.transforms.functional import to_tensor
from InnerEye.ML.Histopathology.models.encoders import TileEncoder from InnerEye.ML.Histopathology.models.encoders import TileEncoder
@ -92,7 +92,7 @@ class EncodeTilesBatchd(MapTransform):
allow_missing_keys: bool = False, allow_missing_keys: bool = False,
chunk_size: int = 0) -> None: chunk_size: int = 0) -> None:
""" """
:param keys: Key(s) for the image path(s) in the input dictionary. :param keys: Key(s) for the image tensor(s) in the input dictionary.
:param encoder: The tile encoder to use for feature extraction. :param encoder: The tile encoder to use for feature extraction.
:param allow_missing_keys: If `False` (default), raises an exception when an input :param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys. dictionary is missing any of the specified keys.
@ -128,3 +128,42 @@ class EncodeTilesBatchd(MapTransform):
for key in self.key_iterator(out_data): for key in self.key_iterator(out_data):
out_data[key] = self._encode_tiles(data[key]) out_data[key] = self._encode_tiles(data[key])
return out_data return out_data
def take_indices(data: Sequence, indices: np.ndarray) -> Sequence:
if isinstance(data, (np.ndarray, torch.Tensor)):
return data[indices]
elif isinstance(data, Sequence):
return [data[i] for i in indices]
else:
raise ValueError(f"Data of type {type(data)} is not indexable")
class Subsampled(MapTransform, Randomizable):
"""Dictionary transform to randomly subsample the data down to a fixed maximum length"""
def __init__(self, keys: KeysCollection, max_size: int,
allow_missing_keys: bool = False) -> None:
"""
:param keys: Key(s) for all batch elements that must be subsampled.
:param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at
random down to `max_size` along their first dimension. If shorter, the elements are merely
shuffled.
:param allow_missing_keys: If `False` (default), raises an exception when an input
dictionary is missing any of the specified keys.
"""
super().__init__(keys, allow_missing_keys=allow_missing_keys)
self.max_size = max_size
self._indices: np.ndarray
def randomize(self, total_size: int) -> None:
subsample_size = min(self.max_size, total_size)
self._indices = self.R.choice(total_size, size=subsample_size)
def __call__(self, data: Mapping) -> Mapping:
out_data = dict(data) # create shallow copy
size = len(data[self.keys[0]])
self.randomize(size)
for key in self.key_iterator(out_data):
out_data[key] = take_indices(data[key], self._indices)
return out_data

Просмотреть файл

@ -14,7 +14,7 @@ import param
from torch import nn from torch import nn
from torchvision.models.resnet import resnet18 from torchvision.models.resnet import resnet18
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
from InnerEye.ML.lightning_container import LightningContainer from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
@ -90,6 +90,8 @@ class BaseMIL(LightningContainer):
return AttentionLayer return AttentionLayer
elif self.pooling_type == GatedAttentionLayer.__name__: elif self.pooling_type == GatedAttentionLayer.__name__:
return GatedAttentionLayer return GatedAttentionLayer
elif self.pooling_type == MeanPoolingLayer.__name__:
return MeanPoolingLayer
else: else:
raise ValueError(f"Unsupported pooling type: {self.pooling_type}") raise ValueError(f"Unsupported pooling type: {self.pooling_type}")

Просмотреть файл

@ -14,6 +14,7 @@ from torchvision.models import resnet18
from health_ml.networks.layers.attention_layers import ( from health_ml.networks.layers.attention_layers import (
AttentionLayer, AttentionLayer,
GatedAttentionLayer, GatedAttentionLayer,
MeanPoolingLayer,
) )
from InnerEye.ML.lightning_container import LightningContainer from InnerEye.ML.lightning_container import LightningContainer
@ -37,14 +38,7 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224) return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
@pytest.mark.parametrize("n_classes", [1, 3]) def _test_lightningmodule(
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule(
n_classes: int, n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module], pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int, batch_size: int,
@ -119,6 +113,50 @@ def test_lightningmodule(
assert torch.all(score <= 1) assert torch.all(score <= 1)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
@pytest.mark.parametrize("pool_out_dim", [1, 6])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_attention(
n_classes: int,
pooling_layer: Callable[[int, int, int], nn.Module],
batch_size: int,
max_bag_size: int,
pool_hidden_dim: int,
pool_out_dim: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=pooling_layer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=pool_hidden_dim,
pool_out_dim=pool_out_dim,
dropout_rate=dropout_rate)
@pytest.mark.parametrize("n_classes", [1, 3])
@pytest.mark.parametrize("batch_size", [1, 15])
@pytest.mark.parametrize("max_bag_size", [1, 7])
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
def test_lightningmodule_mean_pooling(
n_classes: int,
batch_size: int,
max_bag_size: int,
dropout_rate: Optional[float],
) -> None:
_test_lightningmodule(n_classes=n_classes,
pooling_layer=MeanPoolingLayer,
batch_size=batch_size,
max_bag_size=max_bag_size,
pool_hidden_dim=1,
pool_out_dim=1,
dropout_rate=dropout_rate)
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict: def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
device = "cuda" if use_gpu else "cpu" device = "cuda" if use_gpu else "cpu"
return { return {

Просмотреть файл

@ -7,6 +7,7 @@ import os
from pathlib import Path from pathlib import Path
from typing import Callable, Sequence, Union from typing import Callable, Sequence, Union
import numpy as np
import pytest import pytest
import torch import torch
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
@ -19,7 +20,7 @@ from health_ml.utils.bag_utils import BagDataset
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled
from Tests.ML.util import assert_dicts_equal from Tests.ML.util import assert_dicts_equal
@ -153,3 +154,57 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
bagged_subset, bagged_subset,
transform=transform, transform=transform,
cache_subdir="TCGA-CRCk_embed_cache") cache_subdir="TCGA-CRCk_embed_cache")
@pytest.mark.parametrize('include_non_indexable', [True, False])
@pytest.mark.parametrize('allow_missing_keys', [True, False])
def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None:
batch_size = 5
max_size = batch_size // 2
data = {
'array_1d': np.random.randn(batch_size),
'array_2d': np.random.randn(batch_size, 4),
'tensor_1d': torch.randn(batch_size),
'tensor_2d': torch.randn(batch_size, 4),
'list': torch.randn(batch_size).tolist(),
'indices': list(range(batch_size)),
'non-indexable': 42,
}
keys_to_subsample = list(data.keys())
if not include_non_indexable:
keys_to_subsample.remove('non-indexable')
keys_to_subsample.append('missing-key')
subsampling = Subsampled(keys_to_subsample, max_size=max_size,
allow_missing_keys=allow_missing_keys)
if include_non_indexable:
with pytest.raises(ValueError):
sub_data = subsampling(data)
return
elif not allow_missing_keys:
with pytest.raises(KeyError):
sub_data = subsampling(data)
return
else:
sub_data = subsampling(data)
assert set(sub_data.keys()) == set(data.keys())
# Check lenghts before and after subsampling
for key in keys_to_subsample:
if key not in data:
continue # Skip missing keys
assert len(data[key]) == batch_size # type: ignore
assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore
# Check contents of subsampled elements
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
for idx, elem in zip(sub_data['indices'], sub_data[key]):
assert np.array_equal(elem, data[key][idx]) # type: ignore
# Check that subsampling is random, i.e. subsequent calls shouldn't give identical results
sub_data2 = subsampling(data)
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore

2
hi-ml

@ -1 +1 @@
Subproject commit a33c1ed07da8a42486dec9f939cd59eea4b2583e Subproject commit 2bc397b4707b56fecca624ce81e6883e0170b24b