Add a transformation adaptor to allow transformations working only on Tensors/Images to work with a dict in input
This commit is contained in:
vale-salvatelli 2022-03-03 13:36:48 +00:00 коммит произвёл GitHub
Родитель 30854eae4f
Коммит 0250715c5a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 57 добавлений и 2 удалений

Просмотреть файл

@ -15,6 +15,7 @@ the section headers (Added/Changed/...) and incrementing the package version.
### Added
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) Add GaussianBlur and RotationByMultiplesOf90 augmentations. Added torchvision and opencv to
the environment file since it is necessary for the augmentations.
- ([#193](https://github.com/microsoft/hi-ml/pull/193)) Add transformation adaptor to hi-ml-histopathology.
- ([#178](https://github.com/microsoft/hi-ml/pull/178)) Add runner script for running ML experiments.
- ([#181](https://github.com/microsoft/hi-ml/pull/181)) Add computational pathology tools in hi-ml-histopathology folder.
- ([#187](https://github.com/microsoft/hi-ml/pull/187)) Add mean pooling layer for MIL.

Просмотреть файл

@ -4,7 +4,7 @@
# ------------------------------------------------------------------------------------------
from pathlib import Path
from typing import Mapping, Sequence, Union
from typing import Mapping, Sequence, Union, Callable, Dict
import torch
import numpy as np
@ -43,6 +43,37 @@ def load_image_stack_as_tensor(image_paths: Sequence[PathOrString],
return torch.stack(image_tensors, dim=0)
def transform_dict_adaptor(function: Callable, k_input: str = None, k_output: str = None) -> Callable:
"""Adapt transformations to work with an input dictionary (rather than a tensor).
We can't reuse monai.transforms.adaptors because it is only compatible with transformations that accept
a dict as input.
:param function: a transformation function
:param k_input: key of the input dictionary that contains the object
to which function should be applied
:param k_output: key of the input dictionary where to place the function output. If None the ouput of
the transformation is returned
:return: adapted transformation
"""
def _inner(ditems: dict) -> Dict:
if k_input is None:
dinputs = ditems
else:
dinputs = ditems[k_input]
ret = function(dinputs)
if k_output is None:
ditems = ret
else:
if isinstance(ret, type(ditems[k_output])):
ditems[k_output] = ret
else:
raise ValueError("The transformation is not expect to change the type."
"Check input and output are used correctly ")
return ditems
return _inner
class LoadTiled(MapTransform):
"""Dictionary transform to load an individual image tile as a tensor from an input path"""

Просмотреть файл

@ -14,13 +14,15 @@ from monai.transforms import Compose
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import Subset
from torchvision.models import resnet18
from torchvision.transforms import RandomHorizontalFlip
from health_ml.utils.bag_utils import BagDataset
from health_ml.utils.data_augmentations import HEDJitter
from histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
from histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from histopathology.models.encoders import ImageNetEncoder
from histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd
from histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, transform_dict_adaptor
import testhisto
@ -163,3 +165,22 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
bagged_subset,
transform=transform,
cache_subdir="TCGA-CRCk_embed_cache")
def test_transform_dict_adaptor() -> None:
key = "key"
transf1 = transform_dict_adaptor(RandomHorizontalFlip(p=0), key, key)
transf2 = transform_dict_adaptor(RandomHorizontalFlip(p=1), key, key)
transf3 = transform_dict_adaptor(HEDJitter(0), key, key)
input_tensor = torch.arange(24).view(2, 3, 2, 2)
input_dict = {'dummy': [], key: input_tensor}
output_dict1 = transf1(input_dict)
output_dict2 = transf2(input_dict)
output_dict3 = transf3(input_dict)
expected_output_dict2 = input_dict
expected_output_dict2[key] = torch.flip(input_dict[key], [2]) # type: ignore
assert output_dict1 == input_dict
assert output_dict2 == expected_output_dict2
assert output_dict3 == input_dict

Просмотреть файл

@ -72,6 +72,8 @@ class HEDJitter(object):
return img
def __call__(self, img: torch.Tensor) -> torch.Tensor:
if img.shape[1] != 3:
raise ValueError("HED jitter can only be applied to images with 3 channels (RGB).")
return self.adjust_hed(img, self.theta, self.hed_from_rgb, self.rgb_from_hed)