Add subsampling transform and mean pooling (#656)
* Add subsampling transform * Add option to allow_missing_keys for Subsampled * Add dropout param to BaseMIL * Add docstring and tests for Subsampled * Update changelog * Update to hi-ml with mean pooling * Enable mean pooling in DeepMIL * Add/refactor mean pooling tests * Update changelog * Update to latest hi-ml with mean pooling
This commit is contained in:
Родитель
1600ef3ddf
Коммит
e2ec5cc839
|
@ -48,6 +48,7 @@ jobs that run in AzureML.
|
|||
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
|
||||
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
|
||||
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.
|
||||
- ([#656](https://github.com/microsoft/InnerEye-DeepLearning/pull/656)) Add subsampling transform and support for MIL mean pooling.
|
||||
|
||||
### Changed
|
||||
- ([#659](https://github.com/microsoft/InnerEye-DeepLearning/pull/659)) Update cudatoolkit version from 11.1 to 11.3.
|
||||
|
|
|
@ -204,7 +204,7 @@ class DeepMILModule(LightningModule):
|
|||
with set_grad_enabled(self.is_finetune):
|
||||
instance_features = self.encoder(instances) # N X L x 1 x 1
|
||||
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
|
||||
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)
|
||||
bag_features = bag_features.view(1, -1)
|
||||
bag_logit = self.classifier_fn(bag_features)
|
||||
return bag_logit, attentions
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch
|
|||
import numpy as np
|
||||
import PIL
|
||||
from monai.config.type_definitions import KeysCollection
|
||||
from monai.transforms.transform import MapTransform
|
||||
from monai.transforms.transform import MapTransform, Randomizable
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
|
||||
from InnerEye.ML.Histopathology.models.encoders import TileEncoder
|
||||
|
@ -92,7 +92,7 @@ class EncodeTilesBatchd(MapTransform):
|
|||
allow_missing_keys: bool = False,
|
||||
chunk_size: int = 0) -> None:
|
||||
"""
|
||||
:param keys: Key(s) for the image path(s) in the input dictionary.
|
||||
:param keys: Key(s) for the image tensor(s) in the input dictionary.
|
||||
:param encoder: The tile encoder to use for feature extraction.
|
||||
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
||||
dictionary is missing any of the specified keys.
|
||||
|
@ -128,3 +128,42 @@ class EncodeTilesBatchd(MapTransform):
|
|||
for key in self.key_iterator(out_data):
|
||||
out_data[key] = self._encode_tiles(data[key])
|
||||
return out_data
|
||||
|
||||
|
||||
def take_indices(data: Sequence, indices: np.ndarray) -> Sequence:
|
||||
if isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
return data[indices]
|
||||
elif isinstance(data, Sequence):
|
||||
return [data[i] for i in indices]
|
||||
else:
|
||||
raise ValueError(f"Data of type {type(data)} is not indexable")
|
||||
|
||||
|
||||
class Subsampled(MapTransform, Randomizable):
|
||||
"""Dictionary transform to randomly subsample the data down to a fixed maximum length"""
|
||||
|
||||
def __init__(self, keys: KeysCollection, max_size: int,
|
||||
allow_missing_keys: bool = False) -> None:
|
||||
"""
|
||||
:param keys: Key(s) for all batch elements that must be subsampled.
|
||||
:param max_size: Each specified array, tensor, or sequence will be subsampled uniformly at
|
||||
random down to `max_size` along their first dimension. If shorter, the elements are merely
|
||||
shuffled.
|
||||
:param allow_missing_keys: If `False` (default), raises an exception when an input
|
||||
dictionary is missing any of the specified keys.
|
||||
"""
|
||||
super().__init__(keys, allow_missing_keys=allow_missing_keys)
|
||||
self.max_size = max_size
|
||||
self._indices: np.ndarray
|
||||
|
||||
def randomize(self, total_size: int) -> None:
|
||||
subsample_size = min(self.max_size, total_size)
|
||||
self._indices = self.R.choice(total_size, size=subsample_size)
|
||||
|
||||
def __call__(self, data: Mapping) -> Mapping:
|
||||
out_data = dict(data) # create shallow copy
|
||||
size = len(data[self.keys[0]])
|
||||
self.randomize(size)
|
||||
for key in self.key_iterator(out_data):
|
||||
out_data[key] = take_indices(data[key], self._indices)
|
||||
return out_data
|
||||
|
|
|
@ -14,7 +14,7 @@ import param
|
|||
from torch import nn
|
||||
from torchvision.models.resnet import resnet18
|
||||
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer
|
||||
from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer, MeanPoolingLayer
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset
|
||||
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule
|
||||
|
@ -90,6 +90,8 @@ class BaseMIL(LightningContainer):
|
|||
return AttentionLayer
|
||||
elif self.pooling_type == GatedAttentionLayer.__name__:
|
||||
return GatedAttentionLayer
|
||||
elif self.pooling_type == MeanPoolingLayer.__name__:
|
||||
return MeanPoolingLayer
|
||||
else:
|
||||
raise ValueError(f"Unsupported pooling type: {self.pooling_type}")
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from torchvision.models import resnet18
|
|||
from health_ml.networks.layers.attention_layers import (
|
||||
AttentionLayer,
|
||||
GatedAttentionLayer,
|
||||
MeanPoolingLayer,
|
||||
)
|
||||
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
|
@ -37,14 +38,7 @@ def get_supervised_imagenet_encoder() -> TileEncoder:
|
|||
return ImageNetEncoder(feature_extraction_model=resnet18, tile_size=224)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
|
||||
@pytest.mark.parametrize("pool_out_dim", [1, 6])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule(
|
||||
def _test_lightningmodule(
|
||||
n_classes: int,
|
||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
batch_size: int,
|
||||
|
@ -119,6 +113,50 @@ def test_lightningmodule(
|
|||
assert torch.all(score <= 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("pooling_layer", [AttentionLayer, GatedAttentionLayer])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("pool_hidden_dim", [1, 5])
|
||||
@pytest.mark.parametrize("pool_out_dim", [1, 6])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule_attention(
|
||||
n_classes: int,
|
||||
pooling_layer: Callable[[int, int, int], nn.Module],
|
||||
batch_size: int,
|
||||
max_bag_size: int,
|
||||
pool_hidden_dim: int,
|
||||
pool_out_dim: int,
|
||||
dropout_rate: Optional[float],
|
||||
) -> None:
|
||||
_test_lightningmodule(n_classes=n_classes,
|
||||
pooling_layer=pooling_layer,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=max_bag_size,
|
||||
pool_hidden_dim=pool_hidden_dim,
|
||||
pool_out_dim=pool_out_dim,
|
||||
dropout_rate=dropout_rate)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_classes", [1, 3])
|
||||
@pytest.mark.parametrize("batch_size", [1, 15])
|
||||
@pytest.mark.parametrize("max_bag_size", [1, 7])
|
||||
@pytest.mark.parametrize("dropout_rate", [None, 0.5])
|
||||
def test_lightningmodule_mean_pooling(
|
||||
n_classes: int,
|
||||
batch_size: int,
|
||||
max_bag_size: int,
|
||||
dropout_rate: Optional[float],
|
||||
) -> None:
|
||||
_test_lightningmodule(n_classes=n_classes,
|
||||
pooling_layer=MeanPoolingLayer,
|
||||
batch_size=batch_size,
|
||||
max_bag_size=max_bag_size,
|
||||
pool_hidden_dim=1,
|
||||
pool_out_dim=1,
|
||||
dropout_rate=dropout_rate)
|
||||
|
||||
|
||||
def move_batch_to_expected_device(batch: Dict[str, List], use_gpu: bool) -> Dict:
|
||||
device = "cuda" if use_gpu else "cpu"
|
||||
return {
|
||||
|
|
|
@ -7,6 +7,7 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import Callable, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
||||
|
@ -19,7 +20,7 @@ from health_ml.utils.bag_utils import BagDataset
|
|||
from InnerEye.ML.Histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
|
||||
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||
from InnerEye.ML.Histopathology.models.encoders import ImageNetEncoder
|
||||
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd
|
||||
from InnerEye.ML.Histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, Subsampled
|
||||
from Tests.ML.util import assert_dicts_equal
|
||||
|
||||
|
||||
|
@ -153,3 +154,57 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
|
|||
bagged_subset,
|
||||
transform=transform,
|
||||
cache_subdir="TCGA-CRCk_embed_cache")
|
||||
|
||||
|
||||
@pytest.mark.parametrize('include_non_indexable', [True, False])
|
||||
@pytest.mark.parametrize('allow_missing_keys', [True, False])
|
||||
def test_subsample(include_non_indexable: bool, allow_missing_keys: bool) -> None:
|
||||
batch_size = 5
|
||||
max_size = batch_size // 2
|
||||
data = {
|
||||
'array_1d': np.random.randn(batch_size),
|
||||
'array_2d': np.random.randn(batch_size, 4),
|
||||
'tensor_1d': torch.randn(batch_size),
|
||||
'tensor_2d': torch.randn(batch_size, 4),
|
||||
'list': torch.randn(batch_size).tolist(),
|
||||
'indices': list(range(batch_size)),
|
||||
'non-indexable': 42,
|
||||
}
|
||||
|
||||
keys_to_subsample = list(data.keys())
|
||||
if not include_non_indexable:
|
||||
keys_to_subsample.remove('non-indexable')
|
||||
keys_to_subsample.append('missing-key')
|
||||
|
||||
subsampling = Subsampled(keys_to_subsample, max_size=max_size,
|
||||
allow_missing_keys=allow_missing_keys)
|
||||
|
||||
if include_non_indexable:
|
||||
with pytest.raises(ValueError):
|
||||
sub_data = subsampling(data)
|
||||
return
|
||||
elif not allow_missing_keys:
|
||||
with pytest.raises(KeyError):
|
||||
sub_data = subsampling(data)
|
||||
return
|
||||
else:
|
||||
sub_data = subsampling(data)
|
||||
|
||||
assert set(sub_data.keys()) == set(data.keys())
|
||||
|
||||
# Check lenghts before and after subsampling
|
||||
for key in keys_to_subsample:
|
||||
if key not in data:
|
||||
continue # Skip missing keys
|
||||
assert len(data[key]) == batch_size # type: ignore
|
||||
assert len(sub_data[key]) == min(max_size, batch_size) # type: ignore
|
||||
|
||||
# Check contents of subsampled elements
|
||||
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
|
||||
for idx, elem in zip(sub_data['indices'], sub_data[key]):
|
||||
assert np.array_equal(elem, data[key][idx]) # type: ignore
|
||||
|
||||
# Check that subsampling is random, i.e. subsequent calls shouldn't give identical results
|
||||
sub_data2 = subsampling(data)
|
||||
for key in ['tensor_1d', 'tensor_2d', 'array_1d', 'array_2d', 'list']:
|
||||
assert not np.array_equal(sub_data[key], sub_data2[key]) # type: ignore
|
||||
|
|
2
hi-ml
2
hi-ml
|
@ -1 +1 @@
|
|||
Subproject commit a33c1ed07da8a42486dec9f939cd59eea4b2583e
|
||||
Subproject commit 2bc397b4707b56fecca624ce81e6883e0170b24b
|
Загрузка…
Ссылка в новой задаче