ENH: Extend StainNormalisation Transform to work with Slides Pipeline (#644)

Adjust the input/output shapes.
This commit is contained in:
Kenza Bouzid 2022-10-28 10:57:02 +01:00 коммит произвёл GitHub
Родитель c156faddb9
Коммит 431ff2769c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 29 добавлений и 8 удалений

Просмотреть файл

@ -16,7 +16,10 @@ def image_collate(batch: List) -> Any:
for i, item in enumerate(batch):
data = item[0]
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
if isinstance(data[SlideKey.IMAGE], torch.Tensor):
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE] for ix in item], dim=0)
else:
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL])
batch[i] = data
return multibag_collate(batch)

Просмотреть файл

@ -2,7 +2,7 @@ import torch
import pytest
import numpy as np
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from typing import Sequence
from health_cpath.utils.naming import SlideKey
from health_cpath.utils.wsi_utils import image_collate
@ -15,7 +15,8 @@ class MockTiledWSIDataset(Dataset):
n_slides: int,
n_classes: int,
tile_size: Sequence[int],
random_n_tiles: bool) -> None:
random_n_tiles: bool,
img_type: str = "np") -> None:
self.n_tiles = n_tiles
self.n_slides = n_slides
@ -23,30 +24,38 @@ class MockTiledWSIDataset(Dataset):
self.n_classes = n_classes
self.random_n_tiles = random_n_tiles
self.slide_ids = torch.arange(self.n_slides)
self.img_type = img_type
def __len__(self) -> int:
return self.n_slides
def __getitem__(self, index: int) -> List[Dict[SlideKey, Any]]:
tile_count = np.random.randint(self.n_tiles) if self.random_n_tiles else self.n_tiles
tile_count = np.random.randint(low=1, high=self.n_tiles) if self.random_n_tiles else self.n_tiles
label = np.random.choice(self.n_classes)
img: Union[np.ndarray, torch.Tensor]
if self.img_type == "np":
img = np.random.randint(0, 255, size=(tile_count, *self.tile_size))
else:
img = torch.randint(0, 255, size=(tile_count, *self.tile_size))
return [{SlideKey.SLIDE_ID: self.slide_ids[index],
SlideKey.IMAGE: np.random.randint(0, 255, size=self.tile_size),
SlideKey.IMAGE: img,
SlideKey.IMAGE_PATH: f"slide_{self.slide_ids[index]}.tiff",
SlideKey.LABEL: label
} for _ in range(tile_count)
]
@pytest.mark.parametrize("img_type", ["np", "torch"])
@pytest.mark.parametrize("random_n_tiles", [False, True])
def test_image_collate(random_n_tiles: bool) -> None:
def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
# random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during
# training) and None during inference (validation and test)
dataset = MockTiledWSIDataset(n_tiles=20,
n_slides=10,
n_classes=4,
tile_size=(1, 4, 4),
random_n_tiles=random_n_tiles)
random_n_tiles=random_n_tiles,
img_type=img_type)
batch_size = 5
samples_list = [dataset[idx] for idx in range(batch_size)]

Просмотреть файл

@ -131,6 +131,9 @@ class StainNormalization(object):
return nimg
def __call__(self, img: torch.Tensor) -> torch.Tensor:
original_shape = img.shape
if len(original_shape) == 3:
img = img.unsqueeze(0) # add batch dimension if missing
# if the input is a bag of images, stain normalization needs to run on each image separately
if img.shape[0] > 1:
for i in range(img.shape[0]):
@ -138,7 +141,10 @@ class StainNormalization(object):
img[i] = self.stain_normalize(img_tile.unsqueeze(0), self.reference_mean, self.reference_std)
return img
else:
return self.stain_normalize(img, self.reference_mean, self.reference_std)
img = self.stain_normalize(img, self.reference_mean, self.reference_std)
if len(original_shape) == 3:
return img.squeeze(0)
return img
class GaussianBlur(object):

Просмотреть файл

@ -65,6 +65,9 @@ def test_stain_normalization() -> None:
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=False)
_test_data_augmentation(data_augmentation, dummy_bag, expected_output_bag, stochastic=False)
# Test tiling on the fly (i.e. when the input image does not have a batch dimension)
_test_data_augmentation(data_augmentation, dummy_img.squeeze(0), expected_output_img.squeeze(0), stochastic=False)
def test_hed_jitter() -> None:
data_augmentation = HEDJitter(0.05)