* Refractor augmentations initial commit

* Test custom transforms and transform pipeline

* Make transforms independent of config

* Add placeholder to doc as reminder

* Update tests

* Update tests

* Add test for affine

* Add tests for color transform

* Add a todo list

* Adding new test and simplifying the transforms API further

* New tests for new custom transforms

* Remove deprecated file

* Update all transforms

* Add test for custom pipeline transform

* Fix the full pipeline test

* Rename

* SImplify feeding in transforms

* Updating CHANGELOG.md

* Updating CHANGELOG.md

* Flake8

* Simplify more

* Fixes

* Fixes

* Flake8

* Fix few things

* Fix few things

* Updating tests for new API

* Simplify image versus sample

* Update

* Update

* Update

* Update

* Error in test copy paste

* Style

* Add augmentation doc

* Add augmentation doc

* Doc strings

* Flake8

* Flake8

* Mypy

* Fix test and docstring

* Put back skipif windows

* flake8

* mypy

* Skip if OOM

* Skip if OOM

* Import was missing

* Separate transform pipeline from ScalarItem management.

* Update documentation to new API

* Accidental commit

* PR suggestion

* Make test for elastic human checkable

* Make test for elastic human checkable

* Update with p=0

* Update with p=0

* Add human readable test output for gamma transform

* Solve OOM, remove duplicate test input definitions

* Flake8

* Mypy

* Fix config

* Add check PR comment

* PR comment
This commit is contained in:
melanibe 2021-06-01 09:29:57 +01:00 коммит произвёл GitHub
Родитель 2af8c6099c
Коммит 51274c8bdc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
28 изменённых файлов: 1056 добавлений и 1290 удалений

Просмотреть файл

@ -108,6 +108,18 @@ console for easier diagnostics.
- ([#444](https://github.com/microsoft/InnerEye-DeepLearning/pull/444)) The method `before_training_on_rank_zero` of
the `LightningContainer` class has been renamed to `before_training_on_global_rank_zero`. The order in which the
hooks are called has been changed.
- ([#458](https://github.com/microsoft/InnerEye-DeepLearning/pull/458)) Simplifying and generalizing the way we handle
data augmentations for classification models. The pipelining logic is now taken care of by a ImageTransformPipeline
class that takes as input a list of transforms to chain together. This pipeline takes of applying transforms on 3D or
2D images. The user can choose to apply the same transformation for all channels (RGB example) or whether to apply
different transformation for each channel (if each channel represents a different
modality / time point for example). The pipeline can now work directly with out-of-the box torchvision transform
(as long as they support [..., C, H, W] inputs). This allows to get rid of nearly all of our custom augmentations
functions. The conversion from pipeline of image transformation to ScalarItemAugmentation is now taken care of under
the hood, the user does not need to call this wrapper for each config class. In models derived from ScalarModelConfig
to change which augmentations are applied to the images inputs (resp. segmentations inputs), users can override
`get_image_transform` (resp. `get_segmentation_transform`). These two functions replace the old
`get_image_sample_transforms` method. See `docs/building_models.md` for more information on augmentations.
### Fixed
@ -128,6 +140,8 @@ console for easier diagnostics.
- ([#450](https://github.com/microsoft/InnerEye-DeepLearning/pull/450)) Delete unused `classification_report.ipynb`.
- ([#455](https://github.com/microsoft/InnerEye-DeepLearning/pull/455)) Removed the AzureRunner conda environment.
The full InnerEye conda environment is needed to submit a training job to AzureML.
- ([#458](https://github.com/microsoft/InnerEye-DeepLearning/pull/458)) Getting rid of all the unused code for
RandAugment & Co. The user has now instead complete freedom to specify the set of augmentations to use.
### Deprecated

Просмотреть файл

@ -2,195 +2,15 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import random
from collections import Callable
from typing import Any, List, Tuple
from typing import Any, Callable, Tuple
import PIL
import numpy as np
import torch
import torchvision
from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform
from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision.transforms import ToTensor
from yacs.config import CfgNode
class BaseTransform:
def transform(self, x: Any) -> Any:
raise NotImplementedError("Transform needs to be overridden in the child classes")
def __call__(self, data: PIL.Image.Image) -> PIL.Image.Image:
return self.transform(data)
class CenterCrop(BaseTransform):
def __init__(self, config: CfgNode) -> None:
super().__init__()
self.center_crop_size = config.preprocess.center_crop_size
def transform(self, x: Any) -> Any:
return torchvision.transforms.CenterCrop(self.center_crop_size)(x)
class RandomResizeCrop(BaseTransform):
def __init__(self, config: CfgNode) -> None:
super().__init__()
self.resize_size = config.preprocess.resize
self.crop_scale = config.augmentation.random_crop.scale
def transform(self, x: Any) -> Any:
return torchvision.transforms.RandomResizedCrop(
size=self.resize_size,
scale=self.crop_scale)(x)
class RandomHorizontalFlip(BaseTransform):
def __init__(self, config: CfgNode) -> None:
super().__init__()
self.p_apply = config.augmentation.random_horizontal_flip.prob
def transform(self, x: Any) -> Any:
return torchvision.transforms.RandomHorizontalFlip(self.p_apply)(x)
class RandomAffine(BaseTransform):
def __init__(self, config: CfgNode) -> None:
super().__init__()
self.max_angle = config.augmentation.random_affine.max_angle
self.max_horizontal_shift = config.augmentation.random_affine.max_horizontal_shift
self.max_vertical_shift = config.augmentation.random_affine.max_vertical_shift
self.max_shear = config.augmentation.random_affine.max_shear
def transform(self, x: Any) -> Any:
return torchvision.transforms.RandomAffine(degrees=self.max_angle,
translate=(self.max_horizontal_shift, self.max_vertical_shift),
shear=self.max_shear)(x)
class Resize(BaseTransform):
def __init__(self, config: CfgNode) -> None:
super().__init__()
self.resize_size = config.preprocess.resize
def transform(self, x: Any) -> Any:
return torchvision.transforms.Resize(self.resize_size)(x)
class RandomColorJitter(BaseTransform):
def __init__(self, config: CfgNode) -> None:
super().__init__()
self.max_brightness = config.augmentation.random_color.brightness
self.max_contrast = config.augmentation.random_color.contrast
self.max_saturation = config.augmentation.random_color.saturation
def transform(self, x: Any) -> Any:
return torchvision.transforms.ColorJitter(brightness=self.max_brightness,
contrast=self.max_contrast,
saturation=self.max_saturation)(x)
class RandomErasing(BaseTransform):
def __init__(self, config: CfgNode) -> None:
super().__init__()
self.scale = config.augmentation.random_erasing.scale
self.ratio = config.augmentation.random_erasing.ratio
def transform(self, x: Any) -> Any:
return torchvision.transforms.RandomErasing(p=0.5,
scale=self.scale,
ratio=self.ratio)(x)
class RandomGamma(BaseTransform):
def __init__(self, config: CfgNode) -> None:
super().__init__()
self.scale = config.augmentation.gamma.scale
def transform(self, image: PIL.Image.Image) -> PIL.Image.Image:
gamma = random.uniform(*self.scale)
return torchvision.transforms.functional.adjust_gamma(image, gamma=gamma)
class ExpandChannels(BaseTransform):
"""
Transforms an image with 1 channel to an image with 3 channels by copying pixel intensities of the image along
the 0th dimension.
"""
def transform(self, data: torch.Tensor) -> torch.Tensor:
return torch.repeat_interleave(data, 3, dim=0)
class AddGaussianNoise(BaseTransform):
def __init__(self, config: CfgNode) -> None:
"""
Transformation to add Gaussian noise N(0, std) to an image. Where std is set with the
config.augmentation.gaussian_noise.std argument. The transformation will be applied with probability
config.augmentation.gaussian_noise.p_apply
"""
super().__init__()
self.p_apply = config.augmentation.gaussian_noise.p_apply
self.std = config.augmentation.gaussian_noise.std
def transform(self, data: torch.Tensor) -> torch.Tensor:
assert data.max() <= 1 and data.min() >= 0
if np.random.random(1) > self.p_apply:
return data
noise = torch.randn(size=data.shape) * self.std
data = torch.clamp(data + noise, 0, 1)
return data
class ElasticTransform(BaseTransform):
"""Elastic deformation of images as described in [Simard2003]_.
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
Convolutional Neural Networks applied to Visual Document Analysis", in
Proc. of the International Conference on Document Analysis and
Recognition, 2003.
https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.160.8494&rep=rep1&type=pdf
:param sigma: elasticity coefficient
:param alpha: intensity of the deformation
:param p_apply: probability of applying the transformation
"""
def __init__(self, config: CfgNode) -> None:
super().__init__()
self.alpha = config.augmentation.elastic_transform.alpha
self.sigma = config.augmentation.elastic_transform.sigma
self.p_apply = config.augmentation.elastic_transform.p_apply
def transform(self, image: PIL.Image) -> PIL.Image:
if np.random.random(1) > self.p_apply:
return image
image = np.asarray(image).squeeze()
assert len(image.shape) == 2
shape = image.shape
dx = gaussian_filter((np.random.random(shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
dy = gaussian_filter((np.random.random(shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))
return PIL.Image.fromarray(map_coordinates(image, indices, order=1).reshape(shape))
class DualViewTransformWrapper:
"""
Returns two versions of one image, given a random transformation function.
"""
def __init__(self, transform: Callable):
self.transform = transform
def __call__(self, sample: PIL.Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
xi = self.transform(sample)
xj = self.transform(sample)
return xi, xj
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
def get_cxr_ssl_transforms(config: CfgNode,
@ -214,54 +34,19 @@ def get_cxr_ssl_transforms(config: CfgNode,
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
(no augmentations)
"""
train_transforms = create_chest_xray_transform(config, apply_augmentations=True)
val_transforms = create_chest_xray_transform(config, apply_augmentations=use_training_augmentations_for_validation)
train_transforms = create_cxr_transforms_from_config(config, apply_augmentations=True)
val_transforms = create_cxr_transforms_from_config(config,
apply_augmentations=use_training_augmentations_for_validation)
if return_two_views_per_sample:
train_transforms = DualViewTransformWrapper(train_transforms)
val_transforms = DualViewTransformWrapper(val_transforms)
train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore
val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore
return train_transforms, val_transforms
def create_chest_xray_transform(config: CfgNode,
apply_augmentations: bool) -> Callable:
"""
Defines the image transformations pipeline used in Chest-Xray datasets.
Type of augmentation and strength are defined in the config.
:param config: config yaml file fixing strength and type of augmentation to apply
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
disable augmentations i.e.
only resize and center crop the image.
"""
transforms: List[Any] = []
if apply_augmentations:
if config.augmentation.use_random_affine:
transforms.append(RandomAffine(config))
if config.augmentation.use_random_crop:
transforms.append(RandomResizeCrop(config))
else:
transforms.append(Resize(config))
if config.augmentation.use_random_horizontal_flip:
transforms.append(RandomHorizontalFlip(config))
if config.augmentation.use_gamma_transform:
transforms.append(RandomGamma(config))
if config.augmentation.use_random_color:
transforms.append(RandomColorJitter(config))
if config.augmentation.use_elastic_transform:
transforms.append(ElasticTransform(config))
transforms += [CenterCrop(config), ToTensor()]
if config.augmentation.use_random_erasing:
transforms.append(RandomErasing(config))
if config.augmentation.add_gaussian_noise:
transforms.append(AddGaussianNoise(config))
else:
transforms += [Resize(config), CenterCrop(config), ToTensor()]
transforms.append(ExpandChannels())
return torchvision.transforms.Compose(transforms)
class InnerEyeCIFARTrainTransform(SimCLRTrainDataTransform):
"""
Overload lightning-bolts SimCLRTrainDataTransform, to avoid return unused eval transform. Used for training and
Overload lightning-bolts SimCLRTrainDataTransform, to avoid return unused eval transform. Used for
training and
val of SSL models.
"""
@ -279,3 +64,17 @@ class InnerEyeCIFARLinearHeadTransform(SimCLRTrainDataTransform):
def __call__(self, sample: Any) -> Any:
return self.online_transform(sample)
class DualViewTransformWrapper:
"""
Returns two versions of one image, given a random transformation function.
"""
def __init__(self, transform: Callable):
self.transform = transform
def __call__(self, sample: PIL.Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
xi = self.transform(sample)
xj = self.transform(sample)
return xi, xj

Просмотреть файл

@ -22,7 +22,7 @@ from InnerEye.ML.SSL.encoders import get_encoder_output_dim
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType, load_ssl_augmentation_config
from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType, load_yaml_augmentation_config
from InnerEye.ML.lightning_container import LightningContainer
@ -127,10 +127,10 @@ class SSLContainer(LightningContainer):
def _load_config(self) -> None:
# For Chest-XRay you need to specify the parameters of the augmentations via a config file.
self.ssl_augmentation_params = load_ssl_augmentation_config(
self.ssl_augmentation_params = load_yaml_augmentation_config(
self.ssl_augmentation_config) if self.ssl_augmentation_config is not None \
else None
self.classifier_augmentation_params = load_ssl_augmentation_config(
self.classifier_augmentation_params = load_yaml_augmentation_config(
self.linear_head_augmentation_config) if self.linear_head_augmentation_config is not None else \
self.ssl_augmentation_params

Просмотреть файл

@ -25,10 +25,9 @@ class SSLTrainingType(Enum):
BYOL = "BYOL"
def load_ssl_augmentation_config(config_path: Path) -> CfgNode:
def load_yaml_augmentation_config(config_path: Path) -> CfgNode:
"""
Loads configs required for self supervised learning. Does not setup cudann as this is being
taken care of by lightning.
Loads augmentations configs defined as yaml files.
"""
config = ssl_augmentation_config.get_default_model_config()
config.merge_from_file(config_path)

Просмотреть файл

@ -0,0 +1,128 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import random
from typing import List, Tuple
import numpy as np
from InnerEye.Common.common_util import any_pairwise_larger
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.dataset.sample import Sample
def random_select_patch_center(sample: Sample, class_weights: List[float] = None) -> np.ndarray:
"""
Samples a point to use as the coordinates of the patch center. First samples one
class among the available classes then samples a center point among the pixels of the sampled
class.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return numpy int array (3x1) containing patch center spatial coordinates
"""
num_classes = sample.labels.shape[0]
if class_weights is not None:
if len(class_weights) != num_classes:
raise Exception("A weight must be provided for each class, found weights:{}, expected:{}"
.format(len(class_weights), num_classes))
SegmentationModelBase.validate_class_weights(class_weights)
# If class weights are not initialised, selection is made with equal probability for all classes
available_classes = list(range(num_classes))
original_class_weights = class_weights
while len(available_classes) > 0:
selected_label_class = random.choices(population=available_classes, weights=class_weights, k=1)[0]
# Check pixels where mask and label maps are both foreground
indices = np.argwhere(np.logical_and(sample.labels[selected_label_class] == 1.0, sample.mask == 1))
if not np.any(indices):
available_classes.remove(selected_label_class)
if class_weights is not None:
assert original_class_weights is not None # for mypy
class_weights = [original_class_weights[i] for i in available_classes]
if sum(class_weights) <= 0.0:
raise ValueError("Cannot sample a class: no class present in the sample has a positive weight")
else:
break
# Raise an exception if non of the foreground classes are overlapping with the mask
if len(available_classes) == 0:
raise Exception("No non-mask voxels found, please check your mask and labels map")
# noinspection PyUnboundLocalVariable
choice = random.randint(0, len(indices) - 1)
return indices[choice].astype(int) # Numpy usually stores as floats
def slicers_for_random_crop(sample: Sample,
crop_size: TupleInt3,
class_weights: List[float] = None) -> Tuple[List[slice], np.ndarray]:
"""
Computes array slicers that produce random crops of the given crop_size.
The selection of the center is dependant on background probability.
By default it does not center on background.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return: Tuple element 1: The slicers that convert the input image to the chosen crop. Tuple element 2: The
indices of the center point of the crop.
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
"""
shape = sample.image.shape[1:]
if any_pairwise_larger(crop_size, shape):
raise ValueError("The crop_size across each dimension should be greater than zero and less than or equal "
"to the current value (crop_size: {}, spatial shape: {})"
.format(crop_size, shape))
# Sample a center pixel location for patch extraction.
center = random_select_patch_center(sample, class_weights)
# Verify and fix overflow for each dimension
left = []
for i in range(3):
margin_left = int(crop_size[i] / 2)
margin_right = crop_size[i] - margin_left
left_index = center[i] - margin_left
right_index = center[i] + margin_right
if right_index > shape[i]:
left_index = left_index - (right_index - shape[i])
if left_index < 0:
left_index = 0
left.append(left_index)
return [slice(left[x], left[x] + crop_size[x]) for x in range(0, 3)], center
def random_crop(sample: Sample,
crop_size: TupleInt3,
class_weights: List[float] = None) -> Tuple[Sample, np.ndarray]:
"""
Randomly crops images, mask, and labels arrays according to the crop_size argument.
The selection of the center is dependant on background probability.
By default it does not center on background.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return: Tuple item 1: The cropped images, labels, and mask. Tuple item 2: The center that was chosen for the crop,
before shifting to be inside of the image. Tuple item 3: The slicers that convert the input image to the chosen
crop.
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
"""
slicers, center = slicers_for_random_crop(sample, crop_size, class_weights)
sample = Sample(
image=sample.image[:, slicers[0], slicers[1], slicers[2]],
labels=sample.labels[:, slicers[0], slicers[1], slicers[2]],
mask=sample.mask[slicers[0], slicers[1], slicers[2]],
metadata=sample.metadata
)
return sample, center

Просмотреть файл

@ -0,0 +1,116 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import random
from typing import Tuple
import numpy as np
import torch
import torchvision
from scipy.ndimage import gaussian_filter, map_coordinates
class RandomGamma:
"""
Custom function to apply a random gamma transform within a specified range of possible gamma value.
See documentation of
[`adjust_gamma`](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.functional.adjust_gamma) for
more details.
"""
def __init__(self, scale: Tuple[float, float]) -> None:
"""
:param scale: a tuple (min_gamma, max_gamma) that specifies the range of possible values to sample the gamma
value from when the transformation is called.
"""
self.scale = scale
def __call__(self, image: torch.Tensor) -> torch.Tensor:
gamma = random.uniform(*self.scale)
if len(image.shape) != 4:
raise ValueError(f"Expected input of shape [Z, C, H, W], but only got {len(image.shape)} dimensions")
for z in range(image.shape[0]):
for c in range(image.shape[1]):
image[z, c] = torchvision.transforms.functional.adjust_gamma(image[z, c], gamma=gamma)
return image
class ExpandChannels:
"""
Transforms an image with 1 channel to an image with 3 channels by copying pixel intensities of the image along
the 1st dimension.
"""
def __call__(self, data: torch.Tensor) -> torch.Tensor:
"""
:param: data of shape [Z, 1, H, W]
:return: data with channel copied 3 times, shape [Z, 3, H, W]
"""
shape = data.shape
if len(shape) != 4 or shape[1] != 1:
raise ValueError(f"Expected input of shape [Z, 1, H, W], found {shape}")
return torch.repeat_interleave(data, 3, dim=1)
class AddGaussianNoise:
def __init__(self, p_apply: float, std: float) -> None:
"""
Transformation to add Gaussian noise N(0, std) to an image.
:param: p_apply: probability of applying the transformation.
:param: std: standard deviation of the gaussian noise to add to the image.
"""
super().__init__()
self.p_apply = p_apply
self.std = std
def __call__(self, data: torch.Tensor) -> torch.Tensor:
if np.random.random(1) > self.p_apply:
return data
noise = torch.randn(size=data.shape[-2:]) * self.std
data = torch.clamp(data + noise, data.min(), data.max()) # type: ignore
return data
class ElasticTransform:
"""Elastic deformation of images as described in [Simard2003]_.
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
Convolutional Neural Networks applied to Visual Document Analysis", in
Proc. of the International Conference on Document Analysis and
Recognition, 2003.
https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.160.8494&rep=rep1&type=pdf
"""
def __init__(self,
sigma: float,
alpha: float,
p_apply: float
) -> None:
"""
:param sigma: elasticity coefficient
:param alpha: intensity of the deformation
:param p_apply: probability of applying the transformation
"""
super().__init__()
self.alpha = alpha
self.sigma = sigma
self.p_apply = p_apply
def __call__(self, data: torch.Tensor) -> torch.Tensor:
if np.random.random(1) > self.p_apply:
return data
result_type = data.dtype
data = data.cpu().numpy()
shape = data.shape
dx = gaussian_filter((np.random.random(shape[-2:]) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
dy = gaussian_filter((np.random.random(shape[-2:]) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
all_dimensions_axes = [np.arange(dim) for dim in shape]
grid = np.meshgrid(*all_dimensions_axes, indexing='ij')
grid[-2] = grid[-2] + dx
grid[-1] = grid[-1] + dy
indices = [np.reshape(grid[i], (-1, 1)) for i in range(len(grid))]
return torch.tensor(map_coordinates(data, indices, order=1).reshape(shape), dtype=result_type)

Просмотреть файл

@ -0,0 +1,145 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, Callable, List, Union
import PIL
import torch
from torchvision.transforms import CenterCrop, ColorJitter, Compose, RandomAffine, RandomErasing, \
RandomHorizontalFlip, RandomResizedCrop, Resize
from torchvision.transforms.functional import to_tensor
from yacs.config import CfgNode
from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma
ImageData = Union[PIL.Image.Image, torch.Tensor]
class ImageTransformationPipeline:
"""
This class is the base class to classes built to define data augmentation transformations
for 3D or 2D image inputs (tensor or PIL.Image).
In the case of 3D images, the transformations are applied slice by slices along the Z dimension (same transformation
applied for each slice).
The transformations are applied channel by channel, the user can specify whether to apply the same transformation
to each channel (no random shuffling) or whether each channel should use a different transformation (random
parameters of transforms shuffled for each channel).
"""
# noinspection PyMissingConstructor
def __init__(self,
transforms: Union[Callable, List[Callable]],
use_different_transformation_per_channel: bool = False):
"""
:param transforms: List of transformations to apply to images. Supports out of the boxes torchvision transforms
as they accept data of arbitrary dimension. You can also define your own transform class but be aware that you
function should expect input of shape [C, Z, H, W] and apply the same transformation to each Z slice.
:param use_different_transformation_per_channel: if True, apply a different version of the augmentation pipeline
for each channel. If False, applies the same transformation to each channel, separately.
"""
self.use_different_transformation_per_channel = use_different_transformation_per_channel
self.pipeline = Compose(transforms) if isinstance(transforms, List) else transforms
def transform_image(self, image: ImageData) -> torch.Tensor:
"""
Main function to apply the transformation pipeline to either slice by slice on one 3D-image or
on the 2D image.
Note for 3D images: Assumes the same transformations have to be applied on each 2D-slice along the Z-axis.
Assumes the Z axis is the first dimension.
:param image: batch of tensor images of size [C, Z, Y, X] or batch of 2D images as PIL Image
"""
def _convert_to_tensor_if_necessary(data: ImageData) -> torch.Tensor:
return to_tensor(data) if isinstance(data, PIL.Image.Image) else data
image = _convert_to_tensor_if_necessary(image)
original_input_is_2d = len(image.shape) == 3
# If we have a 2D image [C, H, W] expand to [Z, C, H, W]. Build-in torchvision transforms allow such 4D inputs.
if original_input_is_2d:
image = image.unsqueeze(0)
else:
# Some transforms assume the order of dimension is [..., C, H, W] so permute first and last dimension to
# obtain [Z, C, H, W]
if len(image.shape) != 4:
raise ValueError(f"ScalarDataset should load images as 4D tensor [C, Z, H, W]. The input tensor here"
f"was of shape {image.shape}. This is unexpected.")
image = torch.transpose(image, 1, 0)
if not self.use_different_transformation_per_channel:
image = _convert_to_tensor_if_necessary(self.pipeline(image))
else:
channels = []
for channel in range(image.shape[1]):
channels.append(_convert_to_tensor_if_necessary(self.pipeline(image[:, channel, :, :].unsqueeze(1))))
image = torch.cat(channels, dim=1)
# Back to [C, Z, H, W]
image = torch.transpose(image, 1, 0)
if original_input_is_2d:
image = image.squeeze(1)
return image.to(dtype=image.dtype)
def __call__(self, data: ImageData) -> torch.Tensor:
return self.transform_image(data)
def create_cxr_transforms_from_config(config: CfgNode,
apply_augmentations: bool) -> ImageTransformationPipeline:
"""
Defines the image transformations pipeline used in Chest-Xray datasets. Can be used for other types of
images data, type of augmentations to use and strength are expected to be defined in the config.
:param config: config yaml file fixing strength and type of augmentation to apply
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
disable augmentations i.e. only resize and center crop the image.
"""
transforms: List[Any] = [ExpandChannels()]
if apply_augmentations:
if config.augmentation.use_random_affine:
transforms.append(RandomAffine(
degrees=config.augmentation.random_affine.max_angle,
translate=(config.augmentation.random_affine.max_horizontal_shift,
config.augmentation.random_affine.max_vertical_shift),
shear=config.augmentation.random_affine.max_shear
))
if config.augmentation.use_random_crop:
transforms.append(RandomResizedCrop(
scale=config.augmentation.random_crop.scale,
size=config.preprocess.resize
))
else:
transforms.append(Resize(size=config.preprocess.resize))
if config.augmentation.use_random_horizontal_flip:
transforms.append(RandomHorizontalFlip(p=config.augmentation.random_horizontal_flip.prob))
if config.augmentation.use_gamma_transform:
transforms.append(RandomGamma(scale=config.augmentation.gamma.scale))
if config.augmentation.use_random_color:
transforms.append(ColorJitter(
brightness=config.augmentation.random_color.brightness,
contrast=config.augmentation.random_color.contrast,
saturation=config.augmentation.random_color.saturation
))
if config.augmentation.use_elastic_transform:
transforms.append(ElasticTransform(
alpha=config.augmentation.elastic_transform.alpha,
sigma=config.augmentation.elastic_transform.sigma,
p_apply=config.augmentation.elastic_transform.p_apply
))
transforms.append(CenterCrop(config.preprocess.center_crop_size))
if config.augmentation.use_random_erasing:
transforms.append(RandomErasing(
scale=config.augmentation.random_erasing.scale,
ratio=config.augmentation.random_erasing.ratio
))
if config.augmentation.add_gaussian_noise:
transforms.append(AddGaussianNoise(
p_apply=config.augmentation.gaussian_noise.p_apply,
std=config.augmentation.gaussian_noise.std
))
else:
transforms += [Resize(size=config.preprocess.resize),
CenterCrop(config.preprocess.center_crop_size)]
pipeline = ImageTransformationPipeline(transforms)
return pipeline

Просмотреть файл

@ -14,11 +14,12 @@ from pytorch_lightning import LightningModule
from torchvision.transforms import Compose
from InnerEye.Common.common_util import ModelProcessing, get_best_epoch_results_path
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import create_chest_xray_transform
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
from InnerEye.ML.SSL.utils import create_ssl_encoder, create_ssl_image_classifier, load_ssl_augmentation_config
from InnerEye.ML.SSL.utils import create_ssl_encoder, create_ssl_image_classifier, load_yaml_augmentation_config
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
from InnerEye.ML.common import ModelExecutionMode
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_linear_head_augmentation_cxr
@ -32,7 +33,6 @@ from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp impo
from InnerEye.ML.reports.notebook_report import generate_notebook, get_ipynb_report_name, str_or_empty
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
from InnerEye.ML.utils.augmentation import ScalarItemAugmentation
from InnerEye.ML.utils.run_recovery import RunRecovery
from InnerEye.ML.utils.split_dataset import DatasetSplits
@ -114,13 +114,11 @@ class CovidHierarchicalModel(ScalarModelBase):
# noinspection PyTypeChecker
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
config = load_ssl_augmentation_config(path_linear_head_augmentation_cxr)
train_transforms = ScalarItemAugmentation(
Compose(
[DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=True)]))
val_transforms = ScalarItemAugmentation(
Compose(
[DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=False)]))
config = load_yaml_augmentation_config(path_linear_head_augmentation_cxr)
train_transforms = Compose(
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=True)])
val_transforms = Compose(
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=False)])
return ModelTransformsPerExecutionMode(train=train_transforms,
val=val_transforms,

Просмотреть файл

@ -8,12 +8,13 @@ from typing import Any, Dict, List, Optional
import numpy as np
import pandas as pd
from InnerEye.ML.augmentations.augmentation_for_segmentation_utils import random_crop
from InnerEye.Common.common_util import any_pairwise_larger
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.config import PaddingMode, SegmentationModelBase
from InnerEye.ML.dataset.full_image_dataset import FullImageDataset
from InnerEye.ML.dataset.sample import CroppedSample, Sample
from InnerEye.ML.utils import augmentation, image_util
from InnerEye.ML.utils import image_util
from InnerEye.ML.utils.image_util import pad_images
from InnerEye.ML.utils.io_util import ImageDataType
from InnerEye.ML.utils.transforms import Compose3D
@ -96,7 +97,7 @@ class CroppingDataset(FullImageDataset):
:return: CroppedSample
"""
# crop the original raw sample
sample, center_point = augmentation.random_crop(
sample, center_point = random_crop(
sample=sample,
crop_size=crop_size,
class_weights=class_weights

Просмотреть файл

@ -26,7 +26,6 @@ from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.csv_util import CSV_CHANNEL_HEADER, CSV_SUBJECT_HEADER
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
from InnerEye.ML.utils.features_util import FeatureStatistics
from InnerEye.ML.utils.transforms import Compose3D, Transform3D
T = TypeVar('T', bound=ScalarDataSource)
@ -67,12 +66,13 @@ def extract_label_classification(label_string: str, sample_id: str, num_classes:
if isinstance(label_string, float):
if math.isnan(label_string):
if num_classes == 1:
# Pandas special case: When loading a dataframe with dtype=str, missing values can be encoded as NaN, and get into here.
# Pandas special case: When loading a dataframe with dtype=str, missing values can be encoded as NaN,
# and get into here.
return [label_string]
else:
return [0] * num_classes
else:
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
f"dataframe column as a string?")
if not label_string:
@ -92,7 +92,7 @@ def extract_label_classification(label_string: str, sample_id: str, num_classes:
return [float(label_string)]
else:
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'. "
f"Should be one of true/false, yes/no or 0/1.")
f"Should be one of true/false, yes/no or 0/1.")
if '|' in label_string or label_string.isdigit():
classes = [int(a) for a in label_string.split('|')]
@ -648,6 +648,36 @@ You now want to get the label from the "week0" row, and read out Scalar1 at week
"""
class ScalarItemAugmentation:
"""
Wrapper around augmentation pipeline to apply image or/and segmentation transformations
to a ScalarItem inputs.
"""
def __init__(self,
image_transform: Optional[Callable] = None,
segmentation_transform: Optional[Callable] = None) -> None:
"""
:param image_transform: transformation function to apply to images field. If None, images field is unchanged by
call.
:param segmentation_transform: transformation function to apply to segmentations field. If None segmentations
field is unchanged by call.
"""
self.image_transform = image_transform
self.segmentation_transform = segmentation_transform
def __call__(self, item: ScalarItem) -> ScalarItem:
if self.image_transform is not None:
if self.segmentation_transform is not None:
return item.clone_with_overrides(images=self.image_transform(item.images),
segmentations=self.segmentation_transform(item.segmentations))
return item.clone_with_overrides(images=self.image_transform(item.images))
if self.segmentation_transform is not None:
item.clone_with_overrides(segmentations=self.segmentation_transform(item.segmentations))
return item
class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
"""
A base class for datasets for classification tasks. It contains logic for loading images from disk,
@ -661,7 +691,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
data_frame: Optional[pd.DataFrame] = None,
feature_statistics: Optional[FeatureStatistics] = None,
name: Optional[str] = None,
sample_transforms: Optional[Union[Compose3D[ScalarItem], Transform3D[ScalarItem]]] = None):
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
"""
High level class for the scalar dataset designed to be inherited for specific behaviour
:param args: The model configuration object.
@ -670,7 +700,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
:param name: Name of the dataset, used for diagnostics logging
"""
super().__init__(args, data_frame, name)
self.transforms = sample_transforms
self.transform = sample_transform
self.feature_statistics = feature_statistics
self.file_to_full_path: Optional[Dict[str, Path]] = None
if args.traverse_dirs_when_loading:
@ -725,7 +755,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
center_crop_size=self.args.center_crop_size,
image_size=self.args.image_size)
return Compose3D.apply(self.transforms, sample)
return self.transform(sample)
def create_status_string(self, items: List[T]) -> str:
"""
@ -746,21 +776,21 @@ class ScalarDataset(ScalarDatasetBase[ScalarDataSource]):
data_frame: Optional[pd.DataFrame] = None,
feature_statistics: Optional[FeatureStatistics[ScalarDataSource]] = None,
name: Optional[str] = None,
sample_transforms: Optional[Union[Compose3D[ScalarItem], Transform3D[ScalarItem]]] = None):
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
"""
Creates a new scalar dataset from a dataframe.
:param args: The model configuration object.
:param data_frame: The dataframe to read from.
:param feature_statistics: If given, the normalization factor for the non-image features is taken
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
:param sample_transforms: Sample transforms that should be applied.
:param sample_transform: Sample transforms that should be applied.
:param name: Name of the dataset, used for diagnostics logging
"""
super().__init__(args,
data_frame=data_frame,
feature_statistics=feature_statistics,
name=name,
sample_transforms=sample_transforms)
sample_transform=sample_transform)
self.items = self.load_all_data_sources()
self.standardize_non_imaging_features()

Просмотреть файл

@ -6,18 +6,18 @@ from __future__ import annotations
import logging
from collections import Counter, defaultdict
from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, DefaultDict, Dict, Iterable, List, Optional
import numpy as np
import pandas as pd
import torch
from InnerEye.ML.dataset.scalar_dataset import ScalarDatasetBase, filter_valid_classification_data_sources_items
from InnerEye.ML.dataset.scalar_dataset import ScalarDatasetBase, ScalarItemAugmentation, \
filter_valid_classification_data_sources_items
from InnerEye.ML.dataset.scalar_sample import ScalarItem, SequenceDataSource
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence, ListOfSequences
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils.features_util import FeatureStatistics
from InnerEye.ML.utils.transforms import Compose3D, Transform3D
def get_longest_contiguous_sequence(items: List[SequenceDataSource],
@ -214,13 +214,14 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
feature_statistics: Optional[
FeatureStatistics[ClassificationItemSequence[SequenceDataSource]]] = None,
name: Optional[str] = None,
sample_transforms: Optional[Union[Compose3D[ScalarItem], Transform3D[ScalarItem]]] = None):
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
"""
Creates a new sequence dataset from a dataframe.
:param args: The model configuration object.
:param data_frame: The dataframe to read from.
:param feature_statistics: If given, the normalization factor for the non-image features is taken
:param sample_transforms: optional transformation to apply to each sample in the loading step.
:param sample_transform: Transformation to apply to each sample in the loading step. By default, no
transformation is applied.
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
:param name: Name of the dataset, used for logging
"""
@ -228,7 +229,7 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
data_frame=data_frame,
feature_statistics=feature_statistics,
name=name,
sample_transforms=sample_transforms)
sample_transform=sample_transform)
if self.args.sequence_column is None:
raise ValueError("This class requires a value in the `sequence_column`, specifying where the "
"sequence index should be read from.")
@ -288,7 +289,8 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
"""
all_labels_per_target = torch.stack([seq.get_labels_at_target_indices(self.args.get_target_indices())
for seq in self.items]) # [N, T, 1]
non_nan_and_nonzero_labels = list(filter(lambda x: not np.isnan(x) and x != 0, all_labels_per_target.flatten().tolist()))
non_nan_and_nonzero_labels = list(
filter(lambda x: not np.isnan(x) and x != 0, all_labels_per_target.flatten().tolist()))
counts = dict(Counter(non_nan_and_nonzero_labels))
if not len(counts.keys()) == 1 or 1 not in counts.keys():
raise ValueError("get_class_counts supports only binary targets.")

Просмотреть файл

@ -15,6 +15,7 @@ from pandas import DataFrame
from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
from InnerEye.Common.common_util import ModelProcessing
from InnerEye.Common.metrics_constants import TrackedMetrics
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode, STORED_CSV_FILE_NAMES
from InnerEye.ML.deep_learning_config import DeepLearningConfig
from InnerEye.ML.utils.split_dataset import DatasetSplits

Просмотреть файл

@ -14,6 +14,7 @@ from azureml.train.hyperdrive import HyperDriveConfig
from InnerEye.Common.common_util import print_exception
from InnerEye.Common.generic_parsing import ListOrDictParam
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.common import ModelExecutionMode, OneHotEncoderBase
from InnerEye.ML.deep_learning_config import ModelCategory
from InnerEye.ML.model_config_base import ModelConfigBase, ModelTransformsPerExecutionMode
@ -22,6 +23,7 @@ from InnerEye.ML.utils.split_dataset import DatasetSplits
DEFAULT_KEY = "Default"
class AggregationType(Enum):
"""
The type of global pooling aggregation to use between the encoder and the classifier.
@ -113,7 +115,8 @@ class ScalarModelBase(ModelConfigBase):
"in dataset.csv"
"For multi-label classification, this field is required."
"For binary classification, this field must be a list of size 1, and "
"is by default ['Default'], but can optionally be set to a more descriptive "
"is by default ['Default'], but can optionally be set to a more "
"descriptive "
"name for the positive class.")
target_names: List[str] = param.List(class_=str,
default=None,
@ -133,7 +136,8 @@ class ScalarModelBase(ModelConfigBase):
image_file_column: Optional[str] = param.String(default=None, allow_None=True,
doc="The column that contains the path to image files.")
expected_column_values: List[Tuple[str, str]] = \
param.List(default=None, doc="List of tuples with column name and expected value to filter rows in the dataset csv file",
param.List(default=None,
doc="List of tuples with column name and expected value to filter rows in the dataset csv file",
allow_None=True)
label_channels: Optional[List[str]] = \
param.List(default=None, allow_None=True,
@ -384,13 +388,27 @@ class ScalarModelBase(ModelConfigBase):
def create_torch_datasets(self, dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
image_transforms = self.get_image_sample_transforms()
train = ScalarDataset(args=self, data_frame=dataset_splits.train,
name="training", sample_transforms=image_transforms.train) # type: ignore
val = ScalarDataset(args=self, data_frame=dataset_splits.val, feature_statistics=train.feature_statistics,
name="validation", sample_transforms=image_transforms.val) # type: ignore
test = ScalarDataset(args=self, data_frame=dataset_splits.test, feature_statistics=train.feature_statistics,
name="test", sample_transforms=image_transforms.test) # type: ignore
sample_transform = self.get_scalar_item_transform()
assert sample_transform.train is not None # for mypy
assert sample_transform.val is not None # for mypy
assert sample_transform.test is not None # for mypy
train = ScalarDataset(
args=self,
data_frame=dataset_splits.train,
name="training",
sample_transform=sample_transform.train)
val = ScalarDataset(
args=self,
data_frame=dataset_splits.val,
feature_statistics=train.feature_statistics,
name="validation",
sample_transform=sample_transform.val)
test = ScalarDataset(
args=self,
data_frame=dataset_splits.test,
feature_statistics=train.feature_statistics,
name="test",
sample_transform=sample_transform.test)
return {
ModelExecutionMode.TRAIN: train,
@ -459,14 +477,29 @@ class ScalarModelBase(ModelConfigBase):
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
return super().get_model_train_test_dataset_splits(dataset_df)
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
"""
Get transforms to perform on samples for each model execution mode.
Get transforms to apply to images for each model execution mode.
By default only no transformation is performed.
For data augmentation, specify a Compose3D for the training execution mode.
"""
return ModelTransformsPerExecutionMode()
def get_segmentation_transform(self) -> ModelTransformsPerExecutionMode:
"""
Get transforms to apply on segmentations maps inputs for each model execution mode.
By default only no transformation is performed.
"""
return ModelTransformsPerExecutionMode()
def get_scalar_item_transform(self) -> ModelTransformsPerExecutionMode:
from InnerEye.ML.dataset.scalar_dataset import ScalarItemAugmentation
image_transform = self.get_image_transform()
segmentation_transform = self.get_segmentation_transform()
return ModelTransformsPerExecutionMode(
train=ScalarItemAugmentation(image_transform.train, segmentation_transform.train),
val=ScalarItemAugmentation(image_transform.val, segmentation_transform.val),
test=ScalarItemAugmentation(image_transform.test, segmentation_transform.test))
def get_non_image_features_dict(default_channels: List[str],
specific_channels: Optional[Dict[str, List[str]]] = None) -> Dict[str, List[str]]:

Просмотреть файл

@ -90,13 +90,25 @@ class SequenceModelBase(ScalarModelBase):
def create_torch_datasets(self, dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
sample_transforms = self.get_image_sample_transforms()
train = SequenceDataset(self, dataset_splits.train, name="training",
sample_transforms=sample_transforms.train) # type: ignore
val = SequenceDataset(self, dataset_splits.val, feature_statistics=train.feature_statistics, name="validation",
sample_transforms=sample_transforms.val) # type: ignore
test = SequenceDataset(self, dataset_splits.test, feature_statistics=train.feature_statistics, name="test",
sample_transforms=sample_transforms.test) # type: ignore
sample_transform = self.get_scalar_item_transform()
assert sample_transform.train is not None # for mypy
assert sample_transform.val is not None # for mypy
assert sample_transform.test is not None # for mypy
train = SequenceDataset(self,
dataset_splits.train,
name="training",
sample_transform=sample_transform.train)
val = SequenceDataset(self,
dataset_splits.val,
feature_statistics=train.feature_statistics,
name="validation",
sample_transform=sample_transform.val)
test = SequenceDataset(self,
dataset_splits.test,
feature_statistics=train.feature_statistics,
name="test",
sample_transform=sample_transform.test)
return {
ModelExecutionMode.TRAIN: train,

Просмотреть файл

@ -1,475 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import random
from typing import Any, Callable, Dict, List, Tuple
import numpy as np
import torch
from torchvision.transforms import functional as TF
from InnerEye.Common.common_util import any_pairwise_larger
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.dataset.sample import Sample
from InnerEye.ML.dataset.scalar_sample import ScalarItem
from InnerEye.ML.utils.transforms import Transform3D
def random_select_patch_center(sample: Sample, class_weights: List[float] = None) -> np.ndarray:
"""
Samples a point to use as the coordinates of the patch center. First samples one
class among the available classes then samples a center point among the pixels of the sampled
class.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return numpy int array (3x1) containing patch center spatial coordinates
"""
num_classes = sample.labels.shape[0]
if class_weights is not None:
if len(class_weights) != num_classes:
raise Exception("A weight must be provided for each class, found weights:{}, expected:{}"
.format(len(class_weights), num_classes))
SegmentationModelBase.validate_class_weights(class_weights)
# If class weights are not initialised, selection is made with equal probability for all classes
available_classes = list(range(num_classes))
original_class_weights = class_weights
while len(available_classes) > 0:
selected_label_class = random.choices(population=available_classes, weights=class_weights, k=1)[0]
# Check pixels where mask and label maps are both foreground
indices = np.argwhere(np.logical_and(sample.labels[selected_label_class] == 1.0, sample.mask == 1))
if not np.any(indices):
available_classes.remove(selected_label_class)
if class_weights is not None:
assert original_class_weights is not None # for mypy
class_weights = [original_class_weights[i] for i in available_classes]
if sum(class_weights) <= 0.0:
raise ValueError("Cannot sample a class: no class present in the sample has a positive weight")
else:
break
# Raise an exception if non of the foreground classes are overlapping with the mask
if len(available_classes) == 0:
raise Exception("No non-mask voxels found, please check your mask and labels map")
# noinspection PyUnboundLocalVariable
choice = random.randint(0, len(indices) - 1)
return indices[choice].astype(int) # Numpy usually stores as floats
def slicers_for_random_crop(sample: Sample,
crop_size: TupleInt3,
class_weights: List[float] = None) -> Tuple[List[slice], np.ndarray]:
"""
Computes array slicers that produce random crops of the given crop_size.
The selection of the center is dependant on background probability.
By default it does not center on background.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return: Tuple element 1: The slicers that convert the input image to the chosen crop. Tuple element 2: The
indices of the center point of the crop.
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
"""
shape = sample.image.shape[1:]
if any_pairwise_larger(crop_size, shape):
raise ValueError("The crop_size across each dimension should be greater than zero and less than or equal "
"to the current value (crop_size: {}, spatial shape: {})"
.format(crop_size, shape))
# Sample a center pixel location for patch extraction.
center = random_select_patch_center(sample, class_weights)
# Verify and fix overflow for each dimension
left = []
for i in range(3):
margin_left = int(crop_size[i] / 2)
margin_right = crop_size[i] - margin_left
left_index = center[i] - margin_left
right_index = center[i] + margin_right
if right_index > shape[i]:
left_index = left_index - (right_index - shape[i])
if left_index < 0:
left_index = 0
left.append(left_index)
return [slice(left[x], left[x] + crop_size[x]) for x in range(0, 3)], center
def random_crop(sample: Sample,
crop_size: TupleInt3,
class_weights: List[float] = None) -> Tuple[Sample, np.ndarray]:
"""
Randomly crops images, mask, and labels arrays according to the crop_size argument.
The selection of the center is dependant on background probability.
By default it does not center on background.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return: Tuple item 1: The cropped images, labels, and mask. Tuple item 2: The center that was chosen for the crop,
before shifting to be inside of the image. Tuple item 3: The slicers that convert the input image to the chosen
crop.
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
"""
slicers, center = slicers_for_random_crop(sample, crop_size, class_weights)
sample = Sample(
image=sample.image[:, slicers[0], slicers[1], slicers[2]],
labels=sample.labels[:, slicers[0], slicers[1], slicers[2]],
mask=sample.mask[slicers[0], slicers[1], slicers[2]],
metadata=sample.metadata
)
return sample, center
class ImageTransformationBase(Transform3D):
"""
This class is the base class to classes built to define data augmentation transformations
for 3D image inputs.
"""
# noinspection PyMissingConstructor
def __init__(self,
is_transformation_for_segmentation_maps: bool = False,
use_joint_channel_transformation: bool = False):
"""
:param is_transformation_for_segmentation_maps: if True, only use geometrical transformation suitable
for segmentation maps. If False, additionally use color/contrast related transformation suitable for
images.
:param use_joint_channel_transformation: if True apply the exact same transformation for all channels of
a given image. If False, apply a different transformation for each channel.
"""
self.for_segmentation_input_maps = is_transformation_for_segmentation_maps
self.use_joint_channel_transformation = use_joint_channel_transformation
def draw_next_transform(self) -> List[Callable]:
"""
Samples all parameters defining the transformation pipeline.
Returns a list of operations to apply to each 2D-slice in a given
3D volume.
(defined by the sampled parameters).
:return: list of transformations to apply to each B-scan.
"""
raise NotImplementedError("The child class should implement the sampling of transforms")
@staticmethod
def apply_transform_on_3d_image(image: torch.Tensor, transforms: List[Callable]) -> torch.Tensor:
"""
Apply a list of transforms sequentially to a 3D image. Each transformation is assumed to be a 2D-transform to
be applied to each slice of the given 3D input separately.
:param image: a 3d tensor dimension [Z, X, Y]. Transform are applied one after another
separately for each [X, Y] slice along the Z-dimension (assumed to be the first dimension).
:param transforms: a list of transformations to apply to each slice sequentially.
:returns image: the transformed 3D-image
"""
for z in range(image.shape[0]):
pil = TF.to_pil_image(image[z])
for transform_fn in transforms:
pil = transform_fn(pil)
image[z] = TF.to_tensor(pil).squeeze()
return image
@staticmethod
def _toss_fair_coin() -> bool:
"""
Simulates the toss of a fair coin.
:returns the outcome of the toss.
"""
return random.random() > 0.5
@staticmethod
def randomly_negate_level(value: Any) -> Any:
"""
Negate the value of the input with probability 0.5
"""
return -value if ImageTransformationBase._toss_fair_coin() else value
@staticmethod
def identity() -> Callable:
"""
Identity transform.
"""
return lambda img: img
@staticmethod
def rotate(angle: float) -> Callable:
"""
Returns a function that rotates a 2D image by
a certain angle.
"""
return lambda img: TF.rotate(img, angle)
@staticmethod
def translateX(shift: float) -> Callable:
"""
Returns a function that shifts a 2D image horizontally by
a given shift.
:param shift: a floating point between (-1, 1), the shift is defining as the
proportion of the image width by which to shift.
"""
return lambda img: TF.affine(img, 0, (shift * img.size[0], 0), 1, 0)
@staticmethod
def translateY(shift: float) -> Callable:
"""
Returns a function that shifts a 2D image vertically by
a given shift.
:param shift: a floating point between (-1, 1), the shift is defining as the
proportion of the image height by which to shift.
"""
return lambda img: TF.affine(img, 0, (0, shift * img.size[1]), 1, 0)
@staticmethod
def horizontal_flip() -> Callable:
"""
Returns a function that is flipping a 2D-image horizontally.
"""
return lambda img: TF.hflip(img)
@staticmethod
def adjust_contrast(constrast_factor: float) -> Callable:
"""
Returns a function that modifies the contrast of a
2D image by a certain factor.
:param constrast_factor: Integer > 0. 0 means black image,
1 means no transformation, 2 means multipyling the contrast
by two.
"""
return lambda img: TF.adjust_contrast(img, constrast_factor)
@staticmethod
def adjust_brightness(brightness_factor: float) -> Callable:
"""
Returns a function that modifies the brightness of a
2D image by a certain factor.
:param brightness_factor: Integer > 0. 0 means black image,
1 means no transformation, 2 means multipyling the brightness
by two.
"""
return lambda img: TF.adjust_brightness(img, brightness_factor)
def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Main function to apply the transformation to one 3D-image.
Assumes the same transformations have to be applied on
each 2D-slice along the Z-axis.
Assumes the Z axis is the first dimension.
:param image: batch of images of size [C, Z, Y, X]
"""
assert len(image.shape) == 4
res = image.clone()
if self.for_segmentation_input_maps:
res = res.int()
else:
res = res.float()
if res.max() > 1:
raise ValueError("Image tensor should be in "
"range 0-1 for conversion to PIL")
# Sample parameters defining the transformation
transforms = self.draw_next_transform()
for c in range(image.shape[0]):
res[c] = self.apply_transform_on_3d_image(res[c], transforms)
if not self.use_joint_channel_transformation:
# Resample transformations for the next channel
transforms = self.draw_next_transform()
return res.to(dtype=image.dtype)
class RandomSliceTransformation(ImageTransformationBase):
"""
Class to apply a random set of 2D affine transformations to all
slices of a 3D volume separately along the z-dimension.
"""
def __init__(self,
probability_transformation: float = 0.8,
max_angle: int = 10,
max_x_shift: float = 0.05,
max_y_shift: float = 0.1,
max_contrast: float = 2,
min_constrast: float = 0,
max_brightness: float = 2,
min_brightness: float = 0,
**kwargs: Any) -> None:
"""
:param probability_transformation: probability of applying the transformation pipeline.
:param max_angle: maximum allowed angle for rotation. For each transformation
the angle is drawn uniformly between -max_angle and max_angle.
:param min_constrast: Minimum contrast factor to apply. 1 means no difference.
2 means doubling the contrast. 0 means a black image. Parameter is sampled
between min_contrast and max_contrast.
:param max_contrast: maximum contrast factor
:param max_brightness: Maximum brightness factor to apply. 1 means no difference.
2 means doubling the brightness. 0 means a black image. Parameter is sampled
between min_brightness and max_brightness.
:param max_x_shift: maximum vertical shift in proportion of the image width
:param max_y_shift: maximum horizontal shift in proportion of the image height.
"""
super().__init__(**kwargs)
self.probability_transformation = probability_transformation
self.max_angle = max_angle
self.max_x_shift = max_x_shift
self.max_y_shift = max_y_shift
self.max_constrast = max_contrast
self.min_constrast = min_constrast
self.min_brightness = min_brightness
self.max_brightness = max_brightness
def draw_next_transform(self) -> List[Callable]:
"""
Samples all parameters defining the transformation pipeline.
Returns a list of operations to apply to each slice in the
3D volume.
:return: list of transformations to apply to each slice.
"""
# Sample parameters for each transformation
angle = random.randint(-self.max_angle, self.max_angle)
x_shift = random.uniform(-self.max_x_shift, self.max_x_shift)
y_shift = random.uniform(-self.max_y_shift, self.max_y_shift)
contrast = random.uniform(self.min_constrast, self.max_constrast)
brightness = random.uniform(self.min_brightness, self.max_brightness)
horizontal_flip = ImageTransformationBase._toss_fair_coin()
# Returns the corresponding operations
if random.random() < self.probability_transformation:
ops = [self.rotate(angle),
self.translateX(x_shift),
self.translateY(y_shift)]
if horizontal_flip:
ops.append(self.horizontal_flip())
if self.for_segmentation_input_maps:
return ops
ops.extend([self.adjust_contrast(contrast),
self.adjust_brightness(brightness)])
else:
ops = []
return ops
class RandAugmentSlice(ImageTransformationBase):
"""
Implements the RandAugment procedure on a restricted set of
transformations. https://arxiv.org/abs/1909.13719
Possible transformations for segmentations maps are: rotation, horizontal
and vertical shift, horizontal flip, identity. Additional transformations
for images are brightness adjustment and contrast adjustment.
"""
def __init__(self,
magnitude: int = 3,
n_transforms: int = 2,
**kwargs: Any) -> None:
"""
:param magnitude: magnitude to apply to the transformations as defined in the RandAugment paper.
1 means a weak transform, 10 is the strongest transform.
:param n_transforms: number of transformation to sample for each image.
"""
super().__init__(**kwargs)
self.magnitude = magnitude
self.n_transforms = n_transforms
self._max_magnitude = 10.0
self._max_x_shift = 0.1
self._max_y_shift = 0.2
self._max_angle = 30
self._max_contrast = 1
self._max_brightness = 1
def get_all_transforms(self) -> Dict[str, Callable]:
"""
Defines the possible transformations for one fixed magnitude level
to sample from.
"""
# Convert magnitude to argument for each transform
level = self.magnitude / self._max_magnitude
angle = self.randomly_negate_level(level) * self._max_angle
x_shift = self.randomly_negate_level(level) * self._max_x_shift
y_shift = self.randomly_negate_level(level) * self._max_y_shift
# Contrast / brightness factor of 1 means no change. 0 means black, 2 means times 2.
contrast = self.randomly_negate_level(level) * self._max_contrast + 1
brightness = self.randomly_negate_level(level) * self._max_brightness + 1
transforms_dict = {
"identity": self.identity(),
"rotate": self.rotate(angle),
"translateX": self.translateX(x_shift),
"translateY": self.translateY(y_shift),
"hFlip": self.horizontal_flip()
}
if self.for_segmentation_input_maps:
return transforms_dict
transforms_dict.update(
{"constrast": self.adjust_contrast(contrast),
"brightness": self.adjust_brightness(brightness),
})
return transforms_dict
def draw_next_transform(self) -> List[Callable]:
"""
Samples all parameters defining the transformation pipeline.
Returns a list of operations to apply to each slice of a 3D volume
(defined by the sampled parameters).
:return: list of transformations to apply to each slice.
"""
available_transforms = self.get_all_transforms()
transform_names = np.random.choice(list(available_transforms), self.n_transforms)
ops = [available_transforms[name] for name in transform_names]
return ops
class ScalarItemAugmentation(Transform3D[ScalarItem]):
"""
Wrapper around an augmentation pipeline for applying an image transformation
to a ScalarItem input and return the transformed sample. Applies the
transformation either to the images or the segmentation maps depending on the
defined transformation to apply. Several objects of this class can be applied
in a row inside a Compose3D object.
"""
# noinspection PyMissingConstructor
def __init__(self, transform: ImageTransformationBase):
"""
:param transform: the transformation to apply to the image.
"""
self.transform = transform
def __call__(self, item: ScalarItem) -> ScalarItem:
if hasattr(self.transform, "for_segmentation_input_maps") and self.transform.for_segmentation_input_maps:
if item.segmentations is None:
raise ValueError("A segmentation data augmentation transform has been"
"specified but no segmentations has been loaded.")
return item.clone_with_overrides(segmentations=self.transform(item.segmentations))
else:
return item.clone_with_overrides(images=self.transform(item.images))
class SampleImageAugmentation(Transform3D[Sample]):
"""
Wrapper around augmentation pipeline for applying an image transformation
to a Sample input (for segmentation models).
"""
# noinspection PyMissingConstructor
def __init__(self, transform: ImageTransformationBase) -> None:
self.transform = transform
def __call__(self, item: Sample) -> Sample:
return item.clone_with_overrides(image=self.transform(item.image))

Просмотреть файл

@ -10,13 +10,14 @@ import matplotlib.pyplot as plt
import numpy as np
import param
from InnerEye.ML.augmentations.augmentation_for_segmentation_utils import slicers_for_random_crop
from InnerEye.Common.generic_parsing import GenericConfig
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.dataset.cropping_dataset import CroppingDataset
from InnerEye.ML.dataset.full_image_dataset import FullImageDataset
from InnerEye.ML.dataset.sample import Sample
from InnerEye.ML.plotting import resize_and_save, scan_with_transparent_overlay
from InnerEye.ML.utils import augmentation, io_util
from InnerEye.ML.utils import io_util
# The name of the folder inside the default outputs folder that will holds plots that show the effect of
# sampling random patches
from InnerEye.ML.utils.image_util import get_unit_image_header
@ -60,9 +61,9 @@ def visualize_random_crops(sample: Sample,
# Nifti file of that datatype.
repeats = 200
for _ in range(repeats):
slicers, _ = augmentation.slicers_for_random_crop(sample=sample,
crop_size=config.crop_size,
class_weights=config.class_weights)
slicers, _ = slicers_for_random_crop(sample=sample,
crop_size=config.crop_size,
class_weights=config.class_weights)
heatmap[slicers[0], slicers[1], slicers[2]] += 1
is_3dim = heatmap.shape[0] > 1
header = sample.metadata.image_header

Просмотреть файл

@ -0,0 +1,162 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, List
import numpy as np
import pytest
from InnerEye.ML.augmentations.augmentation_for_segmentation_utils import random_crop
from InnerEye.ML.dataset.sample import Sample
from InnerEye.ML.utils import ml_util
from Tests.ML.util import DummyPatientMetadata
ml_util.set_random_seed(1)
image_size = (8, 8, 8)
valid_image_4d = np.random.uniform(size=((5,) + image_size)) * 10
valid_mask = np.random.randint(2, size=image_size)
number_of_classes = 5
class_assignments = np.random.randint(2, size=image_size)
valid_labels = np.zeros((number_of_classes,) + image_size)
for c in range(number_of_classes):
valid_labels[c, class_assignments == c] = 1
valid_crop_size = (2, 2, 2)
valid_full_crop_size = image_size
valid_class_weights = [0.5] + [0.5 / (number_of_classes - 1)] * (number_of_classes - 1)
crop_size_requires_padding = (9, 8, 12)
def test_valid_full_crop() -> None:
metadata = DummyPatientMetadata
sample, _ = random_crop(sample=Sample(image=valid_image_4d,
labels=valid_labels,
mask=valid_mask,
metadata=metadata),
crop_size=valid_full_crop_size,
class_weights=valid_class_weights)
assert np.array_equal(sample.image, valid_image_4d)
assert np.array_equal(sample.labels, valid_labels)
assert np.array_equal(sample.mask, valid_mask)
assert sample.metadata == metadata
@pytest.mark.parametrize("image", [None, list(), valid_image_4d])
@pytest.mark.parametrize("labels", [None, list(), valid_labels])
@pytest.mark.parametrize("mask", [None, list(), valid_mask])
@pytest.mark.parametrize("class_weights", [[0, 0, 0], [0], [-1, 0, 1], [-1, -2, -3], valid_class_weights])
def test_invalid_arrays(image: Any, labels: Any, mask: Any, class_weights: Any) -> None:
"""
Tests failure cases of the random_crop function for invalid image, labels, mask or class
weights arguments.
"""
# Skip the final combination, because it is valid
if not (np.array_equal(image, valid_image_4d) and np.array_equal(labels, valid_labels)
and np.array_equal(mask, valid_mask) and class_weights == valid_class_weights):
with pytest.raises(Exception):
random_crop(Sample(metadata=DummyPatientMetadata, image=image, labels=labels, mask=mask),
valid_crop_size, class_weights)
@pytest.mark.parametrize("crop_size", [None, ["a"], 5])
def test_invalid_crop_arg(crop_size: Any) -> None:
with pytest.raises(Exception):
random_crop(
Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=valid_mask),
crop_size, valid_class_weights)
@pytest.mark.parametrize("crop_size", [[2, 2], [2, 2, 2, 2], [10, 10, 10]])
def test_invalid_crop_size(crop_size: Any) -> None:
with pytest.raises(Exception):
random_crop(
Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=valid_mask),
crop_size, valid_class_weights)
def test_random_crop_no_fg() -> None:
with pytest.raises(Exception):
random_crop(Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels,
mask=np.zeros_like(valid_mask)),
valid_crop_size, valid_class_weights)
with pytest.raises(Exception):
random_crop(Sample(metadata=DummyPatientMetadata, image=valid_image_4d,
labels=np.zeros_like(valid_labels), mask=valid_mask),
valid_crop_size, valid_class_weights)
@pytest.mark.parametrize("crop_size", [valid_crop_size])
def test_random_crop(crop_size: Any) -> None:
labels = valid_labels
# create labels such that there are no foreground voxels in a particular class
# this should ne handled gracefully (class being ignored from sampling)
labels[0] = 1
labels[1] = 0
sample, _ = random_crop(Sample(
image=valid_image_4d,
labels=valid_labels,
mask=valid_mask,
metadata=DummyPatientMetadata
), crop_size, valid_class_weights)
expected_img_crop_size = (valid_image_4d.shape[0], *crop_size)
expected_labels_crop_size = (valid_labels.shape[0], *crop_size)
assert sample.image.shape == expected_img_crop_size
assert sample.labels.shape == expected_labels_crop_size
assert sample.mask.shape == tuple(crop_size)
@pytest.mark.parametrize("class_weights",
[None, [0, 0.5, 0.5, 0, 0], [0.1, 0.45, 0.45, 0, 0], [0.5, 0.25, 0.25, 0, 0],
[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0.04, 0.12, 0.20, 0.28, 0.36], [0, 0.5, 0, 0.5, 0]])
def test_valid_class_weights(class_weights: List[float]) -> None:
"""
Produce a large number of crops and make sure the crop center class proportions respect class weights
"""
ml_util.set_random_seed(1)
num_classes = len(valid_labels)
image = np.zeros_like(valid_image_4d)
labels = np.zeros_like(valid_labels)
class0, class1, class2 = non_empty_classes = [0, 2, 4]
labels[class0] = 1
labels[class0][3, 3, 3] = 0
labels[class0][3, 2, 3] = 0
labels[class1][3, 3, 3] = 1
labels[class2][3, 2, 3] = 1
mask = np.ones_like(valid_mask)
sample = Sample(image=image, labels=labels, mask=mask, metadata=DummyPatientMetadata)
crop_size = (1, 1, 1)
total_crops = 200
sampled_label_center_distribution = np.zeros(num_classes)
# If there is no class that has a non-zero weight and is present in the sample, there is no valid
# way to select a class, so we expect an exception to be thrown.
if class_weights is not None and sum(class_weights[c] for c in non_empty_classes) == 0:
with pytest.raises(ValueError):
random_crop(sample, crop_size, class_weights)
return
for _ in range(0, total_crops):
crop_sample, center = random_crop(sample, crop_size, class_weights)
sampled_class = list(labels[:, center[0], center[1], center[2]]).index(1)
sampled_label_center_distribution[sampled_class] += 1
sampled_label_center_distribution /= total_crops
if class_weights is None:
weight = 1.0 / len(non_empty_classes)
expected_label_center_distribution = [weight if c in non_empty_classes else 0.0
for c in range(number_of_classes)]
else:
total = sum(class_weights[c] for c in non_empty_classes)
expected_label_center_distribution = [class_weights[c] / total if c in non_empty_classes else 0.0
for c in range(number_of_classes)]
assert np.allclose(sampled_label_center_distribution, expected_label_center_distribution, atol=0.1)

Просмотреть файл

@ -0,0 +1,96 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import random
import numpy as np
import pytest
import torch
from PIL import Image
from torchvision.transforms.functional import to_pil_image, to_tensor
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma
image_size = (256, 256)
test_tensor_1channel_1slice = torch.ones([1, 1, *image_size], dtype=torch.float)
test_tensor_1channel_1slice[..., 100:150, 100:200] = 1 / 255.
min_test_image, max_test_image = test_tensor_1channel_1slice.min(), test_tensor_1channel_1slice.max()
test_tensor_2channels_2slices = torch.ones([2, 2, *image_size], dtype=torch.float)
test_tensor_2channels_2slices[..., 100:150, 100:200] = 1 / 255.
invalid_test_tensor = torch.ones([1, *image_size])
test_pil_image = Image.open(str(full_ml_test_data_path() / "image_and_contour.png")).convert("RGB")
test_image_as_tensor = to_tensor(test_pil_image).unsqueeze(0) # put in a [1, C, H, W] format
def test_add_gaussian_noise() -> None:
"""
Tests functionality of add gaussian noise
"""
# Test case of image with 1 channel, 1 slice (2D)
torch.manual_seed(10)
transformed = AddGaussianNoise(std=0.05, p_apply=1)(test_tensor_1channel_1slice.clone())
torch.manual_seed(10)
noise = torch.randn(size=(1, *image_size)) * 0.05
assert torch.isclose(
torch.clamp(test_tensor_1channel_1slice + noise, min_test_image, max_test_image), # type: ignore
transformed).all()
# Test p_apply = 0
untransformed = AddGaussianNoise(std=0.05, p_apply=0)(test_tensor_1channel_1slice.clone())
assert torch.isclose(untransformed, test_tensor_1channel_1slice).all()
# Check that it applies the same transform to all slices if number of slices > 1
torch.manual_seed(10)
transformed = AddGaussianNoise(std=0.05, p_apply=1)(test_tensor_2channels_2slices.clone())
assert torch.isclose(
torch.clamp(test_tensor_2channels_2slices + noise, min_test_image, max_test_image), # type: ignore
transformed).all()
def test_elastic_transform() -> None:
"""
Tests elastic transform
"""
np.random.seed(7)
transformed_image = ElasticTransform(sigma=4, alpha=34, p_apply=1.0)(test_image_as_tensor.clone())
transformed_pil = to_pil_image(transformed_image.squeeze(0))
expected_pil_image = Image.open(full_ml_test_data_path() / "elastic_transformed_image_and_contour.png").convert(
"RGB")
assert expected_pil_image == transformed_pil
untransformed_image = ElasticTransform(sigma=4, alpha=34, p_apply=0.0)(test_image_as_tensor.clone())
assert torch.isclose(test_image_as_tensor, untransformed_image).all()
def test_expand_channels() -> None:
with pytest.raises(ValueError):
ExpandChannels()(invalid_test_tensor)
tensor_img = ExpandChannels()(test_tensor_1channel_1slice.clone())
assert tensor_img.shape == torch.Size([1, 3, *image_size])
def test_random_gamma() -> None:
# This is invalid input (expects 4 dimensions)
with pytest.raises(ValueError):
RandomGamma(scale=(0.3, 3))(invalid_test_tensor)
random.seed(0)
transformed_1 = RandomGamma(scale=(0.3, 3))(test_tensor_1channel_1slice.clone())
assert transformed_1.shape == test_tensor_1channel_1slice.shape
tensor_img = torch.ones([2, 3, *image_size])
transformed_2 = RandomGamma(scale=(0.3, 3))(tensor_img)
# If you run on 1 channel, 1 Z dimension the gamma transform applied should be the same for all slices.
assert transformed_2.shape == torch.Size([2, 3, *image_size])
assert torch.isclose(transformed_2[0], transformed_2[1]).all()
assert torch.isclose(transformed_2[0, 1], transformed_2[0, 2]).all() and \
torch.isclose(transformed_2[0, 0], transformed_2[0, 2]).all()
human_readable_transformed = to_pil_image(RandomGamma(scale=(2, 3))(test_image_as_tensor).squeeze(0))
expected_pil_image = Image.open(full_ml_test_data_path() / "gamma_transformed_image_and_contour.png").convert("RGB")
assert expected_pil_image == human_readable_transformed

Просмотреть файл

@ -0,0 +1,165 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import random
import PIL
import pytest
import torch
from torchvision.transforms import CenterCrop, ColorJitter, RandomAffine, RandomErasing, RandomHorizontalFlip, \
RandomResizedCrop, Resize, ToTensor
from torchvision.transforms.functional import to_tensor
from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline, \
create_cxr_transforms_from_config
from Tests.SSL.test_data_modules import cxr_augmentation_config
import numpy as np
image_size = (32, 32)
crop_size = 24
test_image_as_array = np.ones(list(image_size)) * 255.
test_image_as_array[10:15, 10:20] = 1
test_image_as_pil = PIL.Image.fromarray(test_image_as_array).convert("L")
test_2d_image_as_CHW_tensor = to_tensor(test_image_as_array)
test_2d_image_as_ZCHW_tensor = test_2d_image_as_CHW_tensor.unsqueeze(0)
test_4d_scan_as_tensor = torch.ones([5, 4, *image_size]) * 255.
test_4d_scan_as_tensor[..., 10:15, 10:20] = 1
@pytest.mark.parametrize("use_different_transformation_per_channel", [True, False])
def test_torchvision_on_various_input(use_different_transformation_per_channel: bool) -> None:
"""
This tests that we can run transformation pipeline with out of the box torchvision transforms on various types
of input: PIL image, 3D tensor, 4D tensors. Tests that use_different_transformation_per_channel has the correct
behavior.
"""
transform = ImageTransformationPipeline(
[CenterCrop(crop_size),
RandomErasing(),
RandomAffine(degrees=(10, 12), shear=15, translate=(0.1, 0.3))
],
use_different_transformation_per_channel)
# Test PIL image input
transformed = transform(test_image_as_pil)
assert isinstance(transformed, torch.Tensor)
assert transformed.shape == torch.Size([1, crop_size, crop_size])
# Test image as [C, H. W] tensor
transformed = transform(test_2d_image_as_CHW_tensor.clone())
assert isinstance(transformed, torch.Tensor)
assert transformed.shape == torch.Size([1, crop_size, crop_size])
# Test image as [1, 1, H, W]
transformed = transform(test_2d_image_as_ZCHW_tensor)
assert isinstance(transformed, torch.Tensor)
assert transformed.shape == torch.Size([1, 1, crop_size, crop_size])
# Test with a fake 4D scan [C, Z, H, W] -> [25, 34, 32, 32]
transformed = transform(test_4d_scan_as_tensor)
assert isinstance(transformed, torch.Tensor)
assert transformed.shape == torch.Size([5, 4, crop_size, crop_size])
# Same transformation should be applied to all slices and channels.
assert torch.isclose(transformed[0, 0], transformed[1, 1]).all() != use_different_transformation_per_channel
@pytest.mark.parametrize("use_different_transformation_per_channel", [True, False])
def test_custom_tf_on_various_input(use_different_transformation_per_channel: bool) -> None:
"""
This tests that we can run transformation pipeline with our custom transforms on various types
of input: PIL image, 3D tensor, 4D tensors. Tests that use_different_transformation_per_channel has the correct
behavior. The transforms are test individually in test_image_transforms.py
"""
pipeline = ImageTransformationPipeline(
[ElasticTransform(sigma=4, alpha=34, p_apply=1),
AddGaussianNoise(p_apply=1, std=0.05),
RandomGamma(scale=(0.3, 3))
],
use_different_transformation_per_channel)
# Test PIL image input
transformed = pipeline(test_image_as_pil)
assert transformed.shape == test_2d_image_as_CHW_tensor.shape
# Test image as [C, H, W] tensor
pipeline(test_2d_image_as_CHW_tensor)
assert transformed.shape == test_2d_image_as_CHW_tensor.shape
# Test image as [1, 1, H, W]
transformed = pipeline(test_2d_image_as_ZCHW_tensor)
assert isinstance(transformed, torch.Tensor)
assert transformed.shape == torch.Size([1, 1, *image_size])
# Test with a fake scan [C, Z, H, W] -> [25, 34, 32, 32]
transformed = pipeline(test_4d_scan_as_tensor)
assert isinstance(transformed, torch.Tensor)
assert transformed.shape == test_4d_scan_as_tensor.shape
# Same transformation should be applied to all slices and channels.
assert torch.isclose(transformed[0, 0], transformed[1, 1]).all() != use_different_transformation_per_channel
def test_create_transform_pipeline_from_config() -> None:
"""
Tests that the pipeline returned by create_transform_pipeline_from_config returns the expected transformation.
"""
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=True)
fake_cxr_as_array = np.ones([256, 256]) * 255.
fake_cxr_as_array[100:150, 100:200] = 1
fake_cxr_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")
all_transforms = [ExpandChannels(),
RandomAffine(degrees=180, translate=(0, 0), shear=40),
RandomResizedCrop(scale=(0.4, 1.0), size=256),
RandomHorizontalFlip(p=0.5),
RandomGamma(scale=(0.5, 1.5)),
ColorJitter(saturation=0, brightness=0.2, contrast=0.2),
ElasticTransform(sigma=4, alpha=34, p_apply=0.4),
CenterCrop(size=224),
RandomErasing(scale=(0.15, 0.4), ratio=(0.33, 3)),
AddGaussianNoise(std=0.05, p_apply=0.5)
]
np.random.seed(3)
torch.manual_seed(3)
random.seed(3)
transformed_image = transformation_pipeline(fake_cxr_image)
assert isinstance(transformed_image, torch.Tensor)
# Expected pipeline
image = np.ones([256, 256]) * 255.
image[100:150, 100:200] = 1
image = PIL.Image.fromarray(image).convert("L")
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
image = ToTensor()(image).reshape([1, 1, 256, 256])
np.random.seed(3)
torch.manual_seed(3)
random.seed(3)
expected_transformed = image
for t in all_transforms:
expected_transformed = t(expected_transformed)
# The pipeline takes as input [C, Z, H, W] and returns [C, Z, H, W]
# But the transforms list expect [Z, C, H, W] and returns [Z, C, H, W] so need to permute dimension to compare
expected_transformed = torch.transpose(expected_transformed, 1, 0).squeeze(1)
assert torch.isclose(expected_transformed, transformed_image).all()
# Test the evaluation pipeline
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=False)
transformed_image = transformation_pipeline(image)
assert isinstance(transformed_image, torch.Tensor)
all_transforms = [ExpandChannels(), Resize(size=256), CenterCrop(size=224)]
expected_transformed = image
for t in all_transforms:
expected_transformed = t(expected_transformed)
expected_transformed = torch.transpose(expected_transformed, 1, 0).squeeze(1)
assert torch.isclose(expected_transformed, transformed_image).all()

Просмотреть файл

@ -652,7 +652,7 @@ S4,label,,False,3.0
traverse_dirs_when_loading=True,
local_dataset=test_output_dirs.root_dir)
raw_dataset = ScalarDataset(args, data_frame=df)
normalized = ScalarDataset(args, data_frame=df, sample_transforms=WindowNormalizationForScalarItem())
normalized = ScalarDataset(args, data_frame=df, sample_transform=WindowNormalizationForScalarItem())
assert len(raw_dataset) == 4
for i in range(4):
raw_item = raw_dataset[i]

Просмотреть файл

@ -10,12 +10,15 @@ import numpy as np
import pandas as pd
import pytest
import torch
from torchvision.transforms import ColorJitter, RandomAffine
from InnerEye.Common import common_util
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, logging_to_stdout
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.metrics_constants import LoggingColumns, MetricType, SEQUENCE_POSITION_HUE_NAME_PREFIX
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
from InnerEye.ML.deep_learning_config import TemperatureScalingConfig
from InnerEye.ML.lightning_models import transfer_batch_to_device
@ -27,7 +30,7 @@ from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.scalar_config import ScalarLoss
from InnerEye.ML.sequence_config import SEQUENCE_LENGTH_FILE, SEQUENCE_LENGTH_STATS_FILE, SequenceModelBase
from InnerEye.ML.utils import ml_util
from InnerEye.ML.utils.augmentation import RandAugmentSlice, ScalarItemAugmentation
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
from InnerEye.ML.utils.io_util import ImageAndSegmentations
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling, get_scalar_model_inputs_and_labels
@ -107,12 +110,12 @@ class ToySequenceModel(SequenceModelBase):
proportion_val=0.1,
)
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
if self.use_combined_model:
return ModelTransformsPerExecutionMode(
train=ScalarItemAugmentation(
transform=RandAugmentSlice(use_joint_channel_transformation=False,
is_transformation_for_segmentation_maps=True)))
train=ImageTransformationPipeline(
transforms=[RandomAffine(degrees=30, translate=(0.1, 0.1), shear=15),
ColorJitter(brightness=0.2)]))
else:
return ModelTransformsPerExecutionMode()

Просмотреть файл

@ -12,19 +12,20 @@ import numpy as np
import pandas as pd
import pytest
import torch
from torchvision.transforms import ColorJitter, RandomAffine
from InnerEye.Common import common_util
from InnerEye.Common.common_util import logging_to_stdout
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.Common.type_annotations import TupleInt3
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset, ScalarItemAugmentation
from InnerEye.ML.lightning_models import transfer_batch_to_device
from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImageEncoderWithMlp, \
ImagingFeatureType
from InnerEye.ML.run_ml import MLRunner
from InnerEye.ML.scalar_config import AggregationType, ScalarLoss, ScalarModelBase, get_non_image_features_dict
from InnerEye.ML.utils.augmentation import RandAugmentSlice, ScalarItemAugmentation
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
from InnerEye.ML.utils.image_util import HDF5_NUM_SEGMENTATION_CLASSES, segmentation_to_one_hot
from InnerEye.ML.utils.io_util import ImageAndSegmentations, NumpyFile
@ -101,15 +102,24 @@ class ImageEncoder(ScalarModelBase):
def get_post_loss_logits_normalization_function(self) -> Callable:
return torch.nn.Sigmoid()
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
"""
Get transforms to perform on image samples for each model execution mode.
"""
return ModelTransformsPerExecutionMode(
train=ScalarItemAugmentation(
RandAugmentSlice(is_transformation_for_segmentation_maps=(
self.imaging_feature_type == ImagingFeatureType.Segmentation
or self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation))))
if self.imaging_feature_type in [ImagingFeatureType.Image, ImagingFeatureType.ImageAndSegmentation]:
return ModelTransformsPerExecutionMode(
train=ImageTransformationPipeline(
transforms=[RandomAffine(10), ColorJitter(0.2)],
use_different_transformation_per_channel=True))
return ModelTransformsPerExecutionMode()
def get_segmentation_transform(self) -> ModelTransformsPerExecutionMode:
if self.imaging_feature_type in [ImagingFeatureType.Segmentation, ImagingFeatureType.ImageAndSegmentation]:
return ModelTransformsPerExecutionMode(
train=ImageTransformationPipeline(
transforms=[RandomAffine(10), ColorJitter(0.2)],
use_different_transformation_per_channel=True))
return ModelTransformsPerExecutionMode()
@pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows")
@ -177,9 +187,10 @@ S3,week1,scan3.npy,True,6,60,Male,Val2
)
config_for_dataset.read_dataset_into_dataframe_and_pre_process()
dataset = ScalarDataset(config_for_dataset,
sample_transforms=ScalarItemAugmentation(
RandAugmentSlice(is_transformation_for_segmentation_maps=False)))
dataset = ScalarDataset(
config_for_dataset,
sample_transform=ScalarItemAugmentation(ImageTransformationPipeline([RandomAffine(10), ColorJitter(0.2)],
use_different_transformation_per_channel=True)))
assert len(dataset) == 3
config = ImageEncoder(

Просмотреть файл

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d4fb26944ad1fef4265cb52b64160865a940cd28ed4ad2bb7bf1c90458622bf6
size 39050

Просмотреть файл

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:775d542598a607881a313c0ac828ffc47dd86b2567bcb4366646b7ea14a121d8
size 3923

Просмотреть файл

@ -1,405 +0,0 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import random
from typing import Any, Callable, List
import numpy as np
import pytest
import torch
from torchvision.transforms import functional as TF
from InnerEye.ML.dataset.sample import Sample
from InnerEye.ML.utils import augmentation, ml_util
from InnerEye.ML.utils.augmentation import ImageTransformationBase
from Tests.ML.util import DummyPatientMetadata
ml_util.set_random_seed(1)
image_size = (8, 8, 8)
valid_image_4d = np.random.uniform(size=((5,) + image_size)) * 10
valid_mask = np.random.randint(2, size=image_size)
number_of_classes = 5
class_assignments = np.random.randint(2, size=image_size)
valid_labels = np.zeros((number_of_classes,) + image_size)
for c in range(number_of_classes):
valid_labels[c, class_assignments == c] = 1
valid_crop_size = (2, 2, 2)
valid_full_crop_size = image_size
valid_class_weights = [0.5] + [0.5 / (number_of_classes - 1)] * (number_of_classes - 1)
crop_size_requires_padding = (9, 8, 12)
# Random Crop Tests
def test_valid_full_crop() -> None:
metadata = DummyPatientMetadata
sample, _ = augmentation.random_crop(sample=Sample(image=valid_image_4d,
labels=valid_labels,
mask=valid_mask,
metadata=metadata),
crop_size=valid_full_crop_size,
class_weights=valid_class_weights)
assert np.array_equal(sample.image, valid_image_4d)
assert np.array_equal(sample.labels, valid_labels)
assert np.array_equal(sample.mask, valid_mask)
assert sample.metadata == metadata
@pytest.mark.parametrize("image", [None, list(), valid_image_4d])
@pytest.mark.parametrize("labels", [None, list(), valid_labels])
@pytest.mark.parametrize("mask", [None, list(), valid_mask])
@pytest.mark.parametrize("class_weights", [[0, 0, 0], [0], [-1, 0, 1], [-1, -2, -3], valid_class_weights])
def test_invalid_arrays(image: Any, labels: Any, mask: Any, class_weights: Any) -> None:
"""
Tests failure cases of the random_crop function for invalid image, labels, mask or class
weights arguments.
"""
# Skip the final combination, because it is valid
if not (np.array_equal(image, valid_image_4d) and np.array_equal(labels, valid_labels)
and np.array_equal(mask, valid_mask) and class_weights == valid_class_weights):
with pytest.raises(Exception):
augmentation.random_crop(Sample(metadata=DummyPatientMetadata, image=image, labels=labels, mask=mask),
valid_crop_size, class_weights)
@pytest.mark.parametrize("crop_size", [None, ["a"], 5])
def test_invalid_crop_arg(crop_size: Any) -> None:
with pytest.raises(Exception):
augmentation.random_crop(
Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=valid_mask),
crop_size, valid_class_weights)
@pytest.mark.parametrize("crop_size", [[2, 2], [2, 2, 2, 2], [10, 10, 10]])
def test_invalid_crop_size(crop_size: Any) -> None:
with pytest.raises(Exception):
augmentation.random_crop(
Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=valid_mask),
crop_size, valid_class_weights)
def test_random_crop_no_fg() -> None:
with pytest.raises(Exception):
augmentation.random_crop(Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels,
mask=np.zeros_like(valid_mask)),
valid_crop_size, valid_class_weights)
with pytest.raises(Exception):
augmentation.random_crop(Sample(metadata=DummyPatientMetadata, image=valid_image_4d,
labels=np.zeros_like(valid_labels), mask=valid_mask),
valid_crop_size, valid_class_weights)
@pytest.mark.parametrize("crop_size", [valid_crop_size])
def test_random_crop(crop_size: Any) -> None:
labels = valid_labels
# create labels such that there are no foreground voxels in a particular class
# this should ne handled gracefully (class being ignored from sampling)
labels[0] = 1
labels[1] = 0
sample, _ = augmentation.random_crop(Sample(
image=valid_image_4d,
labels=valid_labels,
mask=valid_mask,
metadata=DummyPatientMetadata
), crop_size, valid_class_weights)
expected_img_crop_size = (valid_image_4d.shape[0], *crop_size)
expected_labels_crop_size = (valid_labels.shape[0], *crop_size)
assert sample.image.shape == expected_img_crop_size
assert sample.labels.shape == expected_labels_crop_size
assert sample.mask.shape == tuple(crop_size)
@pytest.mark.parametrize("class_weights",
[None, [0, 0.5, 0.5, 0, 0], [0.1, 0.45, 0.45, 0, 0], [0.5, 0.25, 0.25, 0, 0],
[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0.04, 0.12, 0.20, 0.28, 0.36], [0, 0.5, 0, 0.5, 0]])
def test_valid_class_weights(class_weights: List[float]) -> None:
"""
Produce a large number of crops and make sure the crop center class proportions respect class weights
"""
ml_util.set_random_seed(1)
num_classes = len(valid_labels)
image = np.zeros_like(valid_image_4d)
labels = np.zeros_like(valid_labels)
class0, class1, class2 = non_empty_classes = [0, 2, 4]
labels[class0] = 1
labels[class0][3, 3, 3] = 0
labels[class0][3, 2, 3] = 0
labels[class1][3, 3, 3] = 1
labels[class2][3, 2, 3] = 1
mask = np.ones_like(valid_mask)
sample = Sample(image=image, labels=labels, mask=mask, metadata=DummyPatientMetadata)
crop_size = (1, 1, 1)
total_crops = 200
sampled_label_center_distribution = np.zeros(num_classes)
# If there is no class that has a non-zero weight and is present in the sample, there is no valid
# way to select a class, so we expect an exception to be thrown.
if class_weights is not None and sum(class_weights[c] for c in non_empty_classes) == 0:
with pytest.raises(ValueError):
augmentation.random_crop(sample, crop_size, class_weights)
return
for _ in range(0, total_crops):
crop_sample, center = augmentation.random_crop(sample, crop_size, class_weights)
sampled_class = list(labels[:, center[0], center[1], center[2]]).index(1)
sampled_label_center_distribution[sampled_class] += 1
sampled_label_center_distribution /= total_crops
if class_weights is None:
weight = 1.0 / len(non_empty_classes)
expected_label_center_distribution = [weight if c in non_empty_classes else 0.0
for c in range(number_of_classes)]
else:
total = sum(class_weights[c] for c in non_empty_classes)
expected_label_center_distribution = [class_weights[c] / total if c in non_empty_classes else 0.0
for c in range(number_of_classes)]
assert np.allclose(sampled_label_center_distribution, expected_label_center_distribution, atol=0.1)
def _check_transformation_result(image_as_tensor: torch.Tensor,
transformation: Callable,
expected: torch.Tensor) -> None:
test_tensor_pil = TF.to_pil_image(image_as_tensor)
transformed = TF.to_tensor(transformation(test_tensor_pil)).squeeze()
np.testing.assert_allclose(transformed, expected, rtol=0.02)
# Augmentation pipeline tests
@pytest.mark.parametrize(["transformation", "expected"],
[(ImageTransformationBase.horizontal_flip(), torch.tensor([[0, 1, 2],
[0, 2, 1],
[2, 0, 0]])),
(ImageTransformationBase.rotate(45), torch.tensor([[1, 0, 0],
[2, 2, 2],
[1, 0, 0]])),
(ImageTransformationBase.translateX(0.3), torch.tensor([[0, 2, 1],
[0, 1, 2],
[0, 0, 0]])),
(ImageTransformationBase.translateY(0.3), torch.tensor([[0, 0, 0],
[2, 1, 0],
[1, 2, 0]])),
(ImageTransformationBase.identity(), torch.tensor([[2, 1, 0],
[1, 2, 0],
[0, 0, 2]]))])
def test_transformations_for_segmentations(transformation: Callable, expected: torch.Tensor) -> None:
"""
Tests each individual transformation of the ImageTransformationBase class on a 2D input representing
a segmentation map.
"""
image_as_tensor = torch.tensor([[2, 1, 0],
[1, 2, 0],
[0, 0, 2]], dtype=torch.int32)
_check_transformation_result(image_as_tensor, transformation, expected)
def test_invalid_segmentation_type() -> None:
"""
Validates the necessity of converting segmentation maps to int before PIL
conversion.
"""
image_as_tensor = torch.tensor([[2, 1, 0],
[1, 2, 0],
[0, 0, 2]], dtype=torch.float32)
expected = torch.tensor([[1, 0, 0], [2, 2, 2], [1, 0, 0]])
with pytest.raises(AssertionError):
_check_transformation_result(image_as_tensor, ImageTransformationBase.rotate(45), expected)
@pytest.mark.parametrize(["transformation", "expected"],
[(ImageTransformationBase.horizontal_flip(), torch.tensor([[0.1, 0.5, 1],
[0.1, 1, 0.5],
[1, 0.1, 0.1]])),
(ImageTransformationBase.adjust_contrast(2), torch.tensor([[1., 0.509804, 0.],
[0.509804, 1., 0.],
[0., 0., 1.]])),
(ImageTransformationBase.adjust_brightness(2), torch.tensor([[1.0000, 0.9961, 0.1961],
[0.9961, 1.0000, 0.1961],
[0.1961, 0.1961, 1.0000]])),
(ImageTransformationBase.adjust_contrast(0), torch.tensor([[0.4863, 0.4863, 0.4863],
[0.4863, 0.4863, 0.4863],
[0.4863, 0.4863, 0.4863]])),
(ImageTransformationBase.adjust_brightness(0), torch.tensor([[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]]))])
def test_transformation_image(transformation: Callable, expected: torch.Tensor) -> None:
"""
Tests each individual transformation of the ImageTransformationBase class on a 2D input representing
a natural image.
"""
image_as_tensor = torch.tensor([[1, 0.5, 0.1],
[0.5, 1, 0.1],
[0.1, 0.1, 1]], dtype=torch.float32)
_check_transformation_result(image_as_tensor, transformation, expected)
def test_apply_transformations() -> None:
"""
Testing the function applying a series of transformations to a given image
"""
operations = [ImageTransformationBase.identity(), ImageTransformationBase.translateX(0.3),
ImageTransformationBase.horizontal_flip()]
# Case 1 on segmentations
image_as_tensor = torch.tensor([[[2, 1, 0], [1, 2, 0], [0, 0, 2]],
[[2, 1, 0], [1, 2, 0], [0, 0, 2]]], dtype=torch.int32)
transformed_tensor = ImageTransformationBase.apply_transform_on_3d_image(image=image_as_tensor,
transforms=operations)
expected = torch.tensor([[[1, 2, 0], [2, 1, 0], [0, 0, 0]],
[[1, 2, 0], [2, 1, 0], [0, 0, 0]]])
assert torch.all(expected == transformed_tensor)
# Case 2 on image
image_as_tensor = torch.tensor([[[1, 0.5, 0.1], [0.5, 1, 0.1], [0.1, 0.1, 1]],
[[1, 0.5, 0.1], [0.5, 1, 0.1], [0.1, 0.1, 1]]], dtype=torch.float32)
transformed_tensor = ImageTransformationBase.apply_transform_on_3d_image(image=image_as_tensor,
transforms=operations)
expected = torch.tensor([[[0.5, 1, 0], [1, 0.5, 0], [0.1, 0.1, 0]],
[[0.5, 1, 0], [1, 0.5, 0], [0.1, 0.1, 0]]])
np.testing.assert_allclose(transformed_tensor, expected, rtol=0.02)
def _compute_expected_pipeline_result(transformations: List[List[Callable]],
input_image: torch.Tensor) -> torch.Tensor:
expected = input_image.clone()
expected[0] = ImageTransformationBase.apply_transform_on_3d_image(expected[0],
transformations[0])
expected[1] = ImageTransformationBase.apply_transform_on_3d_image(expected[1],
transformations[1])
return expected
def test_RandAugment_pipeline() -> None:
"""
Test the RandAugment transformation pipeline for online data augmentation.
"""
# Set random seeds for transformations
np.random.seed(1)
random.seed(0)
# Get inputs
one_channel_image = torch.tensor([[[2, 1, 0], [1, 2, 0], [0, 0, 2]],
[[2, 1, 0], [1, 2, 0], [0, 0, 2]]], dtype=torch.int32)
two_channel_image = torch.stack((one_channel_image, one_channel_image), dim=0)
# Case no transformation applied
pipeline = augmentation.RandAugmentSlice(magnitude=3,
n_transforms=0,
is_transformation_for_segmentation_maps=True)
transformed = pipeline(two_channel_image)
assert torch.all(two_channel_image == transformed)
# Case separate transformation per channel
pipeline = augmentation.RandAugmentSlice(magnitude=3,
n_transforms=1,
is_transformation_for_segmentation_maps=True,
use_joint_channel_transformation=False)
expected_sampled_ops_channel_1 = [ImageTransformationBase.translateY(-0.3 * 0.2)]
expected_sampled_ops_channel_2 = [ImageTransformationBase.horizontal_flip()]
expected = _compute_expected_pipeline_result(transformations=[expected_sampled_ops_channel_1,
expected_sampled_ops_channel_2],
input_image=two_channel_image)
transformed = pipeline(two_channel_image)
assert torch.all(transformed == expected)
# Case same transformation for all channels
pipeline = augmentation.RandAugmentSlice(magnitude=5,
n_transforms=2,
is_transformation_for_segmentation_maps=True,
use_joint_channel_transformation=True)
transformed = pipeline(two_channel_image)
expected_sampled_ops_channel = [ImageTransformationBase.rotate(0.5 * 30),
ImageTransformationBase.translateY(0.5 * 0.2)]
expected = _compute_expected_pipeline_result(transformations=[expected_sampled_ops_channel,
expected_sampled_ops_channel],
input_image=two_channel_image)
assert torch.all(transformed == expected)
# Case for images
two_channel_image = two_channel_image / 2.0
pipeline = augmentation.RandAugmentSlice(magnitude=3, n_transforms=1,
use_joint_channel_transformation=True,
is_transformation_for_segmentation_maps=False)
transformed = pipeline(two_channel_image)
expected_sampled_ops_channel = [ImageTransformationBase.adjust_contrast(1 - 0.3)]
expected = _compute_expected_pipeline_result(transformations=[expected_sampled_ops_channel,
expected_sampled_ops_channel],
input_image=two_channel_image)
assert torch.all(transformed == expected)
def test_RandomSliceTransformation_pipeline() -> None:
"""
Test the RandomSerial transformation pipeline for online data augmentation.
"""
# Set random seeds for transformations
np.random.seed(1)
random.seed(0)
one_channel_image = torch.tensor([[[2, 1, 0], [1, 2, 0], [0, 0, 2]],
[[2, 1, 0], [1, 2, 0], [0, 0, 2]]], dtype=torch.int32)
image_with_two_channels = torch.stack((one_channel_image, one_channel_image), dim=0)
# Case no transformation applied
pipeline = augmentation.RandomSliceTransformation(probability_transformation=0,
is_transformation_for_segmentation_maps=True)
transformed = pipeline(image_with_two_channels)
assert torch.all(image_with_two_channels == transformed)
# Case separate transformation per channel
pipeline = augmentation.RandomSliceTransformation(probability_transformation=1,
is_transformation_for_segmentation_maps=True,
use_joint_channel_transformation=False)
transformed = pipeline(image_with_two_channels)
expected_transformations_channel_1 = [ImageTransformationBase.rotate(-7),
ImageTransformationBase.translateX(0.011836899667533166),
ImageTransformationBase.translateY(-0.04989873172751189),
ImageTransformationBase.horizontal_flip()]
expected_transformations_channel_2 = [ImageTransformationBase.rotate(-1),
ImageTransformationBase.translateX(-0.04012366553408523),
ImageTransformationBase.translateY(-0.08525151327425498),
ImageTransformationBase.horizontal_flip()]
expected = _compute_expected_pipeline_result(transformations=[expected_transformations_channel_1,
expected_transformations_channel_2],
input_image=image_with_two_channels)
assert torch.all(transformed == expected)
# Case same transformation for all channels
pipeline = augmentation.RandomSliceTransformation(probability_transformation=1,
is_transformation_for_segmentation_maps=True,
use_joint_channel_transformation=True)
transformed = pipeline(image_with_two_channels)
expected_transformations_channel_1 = [ImageTransformationBase.rotate(9),
ImageTransformationBase.translateX(-0.0006422133534675356),
ImageTransformationBase.translateY(0.07352055509855618),
ImageTransformationBase.horizontal_flip()]
expected = _compute_expected_pipeline_result(transformations=[expected_transformations_channel_1,
expected_transformations_channel_1],
input_image=image_with_two_channels)
assert torch.all(transformed == expected)
# Case for images - convert to range 0-1 first
image_with_two_channels = image_with_two_channels / 4.0
pipeline = augmentation.RandomSliceTransformation(probability_transformation=1,
is_transformation_for_segmentation_maps=False,
use_joint_channel_transformation=True)
transformed = pipeline(image_with_two_channels)
expected_transformations_channel_1 = [ImageTransformationBase.rotate(8),
ImageTransformationBase.translateX(-0.02782961037858135),
ImageTransformationBase.translateY(0.06066901082924045),
ImageTransformationBase.adjust_contrast(0.2849887878176489),
ImageTransformationBase.adjust_brightness(1.0859799153800245)]
expected = _compute_expected_pipeline_result(transformations=[expected_transformations_channel_1,
expected_transformations_channel_1],
input_image=image_with_two_channels)
assert torch.all(transformed == expected)

Просмотреть файл

@ -18,13 +18,13 @@ from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataMod
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
InnerEyeCIFARTrainTransform, get_cxr_ssl_transforms
from InnerEye.ML.SSL.lightning_containers.ssl_container import SSLContainer, SSLDatasetName
from InnerEye.ML.SSL.utils import SSLDataModuleType, load_ssl_augmentation_config
from InnerEye.ML.SSL.utils import SSLDataModuleType, load_yaml_augmentation_config
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_encoder_augmentation_cxr
from Tests.SSL.test_ssl_containers import _create_test_cxr_data
path_to_test_dataset = full_ml_test_data_path("cxr_test_dataset")
_create_test_cxr_data(path_to_test_dataset)
cxr_augmentation_config = load_ssl_augmentation_config(path_encoder_augmentation_cxr)
cxr_augmentation_config = load_yaml_augmentation_config(path_encoder_augmentation_cxr)
@pytest.mark.skipif(is_windows(), reason="Too slow on windows")

Просмотреть файл

@ -1,108 +0,0 @@
import random
import PIL
import numpy as np
import pytest
import torch
from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision.transforms import ToTensor
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import AddGaussianNoise, CenterCrop, ElasticTransform, \
ExpandChannels, RandomAffine, RandomColorJitter, RandomErasing, RandomGamma, RandomHorizontalFlip, RandomResizeCrop, \
Resize, create_chest_xray_transform
from Tests.SSL.test_data_modules import cxr_augmentation_config
def test_add_gaussian_noise() -> None:
"""
Tests functionality of add gaussian noise
"""
np.random.seed(1)
torch.manual_seed(10)
array = np.ones([1, 256, 256]) * 255.
array[0, 100:150, 100:200] = 1
tensor_img = torch.tensor(array / 255.)
transformed = AddGaussianNoise(cxr_augmentation_config)(tensor_img)
torch.manual_seed(10)
noise = torch.randn(size=(1, 256, 256)) * 0.05
assert torch.isclose(torch.clamp(tensor_img + noise, 0, 1), transformed).all()
with pytest.raises(AssertionError):
AddGaussianNoise(cxr_augmentation_config)(tensor_img * 255.)
def test_elastic_transform() -> None:
"""
Tests elastic transform
"""
image = np.ones([256, 256]) * 255.
image[100:150, 100:200] = 1
# Computed expected transform
np.random.seed(7)
np.random.random(1)
shape = (256, 256)
dx = gaussian_filter((np.random.random(shape) * 2 - 1), 4, mode="constant", cval=0) * 34
dy = gaussian_filter((np.random.random(shape) * 2 - 1), 4, mode="constant", cval=0) * 34
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))
expected_array = map_coordinates(image, indices, order=1).reshape(shape)
# Actual transform
np.random.seed(7)
transformed_image = np.asarray(ElasticTransform(cxr_augmentation_config)(PIL.Image.fromarray(image)))
assert np.isclose(expected_array, transformed_image).all()
def test_expand_channels() -> None:
image = np.ones([1, 256, 256]) * 255.
tensor_img = torch.tensor(image)
tensor_img = ExpandChannels()(tensor_img)
assert tensor_img.shape == torch.Size([3, 256, 256])
assert torch.isclose(tensor_img[0], tensor_img[1]).all() and torch.isclose(tensor_img[1], tensor_img[2]).all()
def test_create_chest_xray_transform() -> None:
"""
Tests that the pipeline returned by create_chest_xray_transform returns the expected transformation.
"""
transform = create_chest_xray_transform(cxr_augmentation_config, apply_augmentations=True)
image = np.ones([256, 256]) * 255.
image[100:150, 100:200] = 1
image = PIL.Image.fromarray(image).convert("L")
np.random.seed(3)
torch.manual_seed(3)
random.seed(3)
transformed_image = transform(image)
# Expected pipeline
np.random.seed(3)
torch.manual_seed(3)
random.seed(3)
image = RandomAffine(cxr_augmentation_config)(image)
image = RandomResizeCrop(cxr_augmentation_config)(image)
image = Resize(cxr_augmentation_config)(image)
image = RandomHorizontalFlip(cxr_augmentation_config)(image)
image = RandomGamma(cxr_augmentation_config)(image)
image = RandomColorJitter(cxr_augmentation_config)(image)
image = ElasticTransform(cxr_augmentation_config)(image)
image = CenterCrop(cxr_augmentation_config)(image)
image = ToTensor()(image)
image = RandomErasing(cxr_augmentation_config)(image)
image = AddGaussianNoise(cxr_augmentation_config)(image)
image = ExpandChannels()(image)
assert torch.isclose(image, transformed_image).all()
# Test the evaluation pipeline
transform = create_chest_xray_transform(cxr_augmentation_config, apply_augmentations=False)
image = np.ones([256, 256]) * 255.
image[100:150, 100:200] = 1
image = PIL.Image.fromarray(image).convert("L")
transformed_image = transform(image)
# Expected pipeline
image = Resize(cxr_augmentation_config)(image)
image = CenterCrop(cxr_augmentation_config)(image)
image = ToTensor()(image)
image = ExpandChannels()(image)
assert torch.isclose(image, transformed_image).all()

Просмотреть файл

@ -283,3 +283,35 @@ runs are uploaded to the parent run, in the `CrossValResults` directory. This co
There is also a directory `BaselineComparisons`, containing the Wilcoxon test results and
scatterplots for the ensemble, as described above for single runs.
### Augmentations for classification models.
For classification models, you can define an augmentation pipeline to apply to your images input (resp. segmentations) at
training, validation and test time. In order to define such a series of transformations, you will need to overload the
`get_image_transform` (resp. `get_segmention_transform`) method of your config class. This method expects you to return
a `ModelTransformsPerExecutionMode`, that maps each execution mode to one transform function. We also provide the
`ImageTransformationPipeline` a class that creates a pipeline of transforms, from a list of individual transforms and
ensures the correct conversion of 2D or 3D PIL.Image or tensor inputs to the obtained pipeline.
`ImageTransformationPipeline` takes two arguments for its constructor:
* `transforms`: a list of image transforms, in particular you can feed in standard [torchvision transforms](https://pytorch.org/vision/0.8/transforms.html) or
any other transforms as long as they support an input `[Z, C, H, W]` (where Z is the 3rd dimension (1 for 2D images),
C number of channels, H and W the height and width of each 2D slide - this is supported for standard torchvision
transforms.). You can also define your own transforms as long as they expect such a `[Z, C, H, W]` input. You can
find some examples of custom transforms class in `InnerEye/ML/augmentation/image_transforms.py`.
* `use_different_transformation_per_channel`: if True, apply a different version of the augmentation pipeline
for each channel. If False, applies the same transformation to each channel, separately. Default to False.
Below you can find an example of `get_image_transform` that would resize your input images to 256 x 256, and at
training time only apply random rotation of +/- 10 degrees, and apply some brightness distortion,
using standard pytorch vision transforms.
```python
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
"""
Get transforms to perform on image samples for each model execution mode.
"""
return ModelTransformsPerExecutionMode(
train=ImageTransformationPipeline(transforms=[Resize(256), RandomAffine(degrees=10), ColorJitter(brightness=0.2)]),
val=ImageTransformationPipeline(transforms=[Resize(256)]),
test=ImageTransformationPipeline(transforms=[Resize(256)]))
```