Add subsampling transform and mean pooling (#656)
* Add subsampling transform * Add option to allow_missing_keys for Subsampled * Add dropout param to BaseMIL * Add docstring and tests for Subsampled * Update changelog * Update to hi-ml with mean pooling * Enable mean pooling in DeepMIL * Add/refactor mean pooling tests * Update changelog * Update to latest hi-ml with mean pooling
This commit is contained in:
Родитель
1600ef3ddf
Коммит
e2ec5cc839
|
@ -48,6 +48,7 @@ jobs that run in AzureML.
|
||||||
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
|
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
|
||||||
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
|
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
|
||||||
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
|
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
|
||||||
|
- ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3.
|
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3.
|
||||||
|
|
|
@ -204,7 +204,7 @@ class DeepMILModule(LightningModule):
|
||||||
with set_grad_enabled(self.is_finetune):
|
with set_grad_enabled(self.is_finetune):
|
||||||
instance_features = self.encoder(instances) # N X L x 1 x 1
|
instance_features = self.encoder(instances) # N X L x 1 x 1
|
||||||
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
|
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
|
||||||
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)
|
bag_features = bag_features.view(1, -1)
|
||||||
bag_logit = self.classifier_fn(bag_features)
|
bag_logit = self.classifier_fn(bag_features)
|
||||||
return bag_logit, attentions
|
return bag_logit, attentions
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
from monai.config.type_definitions import KeysCollection
|
from monai.config.type_definitions import KeysCollection
|
||||||
from monai.transforms.transform import MapTransform
|
from monai.transforms.transform import MapTransform, Randomizable
|
||||||
from torchvision.transforms.functional import to_tensor
|
from torchvision.transforms.functional import to_tensor
|
||||||
|
|
||||||
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
|
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
|
||||||
|
@ -92,7 +92,7 @@ class EncodeTilesBatchd(MapTransform):
|
||||||
allow_missing_keys: bool = False,
|
allow_missing_keys: bool = False,
|
||||||
chunk_size: int = 0) -> None:
|
chunk_size: int = 0) -> None:
|
||||||
"""
|
"""
|
||||||
:param keys: Key(s) for the image path(s) in the input dictionary.
|
:param keys: Key(s) for the image tensor(s) in the input dictionary.
|
||||||
:param encoder: The tile encoder to use for feature extraction.
|
:param encoder: The tile encoder to use for feature extraction.
|
||||||
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
||||||
dictionary is missing any of the specified keys.
|
dictionary is missing any of the specified keys.
|
||||||
|
@ -128,3 +128,42 @@ class EncodeTilesBatchd(MapTransform):
|
||||||
for key in self.key_iterator(out_data):
|
for key in self.key_iterator(out_data):
|
||||||
out_data[key] = self._encode_tiles(data[key])
|
out_data[key] = self._encode_tiles(data[key])
|
||||||
return out_data
|
return out_data
|
||||||
|
|
||||||
|
|
||||||
|
def take_indices(data: Sequence, indices: np.ndarray) -> Sequence:
|
||||||
|
if isinstance(data, (np.ndarray, torch.Tensor)):
|
||||||
|
return data[indices]
|
||||||
|
elif isinstance(data, Sequence):
|
||||||
|
return [data[i] for i in indices]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Data of type {type(data)} is not indexable")
|
||||||
|
|
||||||
|
|
||||||
|
class Subsampled(MapTransform, Randomizable):
|
||||||
|
"""Dictionary transform to randomly subsample the data down to a fixed maximum length"""
|
||||||
|
|
||||||
|
def __init__(self, keys: KeysCollection, max_size: int,
|
||||||
|
allow_missing_keys: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
:param keys: Key(s) for all batch elements that must be subsampled.
|
||||||
|
:param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at
|
||||||
|
random down to `max_size` along their first dimension. If shorter, the elements are merely
|
||||||
|
shuffled.
|
||||||
|
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
||||||
|
dictionary is missing any of the specified keys.
|
||||||
|
"""
|
||||||
|
super().__init__(keys, allow_missing_keys=allow_missing_keys)
|
||||||
|
self.max_size = max_size
|
||||||
|
self._indices: np.ndarray
|
||||||
|
|
||||||
|
def randomize(self, total_size: int) -> None:
|
||||||
|
subsample_size = min(self.max_size, total_size)
|
||||||
|
self._indices = self.R.choice(total_size, size=subsample_size)
|
||||||
|
|
||||||
|
def __call__(self, data: Mapping) -> Mapping:
|
||||||
|
out_data = dict(data) # create shallow copy
|
||||||
|
size = len(data[self.keys[0]])
|
||||||
|
self.randomize(size)
|
||||||
|
for key in self.key_iterator(out_data):
|
||||||
|
out_data[key] = take_indices(data[key], self._indices)
|
||||||
|
return out_data
|
||||||
|
|
|
@ -14,7 +14,7 @@ import param
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchvision.models.resnet import resnet18
|
from torchvision.models.resnet import resnet18
|
||||||
|
|
||||||
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
|
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
|
||||||
from InnerEye.ML.lightning_container import LightningContainer
|
from InnerEye.ML.lightning_container import LightningContainer
|
||||||
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||||
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
|
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
|
||||||
|
@ -90,6 +90,8 @@ class BaseMIL(LightningContainer):
|
||||||
return AttentionLayer
|
return AttentionLayer
|
||||||
elif self.pooling_type == GatedAttentionLayer.__name__:
|
elif self.pooling_type == GatedAttentionLayer.__name__:
|
||||||
return GatedAttentionLayer
|
return GatedAttentionLayer
|
||||||
|
elif self.pooling_type == MeanPoolingLayer.__name__:
|
||||||
|
return MeanPoolingLayer
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
|
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from torchvision.models import resnet18
|
||||||
from health_ml.networks.layers.attention_layers import (
|
from health_ml.networks.layers.attention_layers import (
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
GatedAttentionLayer,
|
GatedAttentionLayer,
|
||||||
|
MeanPoolingLayer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from InnerEye.ML.lightning_container import LightningContainer
|
from InnerEye.ML.lightning_container import LightningContainer
|
||||||
|
@ -37,14 +38,7 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
|
||||||
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
|
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
def _test_lightningmodule(
|
||||||
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
|
|
||||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
|
||||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
|
||||||
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
|
|
||||||
@pytest.mark.parametrize("pool_out_dim", [1, 6])
|
|
||||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
|
||||||
def test_lightningmodule(
|
|
||||||
n_classes: int,
|
n_classes: int,
|
||||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
@ -119,6 +113,50 @@ def test_lightningmodule(
|
||||||
assert torch.all(score <= 1)
|
assert torch.all(score <= 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||||
|
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||||
|
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||||
|
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
|
||||||
|
@pytest.mark.parametrize("pool_out_dim", [1, 6])
|
||||||
|
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||||
|
def test_lightningmodule_attention(
|
||||||
|
n_classes: int,
|
||||||
|
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||||
|
batch_size: int,
|
||||||
|
max_bag_size: int,
|
||||||
|
pool_hidden_dim: int,
|
||||||
|
pool_out_dim: int,
|
||||||
|
dropout_rate: Optional[float],
|
||||||
|
) -> None:
|
||||||
|
_test_lightningmodule(n_classes=n_classes,
|
||||||
|
pooling_layer=pooling_layer,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_bag_size=max_bag_size,
|
||||||
|
pool_hidden_dim=pool_hidden_dim,
|
||||||
|
pool_out_dim=pool_out_dim,
|
||||||
|
dropout_rate=dropout_rate)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||||
|
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||||
|
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||||
|
def test_lightningmodule_mean_pooling(
|
||||||
|
n_classes: int,
|
||||||
|
batch_size: int,
|
||||||
|
max_bag_size: int,
|
||||||
|
dropout_rate: Optional[float],
|
||||||
|
) -> None:
|
||||||
|
_test_lightningmodule(n_classes=n_classes,
|
||||||
|
pooling_layer=MeanPoolingLayer,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_bag_size=max_bag_size,
|
||||||
|
pool_hidden_dim=1,
|
||||||
|
pool_out_dim=1,
|
||||||
|
dropout_rate=dropout_rate)
|
||||||
|
|
||||||
|
|
||||||
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
|
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
|
||||||
device = "cuda" if use_gpu else "cpu"
|
device = "cuda" if use_gpu else "cpu"
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Sequence, Union
|
from typing import Callable, Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||||
|
@ -19,7 +20,7 @@ from health_ml.utils.bag_utils import BagDataset
|
||||||
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
|
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
|
||||||
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||||
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
|
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
|
||||||
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd
|
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled
|
||||||
from Tests.ML.util import assert_dicts_equal
|
from Tests.ML.util import assert_dicts_equal
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,3 +154,57 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
|
||||||
bagged_subset,
|
bagged_subset,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
cache_subdir="TCGA-CRCk_embed_cache")
|
cache_subdir="TCGA-CRCk_embed_cache")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('include_non_indexable', [True, False])
|
||||||
|
@pytest.mark.parametrize('allow_missing_keys', [True, False])
|
||||||
|
def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None:
|
||||||
|
batch_size = 5
|
||||||
|
max_size = batch_size // 2
|
||||||
|
data = {
|
||||||
|
'array_1d': np.random.randn(batch_size),
|
||||||
|
'array_2d': np.random.randn(batch_size, 4),
|
||||||
|
'tensor_1d': torch.randn(batch_size),
|
||||||
|
'tensor_2d': torch.randn(batch_size, 4),
|
||||||
|
'list': torch.randn(batch_size).tolist(),
|
||||||
|
'indices': list(range(batch_size)),
|
||||||
|
'non-indexable': 42,
|
||||||
|
}
|
||||||
|
|
||||||
|
keys_to_subsample = list(data.keys())
|
||||||
|
if not include_non_indexable:
|
||||||
|
keys_to_subsample.remove('non-indexable')
|
||||||
|
keys_to_subsample.append('missing-key')
|
||||||
|
|
||||||
|
subsampling = Subsampled(keys_to_subsample, max_size=max_size,
|
||||||
|
allow_missing_keys=allow_missing_keys)
|
||||||
|
|
||||||
|
if include_non_indexable:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
sub_data = subsampling(data)
|
||||||
|
return
|
||||||
|
elif not allow_missing_keys:
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
sub_data = subsampling(data)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
sub_data = subsampling(data)
|
||||||
|
|
||||||
|
assert set(sub_data.keys()) == set(data.keys())
|
||||||
|
|
||||||
|
# Check lenghts before and after subsampling
|
||||||
|
for key in keys_to_subsample:
|
||||||
|
if key not in data:
|
||||||
|
continue # Skip missing keys
|
||||||
|
assert len(data[key]) == batch_size # type: ignore
|
||||||
|
assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore
|
||||||
|
|
||||||
|
# Check contents of subsampled elements
|
||||||
|
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
|
||||||
|
for idx, elem in zip(sub_data['indices'], sub_data[key]):
|
||||||
|
assert np.array_equal(elem, data[key][idx]) # type: ignore
|
||||||
|
|
||||||
|
# Check that subsampling is random, i.e. subsequent calls shouldn't give identical results
|
||||||
|
sub_data2 = subsampling(data)
|
||||||
|
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
|
||||||
|
assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore
|
||||||
|
|
2
hi-ml
2
hi-ml
|
@ -1 +1 @@
|
||||||
Subproject commit a33c1ed07da8a42486dec9f939cd59eea4b2583e
|
Subproject commit 2bc397b4707b56fecca624ce81e6883e0170b24b
|
Загрузка…
Ссылка в новой задаче