diff --git a/CHANGELOG.md b/CHANGELOG.md index df403cbd..6375baed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ the section headers (Added/Changed/...) and incrementing the package version. ### Added - ([#179](https://github.com/microsoft/hi-ml/pull/179)) Add GaussianBlur and RotationByMultiplesOf90 augmentations. Added torchvision and opencv to the environment file since it is necessary for the augmentations. +- ([#193](https://github.com/microsoft/hi-ml/pull/193)) Add transformation adaptor to hi-ml-histopathology. - ([#178](https://github.com/microsoft/hi-ml/pull/178)) Add runner script for running ML experiments. - ([#181](https://github.com/microsoft/hi-ml/pull/181)) Add computational pathology tools in hi-ml-histopathology folder. - ([#187](https://github.com/microsoft/hi-ml/pull/187)) Add mean pooling layer for MIL. diff --git a/hi-ml-histopathology/src/histopathology/models/transforms.py b/hi-ml-histopathology/src/histopathology/models/transforms.py index 38487d16..c06e60c6 100644 --- a/hi-ml-histopathology/src/histopathology/models/transforms.py +++ b/hi-ml-histopathology/src/histopathology/models/transforms.py @@ -4,7 +4,7 @@ # ------------------------------------------------------------------------------------------ from pathlib import Path -from typing import Mapping, Sequence, Union +from typing import Mapping, Sequence, Union, Callable, Dict import torch import numpy as np @@ -43,6 +43,37 @@ def load_image_stack_as_tensor(image_paths: Sequence[PathOrString], return torch.stack(image_tensors, dim=0) +def transform_dict_adaptor(function: Callable, k_input: str = None, k_output: str = None) -> Callable: + """Adapt transformations to work with an input dictionary (rather than a tensor). + We can't reuse monai.transforms.adaptors because it is only compatible with transformations that accept + a dict as input. + + :param function: a transformation function + :param k_input: key of the input dictionary that contains the object + to which function should be applied + :param k_output: key of the input dictionary where to place the function output. If None the ouput of + the transformation is returned + + :return: adapted transformation + """ + def _inner(ditems: dict) -> Dict: + if k_input is None: + dinputs = ditems + else: + dinputs = ditems[k_input] + ret = function(dinputs) + if k_output is None: + ditems = ret + else: + if isinstance(ret, type(ditems[k_output])): + ditems[k_output] = ret + else: + raise ValueError("The transformation is not expect to change the type." + "Check input and output are used correctly ") + return ditems + return _inner + + class LoadTiled(MapTransform): """Dictionary transform to load an individual image tile as a tensor from an input path""" diff --git a/hi-ml-histopathology/testhisto/testhisto/models/test_transforms.py b/hi-ml-histopathology/testhisto/testhisto/models/test_transforms.py index a7b0ad53..802fd1d7 100644 --- a/hi-ml-histopathology/testhisto/testhisto/models/test_transforms.py +++ b/hi-ml-histopathology/testhisto/testhisto/models/test_transforms.py @@ -14,13 +14,15 @@ from monai.transforms import Compose from torch.utils.data import Dataset as TorchDataset from torch.utils.data import Subset from torchvision.models import resnet18 +from torchvision.transforms import RandomHorizontalFlip from health_ml.utils.bag_utils import BagDataset +from health_ml.utils.data_augmentations import HEDJitter from histopathology.datasets.default_paths import TCGA_CRCK_DATASET_DIR from histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset from histopathology.models.encoders import ImageNetEncoder -from histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd +from histopathology.models.transforms import EncodeTilesBatchd, LoadTiled, LoadTilesBatchd, transform_dict_adaptor import testhisto @@ -163,3 +165,22 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None: bagged_subset, transform=transform, cache_subdir="TCGA-CRCk_embed_cache") + + +def test_transform_dict_adaptor() -> None: + key = "key" + transf1 = transform_dict_adaptor(RandomHorizontalFlip(p=0), key, key) + transf2 = transform_dict_adaptor(RandomHorizontalFlip(p=1), key, key) + transf3 = transform_dict_adaptor(HEDJitter(0), key, key) + input_tensor = torch.arange(24).view(2, 3, 2, 2) + input_dict = {'dummy': [], key: input_tensor} + output_dict1 = transf1(input_dict) + output_dict2 = transf2(input_dict) + output_dict3 = transf3(input_dict) + + expected_output_dict2 = input_dict + expected_output_dict2[key] = torch.flip(input_dict[key], [2]) # type: ignore + + assert output_dict1 == input_dict + assert output_dict2 == expected_output_dict2 + assert output_dict3 == input_dict diff --git a/hi-ml/src/health_ml/utils/data_augmentations.py b/hi-ml/src/health_ml/utils/data_augmentations.py index 84b069d7..e6203ab1 100644 --- a/hi-ml/src/health_ml/utils/data_augmentations.py +++ b/hi-ml/src/health_ml/utils/data_augmentations.py @@ -72,6 +72,8 @@ class HEDJitter(object): return img def __call__(self, img: torch.Tensor) -> torch.Tensor: + if img.shape[1] != 3: + raise ValueError("HED jitter can only be applied to images with 3 channels (RGB).") return self.adjust_hed(img, self.theta, self.hed_from_rgb, self.rgb_from_hed)