зеркало из https://github.com/microsoft/hi-ml.git
Adding transformation adaptor (#193)
Add a transformation adaptor to allow transformations working only on Tensors/Images to work with a dict in input
This commit is contained in:
Родитель
30854eae4f
Коммит
0250715c5a
|
@ -15,6 +15,7 @@ the section headers (Added/Changed/...) and incrementing the package version.
|
|||
### Added
|
||||
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) Add GaussianBlur and RotationByMultiplesOf90 augmentations. Added torchvision and opencv to
|
||||
the environment file since it is necessary for the augmentations.
|
||||
- ([#193](https://github.com/microsoft/hi-ml/pull/193)) Add transformation adaptor to hi-ml-histopathology.
|
||||
- ([#178](https://github.com/microsoft/hi-ml/pull/178)) Add runner script for running ML experiments.
|
||||
- ([#181](https://github.com/microsoft/hi-ml/pull/181)) Add computational pathology tools in hi-ml-histopathology folder.
|
||||
- ([#187](https://github.com/microsoft/hi-ml/pull/187)) Add mean pooling layer for MIL.
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Sequence, Union
|
||||
from typing import Mapping, Sequence, Union, Callable, Dict
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
@ -43,6 +43,37 @@ def load_image_stack_as_tensor(image_paths: Sequence[PathOrString],
|
|||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
def transform_dict_adaptor(function: Callable, k_input: str = None, k_output: str = None) -> Callable:
|
||||
"""Adapt transformations to work with an input dictionary (rather than a tensor).
|
||||
We can't reuse monai.transforms.adaptors because it is only compatible with transformations that accept
|
||||
a dict as input.
|
||||
|
||||
:param function: a transformation function
|
||||
:param k_input: key of the input dictionary that contains the object
|
||||
to which function should be applied
|
||||
:param k_output: key of the input dictionary where to place the function output. If None the ouput of
|
||||
the transformation is returned
|
||||
|
||||
:return: adapted transformation
|
||||
"""
|
||||
def _inner(ditems: dict) -> Dict:
|
||||
if k_input is None:
|
||||
dinputs = ditems
|
||||
else:
|
||||
dinputs = ditems[k_input]
|
||||
ret = function(dinputs)
|
||||
if k_output is None:
|
||||
ditems = ret
|
||||
else:
|
||||
if isinstance(ret, type(ditems[k_output])):
|
||||
ditems[k_output] = ret
|
||||
else:
|
||||
raise ValueError("The transformation is not expect to change the type."
|
||||
"Check input and output are used correctly ")
|
||||
return ditems
|
||||
return _inner
|
||||
|
||||
|
||||
class LoadTiled(MapTransform):
|
||||
"""Dictionary transform to load an individual image tile as a tensor from an input path"""
|
||||
|
||||
|
|
|
@ -14,13 +14,15 @@ from monai.transforms import Compose
|
|||
from torch.utils.data import Dataset as TorchDataset
|
||||
from torch.utils.data import Subset
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.transforms import RandomHorizontalFlip
|
||||
|
||||
from health_ml.utils.bag_utils import BagDataset
|
||||
from health_ml.utils.data_augmentations import HEDJitter
|
||||
|
||||
from histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR
|
||||
from histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
|
||||
from histopathology.models.encoders import ImageNetEncoder
|
||||
from histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd
|
||||
from histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, transform_dict_adaptor
|
||||
|
||||
import testhisto
|
||||
|
||||
|
@ -163,3 +165,22 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None:
|
|||
bagged_subset,
|
||||
transform=transform,
|
||||
cache_subdir="TCGA-CRCk_embed_cache")
|
||||
|
||||
|
||||
def test_transform_dict_adaptor() -> None:
|
||||
key = "key"
|
||||
transf1 = transform_dict_adaptor(RandomHorizontalFlip(p=0), key, key)
|
||||
transf2 = transform_dict_adaptor(RandomHorizontalFlip(p=1), key, key)
|
||||
transf3 = transform_dict_adaptor(HEDJitter(0), key, key)
|
||||
input_tensor = torch.arange(24).view(2, 3, 2, 2)
|
||||
input_dict = {'dummy': [], key: input_tensor}
|
||||
output_dict1 = transf1(input_dict)
|
||||
output_dict2 = transf2(input_dict)
|
||||
output_dict3 = transf3(input_dict)
|
||||
|
||||
expected_output_dict2 = input_dict
|
||||
expected_output_dict2[key] = torch.flip(input_dict[key], [2]) # type: ignore
|
||||
|
||||
assert output_dict1 == input_dict
|
||||
assert output_dict2 == expected_output_dict2
|
||||
assert output_dict3 == input_dict
|
||||
|
|
|
@ -72,6 +72,8 @@ class HEDJitter(object):
|
|||
return img
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
if img.shape[1] != 3:
|
||||
raise ValueError("HED jitter can only be applied to images with 3 channels (RGB).")
|
||||
return self.adjust_hed(img, self.theta, self.hed_from_rgb, self.rgb_from_hed)
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче