зеркало из https://github.com/microsoft/hi-ml.git
ENH: Extend StainNormalisation Transform to work with Slides Pipeline (#644)
Adjust the input/output shapes.
This commit is contained in:
Родитель
c156faddb9
Коммит
431ff2769c
|
@ -16,7 +16,10 @@ def image_collate(batch: List) -> Any:
|
|||
|
||||
for i, item in enumerate(batch):
|
||||
data = item[0]
|
||||
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
|
||||
if isinstance(data[SlideKey.IMAGE], torch.Tensor):
|
||||
data[SlideKey.IMAGE] = torch.stack([ix[SlideKey.IMAGE] for ix in item], dim=0)
|
||||
else:
|
||||
data[SlideKey.IMAGE] = torch.tensor(np.array([ix[SlideKey.IMAGE] for ix in item]))
|
||||
data[SlideKey.LABEL] = torch.tensor(data[SlideKey.LABEL])
|
||||
batch[i] = data
|
||||
return multibag_collate(batch)
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Sequence
|
||||
from health_cpath.utils.naming import SlideKey
|
||||
from health_cpath.utils.wsi_utils import image_collate
|
||||
|
@ -15,7 +15,8 @@ class MockTiledWSIDataset(Dataset):
|
|||
n_slides: int,
|
||||
n_classes: int,
|
||||
tile_size: Sequence[int],
|
||||
random_n_tiles: bool) -> None:
|
||||
random_n_tiles: bool,
|
||||
img_type: str = "np") -> None:
|
||||
|
||||
self.n_tiles = n_tiles
|
||||
self.n_slides = n_slides
|
||||
|
@ -23,30 +24,38 @@ class MockTiledWSIDataset(Dataset):
|
|||
self.n_classes = n_classes
|
||||
self.random_n_tiles = random_n_tiles
|
||||
self.slide_ids = torch.arange(self.n_slides)
|
||||
self.img_type = img_type
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.n_slides
|
||||
|
||||
def __getitem__(self, index: int) -> List[Dict[SlideKey, Any]]:
|
||||
tile_count = np.random.randint(self.n_tiles) if self.random_n_tiles else self.n_tiles
|
||||
tile_count = np.random.randint(low=1, high=self.n_tiles) if self.random_n_tiles else self.n_tiles
|
||||
label = np.random.choice(self.n_classes)
|
||||
img: Union[np.ndarray, torch.Tensor]
|
||||
if self.img_type == "np":
|
||||
img = np.random.randint(0, 255, size=(tile_count, *self.tile_size))
|
||||
else:
|
||||
img = torch.randint(0, 255, size=(tile_count, *self.tile_size))
|
||||
return [{SlideKey.SLIDE_ID: self.slide_ids[index],
|
||||
SlideKey.IMAGE: np.random.randint(0, 255, size=self.tile_size),
|
||||
SlideKey.IMAGE: img,
|
||||
SlideKey.IMAGE_PATH: f"slide_{self.slide_ids[index]}.tiff",
|
||||
SlideKey.LABEL: label
|
||||
} for _ in range(tile_count)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("img_type", ["np", "torch"])
|
||||
@pytest.mark.parametrize("random_n_tiles", [False, True])
|
||||
def test_image_collate(random_n_tiles: bool) -> None:
|
||||
def test_image_collate(random_n_tiles: bool, img_type: str) -> None:
|
||||
# random_n_tiles accounts for both train and inference settings where the number of tiles is fixed (during
|
||||
# training) and None during inference (validation and test)
|
||||
dataset = MockTiledWSIDataset(n_tiles=20,
|
||||
n_slides=10,
|
||||
n_classes=4,
|
||||
tile_size=(1, 4, 4),
|
||||
random_n_tiles=random_n_tiles)
|
||||
random_n_tiles=random_n_tiles,
|
||||
img_type=img_type)
|
||||
|
||||
batch_size = 5
|
||||
samples_list = [dataset[idx] for idx in range(batch_size)]
|
||||
|
|
|
@ -131,6 +131,9 @@ class StainNormalization(object):
|
|||
return nimg
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
original_shape = img.shape
|
||||
if len(original_shape) == 3:
|
||||
img = img.unsqueeze(0) # add batch dimension if missing
|
||||
# if the input is a bag of images, stain normalization needs to run on each image separately
|
||||
if img.shape[0] > 1:
|
||||
for i in range(img.shape[0]):
|
||||
|
@ -138,7 +141,10 @@ class StainNormalization(object):
|
|||
img[i] = self.stain_normalize(img_tile.unsqueeze(0), self.reference_mean, self.reference_std)
|
||||
return img
|
||||
else:
|
||||
return self.stain_normalize(img, self.reference_mean, self.reference_std)
|
||||
img = self.stain_normalize(img, self.reference_mean, self.reference_std)
|
||||
if len(original_shape) == 3:
|
||||
return img.squeeze(0)
|
||||
return img
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
|
|
|
@ -65,6 +65,9 @@ def test_stain_normalization() -> None:
|
|||
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=False)
|
||||
_test_data_augmentation(data_augmentation, dummy_bag, expected_output_bag, stochastic=False)
|
||||
|
||||
# Test tiling on the fly (i.e. when the input image does not have a batch dimension)
|
||||
_test_data_augmentation(data_augmentation, dummy_img.squeeze(0), expected_output_img.squeeze(0), stochastic=False)
|
||||
|
||||
|
||||
def test_hed_jitter() -> None:
|
||||
data_augmentation = HEDJitter(0.05)
|
||||
|
|
Загрузка…
Ссылка в новой задаче