Update augmentation pipeline (#458)
* Refractor augmentations initial commit * Test custom transforms and transform pipeline * Make transforms independent of config * Add placeholder to doc as reminder * Update tests * Update tests * Add test for affine * Add tests for color transform * Add a todo list * Adding new test and simplifying the transforms API further * New tests for new custom transforms * Remove deprecated file * Update all transforms * Add test for custom pipeline transform * Fix the full pipeline test * Rename * SImplify feeding in transforms * Updating CHANGELOG.md * Updating CHANGELOG.md * Flake8 * Simplify more * Fixes * Fixes * Flake8 * Fix few things * Fix few things * Updating tests for new API * Simplify image versus sample * Update * Update * Update * Update * Error in test copy paste * Style * Add augmentation doc * Add augmentation doc * Doc strings * Flake8 * Flake8 * Mypy * Fix test and docstring * Put back skipif windows * flake8 * mypy * Skip if OOM * Skip if OOM * Import was missing * Separate transform pipeline from ScalarItem management. * Update documentation to new API * Accidental commit * PR suggestion * Make test for elastic human checkable * Make test for elastic human checkable * Update with p=0 * Update with p=0 * Add human readable test output for gamma transform * Solve OOM, remove duplicate test input definitions * Flake8 * Mypy * Fix config * Add check PR comment * PR comment
This commit is contained in:
Родитель
2af8c6099c
Коммит
51274c8bdc
14
CHANGELOG.md
14
CHANGELOG.md
|
@ -108,6 +108,18 @@ console for easier diagnostics.
|
|||
- ([#444](https://github.com/microsoft/InnerEye-DeepLearning/pull/444)) The method `before_training_on_rank_zero` of
|
||||
the `LightningContainer` class has been renamed to `before_training_on_global_rank_zero`. The order in which the
|
||||
hooks are called has been changed.
|
||||
- ([#458](https://github.com/microsoft/InnerEye-DeepLearning/pull/458)) Simplifying and generalizing the way we handle
|
||||
data augmentations for classification models. The pipelining logic is now taken care of by a ImageTransformPipeline
|
||||
class that takes as input a list of transforms to chain together. This pipeline takes of applying transforms on 3D or
|
||||
2D images. The user can choose to apply the same transformation for all channels (RGB example) or whether to apply
|
||||
different transformation for each channel (if each channel represents a different
|
||||
modality / time point for example). The pipeline can now work directly with out-of-the box torchvision transform
|
||||
(as long as they support [..., C, H, W] inputs). This allows to get rid of nearly all of our custom augmentations
|
||||
functions. The conversion from pipeline of image transformation to ScalarItemAugmentation is now taken care of under
|
||||
the hood, the user does not need to call this wrapper for each config class. In models derived from ScalarModelConfig
|
||||
to change which augmentations are applied to the images inputs (resp. segmentations inputs), users can override
|
||||
`get_image_transform` (resp. `get_segmentation_transform`). These two functions replace the old
|
||||
`get_image_sample_transforms` method. See `docs/building_models.md` for more information on augmentations.
|
||||
|
||||
### Fixed
|
||||
|
||||
|
@ -128,6 +140,8 @@ console for easier diagnostics.
|
|||
- ([#450](https://github.com/microsoft/InnerEye-DeepLearning/pull/450)) Delete unused `classification_report.ipynb`.
|
||||
- ([#455](https://github.com/microsoft/InnerEye-DeepLearning/pull/455)) Removed the AzureRunner conda environment.
|
||||
The full InnerEye conda environment is needed to submit a training job to AzureML.
|
||||
- ([#458](https://github.com/microsoft/InnerEye-DeepLearning/pull/458)) Getting rid of all the unused code for
|
||||
RandAugment & Co. The user has now instead complete freedom to specify the set of augmentations to use.
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
|
|
@ -2,195 +2,15 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import random
|
||||
from collections import Callable
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform
|
||||
from scipy.ndimage import gaussian_filter, map_coordinates
|
||||
from torchvision.transforms import ToTensor
|
||||
from yacs.config import CfgNode
|
||||
|
||||
|
||||
class BaseTransform:
|
||||
def transform(self, x: Any) -> Any:
|
||||
raise NotImplementedError("Transform needs to be overridden in the child classes")
|
||||
|
||||
def __call__(self, data: PIL.Image.Image) -> PIL.Image.Image:
|
||||
return self.transform(data)
|
||||
|
||||
|
||||
class CenterCrop(BaseTransform):
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
super().__init__()
|
||||
self.center_crop_size = config.preprocess.center_crop_size
|
||||
|
||||
def transform(self, x: Any) -> Any:
|
||||
return torchvision.transforms.CenterCrop(self.center_crop_size)(x)
|
||||
|
||||
|
||||
class RandomResizeCrop(BaseTransform):
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
super().__init__()
|
||||
self.resize_size = config.preprocess.resize
|
||||
self.crop_scale = config.augmentation.random_crop.scale
|
||||
|
||||
def transform(self, x: Any) -> Any:
|
||||
return torchvision.transforms.RandomResizedCrop(
|
||||
size=self.resize_size,
|
||||
scale=self.crop_scale)(x)
|
||||
|
||||
|
||||
class RandomHorizontalFlip(BaseTransform):
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
super().__init__()
|
||||
self.p_apply = config.augmentation.random_horizontal_flip.prob
|
||||
|
||||
def transform(self, x: Any) -> Any:
|
||||
return torchvision.transforms.RandomHorizontalFlip(self.p_apply)(x)
|
||||
|
||||
|
||||
class RandomAffine(BaseTransform):
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
super().__init__()
|
||||
self.max_angle = config.augmentation.random_affine.max_angle
|
||||
self.max_horizontal_shift = config.augmentation.random_affine.max_horizontal_shift
|
||||
self.max_vertical_shift = config.augmentation.random_affine.max_vertical_shift
|
||||
self.max_shear = config.augmentation.random_affine.max_shear
|
||||
|
||||
def transform(self, x: Any) -> Any:
|
||||
return torchvision.transforms.RandomAffine(degrees=self.max_angle,
|
||||
translate=(self.max_horizontal_shift, self.max_vertical_shift),
|
||||
shear=self.max_shear)(x)
|
||||
|
||||
|
||||
class Resize(BaseTransform):
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
super().__init__()
|
||||
self.resize_size = config.preprocess.resize
|
||||
|
||||
def transform(self, x: Any) -> Any:
|
||||
return torchvision.transforms.Resize(self.resize_size)(x)
|
||||
|
||||
|
||||
class RandomColorJitter(BaseTransform):
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
super().__init__()
|
||||
self.max_brightness = config.augmentation.random_color.brightness
|
||||
self.max_contrast = config.augmentation.random_color.contrast
|
||||
self.max_saturation = config.augmentation.random_color.saturation
|
||||
|
||||
def transform(self, x: Any) -> Any:
|
||||
return torchvision.transforms.ColorJitter(brightness=self.max_brightness,
|
||||
contrast=self.max_contrast,
|
||||
saturation=self.max_saturation)(x)
|
||||
|
||||
|
||||
class RandomErasing(BaseTransform):
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
super().__init__()
|
||||
self.scale = config.augmentation.random_erasing.scale
|
||||
self.ratio = config.augmentation.random_erasing.ratio
|
||||
|
||||
def transform(self, x: Any) -> Any:
|
||||
return torchvision.transforms.RandomErasing(p=0.5,
|
||||
scale=self.scale,
|
||||
ratio=self.ratio)(x)
|
||||
|
||||
|
||||
class RandomGamma(BaseTransform):
|
||||
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
super().__init__()
|
||||
self.scale = config.augmentation.gamma.scale
|
||||
|
||||
def transform(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
gamma = random.uniform(*self.scale)
|
||||
return torchvision.transforms.functional.adjust_gamma(image, gamma=gamma)
|
||||
|
||||
|
||||
class ExpandChannels(BaseTransform):
|
||||
"""
|
||||
Transforms an image with 1 channel to an image with 3 channels by copying pixel intensities of the image along
|
||||
the 0th dimension.
|
||||
"""
|
||||
|
||||
def transform(self, data: torch.Tensor) -> torch.Tensor:
|
||||
return torch.repeat_interleave(data, 3, dim=0)
|
||||
|
||||
|
||||
class AddGaussianNoise(BaseTransform):
|
||||
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
"""
|
||||
Transformation to add Gaussian noise N(0, std) to an image. Where std is set with the
|
||||
config.augmentation.gaussian_noise.std argument. The transformation will be applied with probability
|
||||
config.augmentation.gaussian_noise.p_apply
|
||||
"""
|
||||
super().__init__()
|
||||
self.p_apply = config.augmentation.gaussian_noise.p_apply
|
||||
self.std = config.augmentation.gaussian_noise.std
|
||||
|
||||
def transform(self, data: torch.Tensor) -> torch.Tensor:
|
||||
assert data.max() <= 1 and data.min() >= 0
|
||||
if np.random.random(1) > self.p_apply:
|
||||
return data
|
||||
noise = torch.randn(size=data.shape) * self.std
|
||||
data = torch.clamp(data + noise, 0, 1)
|
||||
return data
|
||||
|
||||
|
||||
class ElasticTransform(BaseTransform):
|
||||
"""Elastic deformation of images as described in [Simard2003]_.
|
||||
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
||||
Convolutional Neural Networks applied to Visual Document Analysis", in
|
||||
Proc. of the International Conference on Document Analysis and
|
||||
Recognition, 2003.
|
||||
|
||||
https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.160.8494&rep=rep1&type=pdf
|
||||
|
||||
:param sigma: elasticity coefficient
|
||||
:param alpha: intensity of the deformation
|
||||
:param p_apply: probability of applying the transformation
|
||||
"""
|
||||
|
||||
def __init__(self, config: CfgNode) -> None:
|
||||
super().__init__()
|
||||
self.alpha = config.augmentation.elastic_transform.alpha
|
||||
self.sigma = config.augmentation.elastic_transform.sigma
|
||||
self.p_apply = config.augmentation.elastic_transform.p_apply
|
||||
|
||||
def transform(self, image: PIL.Image) -> PIL.Image:
|
||||
if np.random.random(1) > self.p_apply:
|
||||
return image
|
||||
image = np.asarray(image).squeeze()
|
||||
assert len(image.shape) == 2
|
||||
shape = image.shape
|
||||
|
||||
dx = gaussian_filter((np.random.random(shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
|
||||
dy = gaussian_filter((np.random.random(shape) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
|
||||
|
||||
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
|
||||
indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))
|
||||
return PIL.Image.fromarray(map_coordinates(image, indices, order=1).reshape(shape))
|
||||
|
||||
|
||||
class DualViewTransformWrapper:
|
||||
"""
|
||||
Returns two versions of one image, given a random transformation function.
|
||||
"""
|
||||
|
||||
def __init__(self, transform: Callable):
|
||||
self.transform = transform
|
||||
|
||||
def __call__(self, sample: PIL.Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xi = self.transform(sample)
|
||||
xj = self.transform(sample)
|
||||
return xi, xj
|
||||
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
|
||||
|
||||
|
||||
def get_cxr_ssl_transforms(config: CfgNode,
|
||||
|
@ -214,54 +34,19 @@ def get_cxr_ssl_transforms(config: CfgNode,
|
|||
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
|
||||
(no augmentations)
|
||||
"""
|
||||
train_transforms = create_chest_xray_transform(config, apply_augmentations=True)
|
||||
val_transforms = create_chest_xray_transform(config, apply_augmentations=use_training_augmentations_for_validation)
|
||||
train_transforms = create_cxr_transforms_from_config(config, apply_augmentations=True)
|
||||
val_transforms = create_cxr_transforms_from_config(config,
|
||||
apply_augmentations=use_training_augmentations_for_validation)
|
||||
if return_two_views_per_sample:
|
||||
train_transforms = DualViewTransformWrapper(train_transforms)
|
||||
val_transforms = DualViewTransformWrapper(val_transforms)
|
||||
train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore
|
||||
val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore
|
||||
return train_transforms, val_transforms
|
||||
|
||||
|
||||
def create_chest_xray_transform(config: CfgNode,
|
||||
apply_augmentations: bool) -> Callable:
|
||||
"""
|
||||
Defines the image transformations pipeline used in Chest-Xray datasets.
|
||||
Type of augmentation and strength are defined in the config.
|
||||
:param config: config yaml file fixing strength and type of augmentation to apply
|
||||
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
|
||||
disable augmentations i.e.
|
||||
only resize and center crop the image.
|
||||
"""
|
||||
transforms: List[Any] = []
|
||||
if apply_augmentations:
|
||||
if config.augmentation.use_random_affine:
|
||||
transforms.append(RandomAffine(config))
|
||||
if config.augmentation.use_random_crop:
|
||||
transforms.append(RandomResizeCrop(config))
|
||||
else:
|
||||
transforms.append(Resize(config))
|
||||
if config.augmentation.use_random_horizontal_flip:
|
||||
transforms.append(RandomHorizontalFlip(config))
|
||||
if config.augmentation.use_gamma_transform:
|
||||
transforms.append(RandomGamma(config))
|
||||
if config.augmentation.use_random_color:
|
||||
transforms.append(RandomColorJitter(config))
|
||||
if config.augmentation.use_elastic_transform:
|
||||
transforms.append(ElasticTransform(config))
|
||||
transforms += [CenterCrop(config), ToTensor()]
|
||||
if config.augmentation.use_random_erasing:
|
||||
transforms.append(RandomErasing(config))
|
||||
if config.augmentation.add_gaussian_noise:
|
||||
transforms.append(AddGaussianNoise(config))
|
||||
else:
|
||||
transforms += [Resize(config), CenterCrop(config), ToTensor()]
|
||||
transforms.append(ExpandChannels())
|
||||
return torchvision.transforms.Compose(transforms)
|
||||
|
||||
|
||||
class InnerEyeCIFARTrainTransform(SimCLRTrainDataTransform):
|
||||
"""
|
||||
Overload lightning-bolts SimCLRTrainDataTransform, to avoid return unused eval transform. Used for training and
|
||||
Overload lightning-bolts SimCLRTrainDataTransform, to avoid return unused eval transform. Used for
|
||||
training and
|
||||
val of SSL models.
|
||||
"""
|
||||
|
||||
|
@ -279,3 +64,17 @@ class InnerEyeCIFARLinearHeadTransform(SimCLRTrainDataTransform):
|
|||
|
||||
def __call__(self, sample: Any) -> Any:
|
||||
return self.online_transform(sample)
|
||||
|
||||
|
||||
class DualViewTransformWrapper:
|
||||
"""
|
||||
Returns two versions of one image, given a random transformation function.
|
||||
"""
|
||||
|
||||
def __init__(self, transform: Callable):
|
||||
self.transform = transform
|
||||
|
||||
def __call__(self, sample: PIL.Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xi = self.transform(sample)
|
||||
xj = self.transform(sample)
|
||||
return xi, xj
|
||||
|
|
|
@ -22,7 +22,7 @@ from InnerEye.ML.SSL.encoders import get_encoder_output_dim
|
|||
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
|
||||
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
|
||||
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
|
||||
from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType, load_ssl_augmentation_config
|
||||
from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType, load_yaml_augmentation_config
|
||||
from InnerEye.ML.lightning_container import LightningContainer
|
||||
|
||||
|
||||
|
@ -127,10 +127,10 @@ class SSLContainer(LightningContainer):
|
|||
|
||||
def _load_config(self) -> None:
|
||||
# For Chest-XRay you need to specify the parameters of the augmentations via a config file.
|
||||
self.ssl_augmentation_params = load_ssl_augmentation_config(
|
||||
self.ssl_augmentation_params = load_yaml_augmentation_config(
|
||||
self.ssl_augmentation_config) if self.ssl_augmentation_config is not None \
|
||||
else None
|
||||
self.classifier_augmentation_params = load_ssl_augmentation_config(
|
||||
self.classifier_augmentation_params = load_yaml_augmentation_config(
|
||||
self.linear_head_augmentation_config) if self.linear_head_augmentation_config is not None else \
|
||||
self.ssl_augmentation_params
|
||||
|
||||
|
|
|
@ -25,10 +25,9 @@ class SSLTrainingType(Enum):
|
|||
BYOL = "BYOL"
|
||||
|
||||
|
||||
def load_ssl_augmentation_config(config_path: Path) -> CfgNode:
|
||||
def load_yaml_augmentation_config(config_path: Path) -> CfgNode:
|
||||
"""
|
||||
Loads configs required for self supervised learning. Does not setup cudann as this is being
|
||||
taken care of by lightning.
|
||||
Loads augmentations configs defined as yaml files.
|
||||
"""
|
||||
config = ssl_augmentation_config.get_default_model_config()
|
||||
config.merge_from_file(config_path)
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import random
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from InnerEye.Common.common_util import any_pairwise_larger
|
||||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.dataset.sample import Sample
|
||||
|
||||
|
||||
def random_select_patch_center(sample: Sample, class_weights: List[float] = None) -> np.ndarray:
|
||||
"""
|
||||
Samples a point to use as the coordinates of the patch center. First samples one
|
||||
class among the available classes then samples a center point among the pixels of the sampled
|
||||
class.
|
||||
|
||||
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
|
||||
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
|
||||
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
|
||||
:return numpy int array (3x1) containing patch center spatial coordinates
|
||||
"""
|
||||
num_classes = sample.labels.shape[0]
|
||||
|
||||
if class_weights is not None:
|
||||
if len(class_weights) != num_classes:
|
||||
raise Exception("A weight must be provided for each class, found weights:{}, expected:{}"
|
||||
.format(len(class_weights), num_classes))
|
||||
SegmentationModelBase.validate_class_weights(class_weights)
|
||||
|
||||
# If class weights are not initialised, selection is made with equal probability for all classes
|
||||
available_classes = list(range(num_classes))
|
||||
original_class_weights = class_weights
|
||||
while len(available_classes) > 0:
|
||||
selected_label_class = random.choices(population=available_classes, weights=class_weights, k=1)[0]
|
||||
# Check pixels where mask and label maps are both foreground
|
||||
indices = np.argwhere(np.logical_and(sample.labels[selected_label_class] == 1.0, sample.mask == 1))
|
||||
if not np.any(indices):
|
||||
available_classes.remove(selected_label_class)
|
||||
if class_weights is not None:
|
||||
assert original_class_weights is not None # for mypy
|
||||
class_weights = [original_class_weights[i] for i in available_classes]
|
||||
if sum(class_weights) <= 0.0:
|
||||
raise ValueError("Cannot sample a class: no class present in the sample has a positive weight")
|
||||
else:
|
||||
break
|
||||
|
||||
# Raise an exception if non of the foreground classes are overlapping with the mask
|
||||
if len(available_classes) == 0:
|
||||
raise Exception("No non-mask voxels found, please check your mask and labels map")
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
choice = random.randint(0, len(indices) - 1)
|
||||
|
||||
return indices[choice].astype(int) # Numpy usually stores as floats
|
||||
|
||||
|
||||
def slicers_for_random_crop(sample: Sample,
|
||||
crop_size: TupleInt3,
|
||||
class_weights: List[float] = None) -> Tuple[List[slice], np.ndarray]:
|
||||
"""
|
||||
Computes array slicers that produce random crops of the given crop_size.
|
||||
The selection of the center is dependant on background probability.
|
||||
By default it does not center on background.
|
||||
|
||||
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
|
||||
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
|
||||
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
|
||||
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
|
||||
:return: Tuple element 1: The slicers that convert the input image to the chosen crop. Tuple element 2: The
|
||||
indices of the center point of the crop.
|
||||
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
|
||||
"""
|
||||
shape = sample.image.shape[1:]
|
||||
|
||||
if any_pairwise_larger(crop_size, shape):
|
||||
raise ValueError("The crop_size across each dimension should be greater than zero and less than or equal "
|
||||
"to the current value (crop_size: {}, spatial shape: {})"
|
||||
.format(crop_size, shape))
|
||||
|
||||
# Sample a center pixel location for patch extraction.
|
||||
center = random_select_patch_center(sample, class_weights)
|
||||
|
||||
# Verify and fix overflow for each dimension
|
||||
left = []
|
||||
for i in range(3):
|
||||
margin_left = int(crop_size[i] / 2)
|
||||
margin_right = crop_size[i] - margin_left
|
||||
left_index = center[i] - margin_left
|
||||
right_index = center[i] + margin_right
|
||||
if right_index > shape[i]:
|
||||
left_index = left_index - (right_index - shape[i])
|
||||
if left_index < 0:
|
||||
left_index = 0
|
||||
left.append(left_index)
|
||||
|
||||
return [slice(left[x], left[x] + crop_size[x]) for x in range(0, 3)], center
|
||||
|
||||
|
||||
def random_crop(sample: Sample,
|
||||
crop_size: TupleInt3,
|
||||
class_weights: List[float] = None) -> Tuple[Sample, np.ndarray]:
|
||||
"""
|
||||
Randomly crops images, mask, and labels arrays according to the crop_size argument.
|
||||
The selection of the center is dependant on background probability.
|
||||
By default it does not center on background.
|
||||
|
||||
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
|
||||
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
|
||||
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
|
||||
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
|
||||
:return: Tuple item 1: The cropped images, labels, and mask. Tuple item 2: The center that was chosen for the crop,
|
||||
before shifting to be inside of the image. Tuple item 3: The slicers that convert the input image to the chosen
|
||||
crop.
|
||||
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
|
||||
"""
|
||||
slicers, center = slicers_for_random_crop(sample, crop_size, class_weights)
|
||||
sample = Sample(
|
||||
image=sample.image[:, slicers[0], slicers[1], slicers[2]],
|
||||
labels=sample.labels[:, slicers[0], slicers[1], slicers[2]],
|
||||
mask=sample.mask[slicers[0], slicers[1], slicers[2]],
|
||||
metadata=sample.metadata
|
||||
)
|
||||
return sample, center
|
|
@ -0,0 +1,116 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import random
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from scipy.ndimage import gaussian_filter, map_coordinates
|
||||
|
||||
|
||||
class RandomGamma:
|
||||
"""
|
||||
Custom function to apply a random gamma transform within a specified range of possible gamma value.
|
||||
See documentation of
|
||||
[`adjust_gamma`](https://pytorch.org/vision/0.8/transforms.html#torchvision.transforms.functional.adjust_gamma) for
|
||||
more details.
|
||||
"""
|
||||
|
||||
def __init__(self, scale: Tuple[float, float]) -> None:
|
||||
"""
|
||||
:param scale: a tuple (min_gamma, max_gamma) that specifies the range of possible values to sample the gamma
|
||||
value from when the transformation is called.
|
||||
"""
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
gamma = random.uniform(*self.scale)
|
||||
if len(image.shape) != 4:
|
||||
raise ValueError(f"Expected input of shape [Z, C, H, W], but only got {len(image.shape)} dimensions")
|
||||
for z in range(image.shape[0]):
|
||||
for c in range(image.shape[1]):
|
||||
image[z, c] = torchvision.transforms.functional.adjust_gamma(image[z, c], gamma=gamma)
|
||||
return image
|
||||
|
||||
|
||||
class ExpandChannels:
|
||||
"""
|
||||
Transforms an image with 1 channel to an image with 3 channels by copying pixel intensities of the image along
|
||||
the 1st dimension.
|
||||
"""
|
||||
|
||||
def __call__(self, data: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
:param: data of shape [Z, 1, H, W]
|
||||
:return: data with channel copied 3 times, shape [Z, 3, H, W]
|
||||
"""
|
||||
shape = data.shape
|
||||
if len(shape) != 4 or shape[1] != 1:
|
||||
raise ValueError(f"Expected input of shape [Z, 1, H, W], found {shape}")
|
||||
return torch.repeat_interleave(data, 3, dim=1)
|
||||
|
||||
|
||||
class AddGaussianNoise:
|
||||
|
||||
def __init__(self, p_apply: float, std: float) -> None:
|
||||
"""
|
||||
Transformation to add Gaussian noise N(0, std) to an image.
|
||||
:param: p_apply: probability of applying the transformation.
|
||||
:param: std: standard deviation of the gaussian noise to add to the image.
|
||||
"""
|
||||
super().__init__()
|
||||
self.p_apply = p_apply
|
||||
self.std = std
|
||||
|
||||
def __call__(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if np.random.random(1) > self.p_apply:
|
||||
return data
|
||||
noise = torch.randn(size=data.shape[-2:]) * self.std
|
||||
data = torch.clamp(data + noise, data.min(), data.max()) # type: ignore
|
||||
return data
|
||||
|
||||
|
||||
class ElasticTransform:
|
||||
"""Elastic deformation of images as described in [Simard2003]_.
|
||||
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
||||
Convolutional Neural Networks applied to Visual Document Analysis", in
|
||||
Proc. of the International Conference on Document Analysis and
|
||||
Recognition, 2003.
|
||||
https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.160.8494&rep=rep1&type=pdf
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sigma: float,
|
||||
alpha: float,
|
||||
p_apply: float
|
||||
) -> None:
|
||||
"""
|
||||
:param sigma: elasticity coefficient
|
||||
:param alpha: intensity of the deformation
|
||||
:param p_apply: probability of applying the transformation
|
||||
"""
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.sigma = sigma
|
||||
self.p_apply = p_apply
|
||||
|
||||
def __call__(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if np.random.random(1) > self.p_apply:
|
||||
return data
|
||||
result_type = data.dtype
|
||||
data = data.cpu().numpy()
|
||||
shape = data.shape
|
||||
|
||||
dx = gaussian_filter((np.random.random(shape[-2:]) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
|
||||
dy = gaussian_filter((np.random.random(shape[-2:]) * 2 - 1), self.sigma, mode="constant", cval=0) * self.alpha
|
||||
all_dimensions_axes = [np.arange(dim) for dim in shape]
|
||||
grid = np.meshgrid(*all_dimensions_axes, indexing='ij')
|
||||
grid[-2] = grid[-2] + dx
|
||||
grid[-1] = grid[-1] + dy
|
||||
indices = [np.reshape(grid[i], (-1, 1)) for i in range(len(grid))]
|
||||
|
||||
return torch.tensor(map_coordinates(data, indices, order=1).reshape(shape), dtype=result_type)
|
|
@ -0,0 +1,145 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from torchvision.transforms import CenterCrop, ColorJitter, Compose, RandomAffine, RandomErasing, \
|
||||
RandomHorizontalFlip, RandomResizedCrop, Resize
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma
|
||||
|
||||
ImageData = Union[PIL.Image.Image, torch.Tensor]
|
||||
|
||||
|
||||
class ImageTransformationPipeline:
|
||||
"""
|
||||
This class is the base class to classes built to define data augmentation transformations
|
||||
for 3D or 2D image inputs (tensor or PIL.Image).
|
||||
In the case of 3D images, the transformations are applied slice by slices along the Z dimension (same transformation
|
||||
applied for each slice).
|
||||
The transformations are applied channel by channel, the user can specify whether to apply the same transformation
|
||||
to each channel (no random shuffling) or whether each channel should use a different transformation (random
|
||||
parameters of transforms shuffled for each channel).
|
||||
"""
|
||||
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self,
|
||||
transforms: Union[Callable, List[Callable]],
|
||||
use_different_transformation_per_channel: bool = False):
|
||||
"""
|
||||
:param transforms: List of transformations to apply to images. Supports out of the boxes torchvision transforms
|
||||
as they accept data of arbitrary dimension. You can also define your own transform class but be aware that you
|
||||
function should expect input of shape [C, Z, H, W] and apply the same transformation to each Z slice.
|
||||
:param use_different_transformation_per_channel: if True, apply a different version of the augmentation pipeline
|
||||
for each channel. If False, applies the same transformation to each channel, separately.
|
||||
"""
|
||||
self.use_different_transformation_per_channel = use_different_transformation_per_channel
|
||||
self.pipeline = Compose(transforms) if isinstance(transforms, List) else transforms
|
||||
|
||||
def transform_image(self, image: ImageData) -> torch.Tensor:
|
||||
"""
|
||||
Main function to apply the transformation pipeline to either slice by slice on one 3D-image or
|
||||
on the 2D image.
|
||||
|
||||
Note for 3D images: Assumes the same transformations have to be applied on each 2D-slice along the Z-axis.
|
||||
Assumes the Z axis is the first dimension.
|
||||
|
||||
:param image: batch of tensor images of size [C, Z, Y, X] or batch of 2D images as PIL Image
|
||||
"""
|
||||
|
||||
def _convert_to_tensor_if_necessary(data: ImageData) -> torch.Tensor:
|
||||
return to_tensor(data) if isinstance(data, PIL.Image.Image) else data
|
||||
|
||||
image = _convert_to_tensor_if_necessary(image)
|
||||
original_input_is_2d = len(image.shape) == 3
|
||||
# If we have a 2D image [C, H, W] expand to [Z, C, H, W]. Build-in torchvision transforms allow such 4D inputs.
|
||||
if original_input_is_2d:
|
||||
image = image.unsqueeze(0)
|
||||
else:
|
||||
# Some transforms assume the order of dimension is [..., C, H, W] so permute first and last dimension to
|
||||
# obtain [Z, C, H, W]
|
||||
if len(image.shape) != 4:
|
||||
raise ValueError(f"ScalarDataset should load images as 4D tensor [C, Z, H, W]. The input tensor here"
|
||||
f"was of shape {image.shape}. This is unexpected.")
|
||||
image = torch.transpose(image, 1, 0)
|
||||
|
||||
if not self.use_different_transformation_per_channel:
|
||||
image = _convert_to_tensor_if_necessary(self.pipeline(image))
|
||||
else:
|
||||
channels = []
|
||||
for channel in range(image.shape[1]):
|
||||
channels.append(_convert_to_tensor_if_necessary(self.pipeline(image[:, channel, :, :].unsqueeze(1))))
|
||||
image = torch.cat(channels, dim=1)
|
||||
# Back to [C, Z, H, W]
|
||||
image = torch.transpose(image, 1, 0)
|
||||
if original_input_is_2d:
|
||||
image = image.squeeze(1)
|
||||
return image.to(dtype=image.dtype)
|
||||
|
||||
def __call__(self, data: ImageData) -> torch.Tensor:
|
||||
return self.transform_image(data)
|
||||
|
||||
|
||||
def create_cxr_transforms_from_config(config: CfgNode,
|
||||
apply_augmentations: bool) -> ImageTransformationPipeline:
|
||||
"""
|
||||
Defines the image transformations pipeline used in Chest-Xray datasets. Can be used for other types of
|
||||
images data, type of augmentations to use and strength are expected to be defined in the config.
|
||||
:param config: config yaml file fixing strength and type of augmentation to apply
|
||||
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
|
||||
disable augmentations i.e. only resize and center crop the image.
|
||||
"""
|
||||
transforms: List[Any] = [ExpandChannels()]
|
||||
if apply_augmentations:
|
||||
if config.augmentation.use_random_affine:
|
||||
transforms.append(RandomAffine(
|
||||
degrees=config.augmentation.random_affine.max_angle,
|
||||
translate=(config.augmentation.random_affine.max_horizontal_shift,
|
||||
config.augmentation.random_affine.max_vertical_shift),
|
||||
shear=config.augmentation.random_affine.max_shear
|
||||
))
|
||||
if config.augmentation.use_random_crop:
|
||||
transforms.append(RandomResizedCrop(
|
||||
scale=config.augmentation.random_crop.scale,
|
||||
size=config.preprocess.resize
|
||||
))
|
||||
else:
|
||||
transforms.append(Resize(size=config.preprocess.resize))
|
||||
if config.augmentation.use_random_horizontal_flip:
|
||||
transforms.append(RandomHorizontalFlip(p=config.augmentation.random_horizontal_flip.prob))
|
||||
if config.augmentation.use_gamma_transform:
|
||||
transforms.append(RandomGamma(scale=config.augmentation.gamma.scale))
|
||||
if config.augmentation.use_random_color:
|
||||
transforms.append(ColorJitter(
|
||||
brightness=config.augmentation.random_color.brightness,
|
||||
contrast=config.augmentation.random_color.contrast,
|
||||
saturation=config.augmentation.random_color.saturation
|
||||
))
|
||||
if config.augmentation.use_elastic_transform:
|
||||
transforms.append(ElasticTransform(
|
||||
alpha=config.augmentation.elastic_transform.alpha,
|
||||
sigma=config.augmentation.elastic_transform.sigma,
|
||||
p_apply=config.augmentation.elastic_transform.p_apply
|
||||
))
|
||||
transforms.append(CenterCrop(config.preprocess.center_crop_size))
|
||||
if config.augmentation.use_random_erasing:
|
||||
transforms.append(RandomErasing(
|
||||
scale=config.augmentation.random_erasing.scale,
|
||||
ratio=config.augmentation.random_erasing.ratio
|
||||
))
|
||||
if config.augmentation.add_gaussian_noise:
|
||||
transforms.append(AddGaussianNoise(
|
||||
p_apply=config.augmentation.gaussian_noise.p_apply,
|
||||
std=config.augmentation.gaussian_noise.std
|
||||
))
|
||||
else:
|
||||
transforms += [Resize(size=config.preprocess.resize),
|
||||
CenterCrop(config.preprocess.center_crop_size)]
|
||||
pipeline = ImageTransformationPipeline(transforms)
|
||||
return pipeline
|
|
@ -14,11 +14,12 @@ from pytorch_lightning import LightningModule
|
|||
from torchvision.transforms import Compose
|
||||
|
||||
from InnerEye.Common.common_util import ModelProcessing, get_best_epoch_results_path
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import create_chest_xray_transform
|
||||
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName
|
||||
|
||||
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
|
||||
from InnerEye.ML.SSL.utils import create_ssl_encoder, create_ssl_image_classifier, load_ssl_augmentation_config
|
||||
from InnerEye.ML.SSL.utils import create_ssl_encoder, create_ssl_image_classifier, load_yaml_augmentation_config
|
||||
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
|
||||
from InnerEye.ML.common import ModelExecutionMode
|
||||
|
||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_linear_head_augmentation_cxr
|
||||
|
@ -32,7 +33,6 @@ from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp impo
|
|||
from InnerEye.ML.reports.notebook_report import generate_notebook, get_ipynb_report_name, str_or_empty
|
||||
|
||||
from InnerEye.ML.scalar_config import ScalarLoss, ScalarModelBase
|
||||
from InnerEye.ML.utils.augmentation import ScalarItemAugmentation
|
||||
from InnerEye.ML.utils.run_recovery import RunRecovery
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
|
||||
|
@ -114,13 +114,11 @@ class CovidHierarchicalModel(ScalarModelBase):
|
|||
|
||||
# noinspection PyTypeChecker
|
||||
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
|
||||
config = load_ssl_augmentation_config(path_linear_head_augmentation_cxr)
|
||||
train_transforms = ScalarItemAugmentation(
|
||||
Compose(
|
||||
[DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=True)]))
|
||||
val_transforms = ScalarItemAugmentation(
|
||||
Compose(
|
||||
[DicomPreparation(), create_chest_xray_transform(config, apply_augmentations=False)]))
|
||||
config = load_yaml_augmentation_config(path_linear_head_augmentation_cxr)
|
||||
train_transforms = Compose(
|
||||
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=True)])
|
||||
val_transforms = Compose(
|
||||
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=False)])
|
||||
|
||||
return ModelTransformsPerExecutionMode(train=train_transforms,
|
||||
val=val_transforms,
|
||||
|
|
|
@ -8,12 +8,13 @@ from typing import Any, Dict, List, Optional
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from InnerEye.ML.augmentations.augmentation_for_segmentation_utils import random_crop
|
||||
from InnerEye.Common.common_util import any_pairwise_larger
|
||||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
from InnerEye.ML.config import PaddingMode, SegmentationModelBase
|
||||
from InnerEye.ML.dataset.full_image_dataset import FullImageDataset
|
||||
from InnerEye.ML.dataset.sample import CroppedSample, Sample
|
||||
from InnerEye.ML.utils import augmentation, image_util
|
||||
from InnerEye.ML.utils import image_util
|
||||
from InnerEye.ML.utils.image_util import pad_images
|
||||
from InnerEye.ML.utils.io_util import ImageDataType
|
||||
from InnerEye.ML.utils.transforms import Compose3D
|
||||
|
@ -96,7 +97,7 @@ class CroppingDataset(FullImageDataset):
|
|||
:return: CroppedSample
|
||||
"""
|
||||
# crop the original raw sample
|
||||
sample, center_point = augmentation.random_crop(
|
||||
sample, center_point = random_crop(
|
||||
sample=sample,
|
||||
crop_size=crop_size,
|
||||
class_weights=class_weights
|
||||
|
|
|
@ -26,7 +26,6 @@ from InnerEye.ML.sequence_config import SequenceModelBase
|
|||
from InnerEye.ML.utils.csv_util import CSV_CHANNEL_HEADER, CSV_SUBJECT_HEADER
|
||||
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
|
||||
from InnerEye.ML.utils.features_util import FeatureStatistics
|
||||
from InnerEye.ML.utils.transforms import Compose3D, Transform3D
|
||||
|
||||
T = TypeVar('T', bound=ScalarDataSource)
|
||||
|
||||
|
@ -67,12 +66,13 @@ def extract_label_classification(label_string: str, sample_id: str, num_classes:
|
|||
if isinstance(label_string, float):
|
||||
if math.isnan(label_string):
|
||||
if num_classes == 1:
|
||||
# Pandas special case: When loading a dataframe with dtype=str, missing values can be encoded as NaN, and get into here.
|
||||
# Pandas special case: When loading a dataframe with dtype=str, missing values can be encoded as NaN,
|
||||
# and get into here.
|
||||
return [label_string]
|
||||
else:
|
||||
return [0] * num_classes
|
||||
else:
|
||||
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
|
||||
raise ValueError(f"Subject {sample_id}: Unexpected float input {label_string} - did you read the "
|
||||
f"dataframe column as a string?")
|
||||
|
||||
if not label_string:
|
||||
|
@ -92,7 +92,7 @@ def extract_label_classification(label_string: str, sample_id: str, num_classes:
|
|||
return [float(label_string)]
|
||||
else:
|
||||
raise ValueError(f"Subject {sample_id}: Label string not recognized: '{label_string}'. "
|
||||
f"Should be one of true/false, yes/no or 0/1.")
|
||||
f"Should be one of true/false, yes/no or 0/1.")
|
||||
|
||||
if '|' in label_string or label_string.isdigit():
|
||||
classes = [int(a) for a in label_string.split('|')]
|
||||
|
@ -648,6 +648,36 @@ You now want to get the label from the "week0" row, and read out Scalar1 at week
|
|||
"""
|
||||
|
||||
|
||||
class ScalarItemAugmentation:
|
||||
"""
|
||||
Wrapper around augmentation pipeline to apply image or/and segmentation transformations
|
||||
to a ScalarItem inputs.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
image_transform: Optional[Callable] = None,
|
||||
segmentation_transform: Optional[Callable] = None) -> None:
|
||||
"""
|
||||
:param image_transform: transformation function to apply to images field. If None, images field is unchanged by
|
||||
call.
|
||||
:param segmentation_transform: transformation function to apply to segmentations field. If None segmentations
|
||||
field is unchanged by call.
|
||||
"""
|
||||
self.image_transform = image_transform
|
||||
self.segmentation_transform = segmentation_transform
|
||||
|
||||
def __call__(self, item: ScalarItem) -> ScalarItem:
|
||||
if self.image_transform is not None:
|
||||
if self.segmentation_transform is not None:
|
||||
return item.clone_with_overrides(images=self.image_transform(item.images),
|
||||
segmentations=self.segmentation_transform(item.segmentations))
|
||||
return item.clone_with_overrides(images=self.image_transform(item.images))
|
||||
|
||||
if self.segmentation_transform is not None:
|
||||
item.clone_with_overrides(segmentations=self.segmentation_transform(item.segmentations))
|
||||
return item
|
||||
|
||||
|
||||
class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
||||
"""
|
||||
A base class for datasets for classification tasks. It contains logic for loading images from disk,
|
||||
|
@ -661,7 +691,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
|||
data_frame: Optional[pd.DataFrame] = None,
|
||||
feature_statistics: Optional[FeatureStatistics] = None,
|
||||
name: Optional[str] = None,
|
||||
sample_transforms: Optional[Union[Compose3D[ScalarItem], Transform3D[ScalarItem]]] = None):
|
||||
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
|
||||
"""
|
||||
High level class for the scalar dataset designed to be inherited for specific behaviour
|
||||
:param args: The model configuration object.
|
||||
|
@ -670,7 +700,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
|||
:param name: Name of the dataset, used for diagnostics logging
|
||||
"""
|
||||
super().__init__(args, data_frame, name)
|
||||
self.transforms = sample_transforms
|
||||
self.transform = sample_transform
|
||||
self.feature_statistics = feature_statistics
|
||||
self.file_to_full_path: Optional[Dict[str, Path]] = None
|
||||
if args.traverse_dirs_when_loading:
|
||||
|
@ -725,7 +755,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], Generic[T]):
|
|||
center_crop_size=self.args.center_crop_size,
|
||||
image_size=self.args.image_size)
|
||||
|
||||
return Compose3D.apply(self.transforms, sample)
|
||||
return self.transform(sample)
|
||||
|
||||
def create_status_string(self, items: List[T]) -> str:
|
||||
"""
|
||||
|
@ -746,21 +776,21 @@ class ScalarDataset(ScalarDatasetBase[ScalarDataSource]):
|
|||
data_frame: Optional[pd.DataFrame] = None,
|
||||
feature_statistics: Optional[FeatureStatistics[ScalarDataSource]] = None,
|
||||
name: Optional[str] = None,
|
||||
sample_transforms: Optional[Union[Compose3D[ScalarItem], Transform3D[ScalarItem]]] = None):
|
||||
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
|
||||
"""
|
||||
Creates a new scalar dataset from a dataframe.
|
||||
:param args: The model configuration object.
|
||||
:param data_frame: The dataframe to read from.
|
||||
:param feature_statistics: If given, the normalization factor for the non-image features is taken
|
||||
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
|
||||
:param sample_transforms: Sample transforms that should be applied.
|
||||
:param sample_transform: Sample transforms that should be applied.
|
||||
:param name: Name of the dataset, used for diagnostics logging
|
||||
"""
|
||||
super().__init__(args,
|
||||
data_frame=data_frame,
|
||||
feature_statistics=feature_statistics,
|
||||
name=name,
|
||||
sample_transforms=sample_transforms)
|
||||
sample_transform=sample_transform)
|
||||
self.items = self.load_all_data_sources()
|
||||
self.standardize_non_imaging_features()
|
||||
|
||||
|
|
|
@ -6,18 +6,18 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Callable, DefaultDict, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDatasetBase, filter_valid_classification_data_sources_items
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDatasetBase, ScalarItemAugmentation, \
|
||||
filter_valid_classification_data_sources_items
|
||||
from InnerEye.ML.dataset.scalar_sample import ScalarItem, SequenceDataSource
|
||||
from InnerEye.ML.dataset.sequence_sample import ClassificationItemSequence, ListOfSequences
|
||||
from InnerEye.ML.sequence_config import SequenceModelBase
|
||||
from InnerEye.ML.utils.features_util import FeatureStatistics
|
||||
from InnerEye.ML.utils.transforms import Compose3D, Transform3D
|
||||
|
||||
|
||||
def get_longest_contiguous_sequence(items: List[SequenceDataSource],
|
||||
|
@ -214,13 +214,14 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
|
|||
feature_statistics: Optional[
|
||||
FeatureStatistics[ClassificationItemSequence[SequenceDataSource]]] = None,
|
||||
name: Optional[str] = None,
|
||||
sample_transforms: Optional[Union[Compose3D[ScalarItem], Transform3D[ScalarItem]]] = None):
|
||||
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
|
||||
"""
|
||||
Creates a new sequence dataset from a dataframe.
|
||||
:param args: The model configuration object.
|
||||
:param data_frame: The dataframe to read from.
|
||||
:param feature_statistics: If given, the normalization factor for the non-image features is taken
|
||||
:param sample_transforms: optional transformation to apply to each sample in the loading step.
|
||||
:param sample_transform: Transformation to apply to each sample in the loading step. By default, no
|
||||
transformation is applied.
|
||||
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
|
||||
:param name: Name of the dataset, used for logging
|
||||
"""
|
||||
|
@ -228,7 +229,7 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
|
|||
data_frame=data_frame,
|
||||
feature_statistics=feature_statistics,
|
||||
name=name,
|
||||
sample_transforms=sample_transforms)
|
||||
sample_transform=sample_transform)
|
||||
if self.args.sequence_column is None:
|
||||
raise ValueError("This class requires a value in the `sequence_column`, specifying where the "
|
||||
"sequence index should be read from.")
|
||||
|
@ -288,7 +289,8 @@ class SequenceDataset(ScalarDatasetBase[SequenceDataSource]):
|
|||
"""
|
||||
all_labels_per_target = torch.stack([seq.get_labels_at_target_indices(self.args.get_target_indices())
|
||||
for seq in self.items]) # [N, T, 1]
|
||||
non_nan_and_nonzero_labels = list(filter(lambda x: not np.isnan(x) and x != 0, all_labels_per_target.flatten().tolist()))
|
||||
non_nan_and_nonzero_labels = list(
|
||||
filter(lambda x: not np.isnan(x) and x != 0, all_labels_per_target.flatten().tolist()))
|
||||
counts = dict(Counter(non_nan_and_nonzero_labels))
|
||||
if not len(counts.keys()) == 1 or 1 not in counts.keys():
|
||||
raise ValueError("get_class_counts supports only binary targets.")
|
||||
|
|
|
@ -15,6 +15,7 @@ from pandas import DataFrame
|
|||
from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
|
||||
from InnerEye.Common.common_util import ModelProcessing
|
||||
from InnerEye.Common.metrics_constants import TrackedMetrics
|
||||
|
||||
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode, STORED_CSV_FILE_NAMES
|
||||
from InnerEye.ML.deep_learning_config import DeepLearningConfig
|
||||
from InnerEye.ML.utils.split_dataset import DatasetSplits
|
||||
|
|
|
@ -14,6 +14,7 @@ from azureml.train.hyperdrive import HyperDriveConfig
|
|||
from InnerEye.Common.common_util import print_exception
|
||||
from InnerEye.Common.generic_parsing import ListOrDictParam
|
||||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
|
||||
from InnerEye.ML.common import ModelExecutionMode, OneHotEncoderBase
|
||||
from InnerEye.ML.deep_learning_config import ModelCategory
|
||||
from InnerEye.ML.model_config_base import ModelConfigBase, ModelTransformsPerExecutionMode
|
||||
|
@ -22,6 +23,7 @@ from InnerEye.ML.utils.split_dataset import DatasetSplits
|
|||
|
||||
DEFAULT_KEY = "Default"
|
||||
|
||||
|
||||
class AggregationType(Enum):
|
||||
"""
|
||||
The type of global pooling aggregation to use between the encoder and the classifier.
|
||||
|
@ -113,7 +115,8 @@ class ScalarModelBase(ModelConfigBase):
|
|||
"in dataset.csv"
|
||||
"For multi-label classification, this field is required."
|
||||
"For binary classification, this field must be a list of size 1, and "
|
||||
"is by default ['Default'], but can optionally be set to a more descriptive "
|
||||
"is by default ['Default'], but can optionally be set to a more "
|
||||
"descriptive "
|
||||
"name for the positive class.")
|
||||
target_names: List[str] = param.List(class_=str,
|
||||
default=None,
|
||||
|
@ -133,7 +136,8 @@ class ScalarModelBase(ModelConfigBase):
|
|||
image_file_column: Optional[str] = param.String(default=None, allow_None=True,
|
||||
doc="The column that contains the path to image files.")
|
||||
expected_column_values: List[Tuple[str, str]] = \
|
||||
param.List(default=None, doc="List of tuples with column name and expected value to filter rows in the dataset csv file",
|
||||
param.List(default=None,
|
||||
doc="List of tuples with column name and expected value to filter rows in the dataset csv file",
|
||||
allow_None=True)
|
||||
label_channels: Optional[List[str]] = \
|
||||
param.List(default=None, allow_None=True,
|
||||
|
@ -384,13 +388,27 @@ class ScalarModelBase(ModelConfigBase):
|
|||
|
||||
def create_torch_datasets(self, dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
|
||||
image_transforms = self.get_image_sample_transforms()
|
||||
train = ScalarDataset(args=self, data_frame=dataset_splits.train,
|
||||
name="training", sample_transforms=image_transforms.train) # type: ignore
|
||||
val = ScalarDataset(args=self, data_frame=dataset_splits.val, feature_statistics=train.feature_statistics,
|
||||
name="validation", sample_transforms=image_transforms.val) # type: ignore
|
||||
test = ScalarDataset(args=self, data_frame=dataset_splits.test, feature_statistics=train.feature_statistics,
|
||||
name="test", sample_transforms=image_transforms.test) # type: ignore
|
||||
sample_transform = self.get_scalar_item_transform()
|
||||
assert sample_transform.train is not None # for mypy
|
||||
assert sample_transform.val is not None # for mypy
|
||||
assert sample_transform.test is not None # for mypy
|
||||
train = ScalarDataset(
|
||||
args=self,
|
||||
data_frame=dataset_splits.train,
|
||||
name="training",
|
||||
sample_transform=sample_transform.train)
|
||||
val = ScalarDataset(
|
||||
args=self,
|
||||
data_frame=dataset_splits.val,
|
||||
feature_statistics=train.feature_statistics,
|
||||
name="validation",
|
||||
sample_transform=sample_transform.val)
|
||||
test = ScalarDataset(
|
||||
args=self,
|
||||
data_frame=dataset_splits.test,
|
||||
feature_statistics=train.feature_statistics,
|
||||
name="test",
|
||||
sample_transform=sample_transform.test)
|
||||
|
||||
return {
|
||||
ModelExecutionMode.TRAIN: train,
|
||||
|
@ -459,14 +477,29 @@ class ScalarModelBase(ModelConfigBase):
|
|||
def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
|
||||
return super().get_model_train_test_dataset_splits(dataset_df)
|
||||
|
||||
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
|
||||
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
|
||||
"""
|
||||
Get transforms to perform on samples for each model execution mode.
|
||||
Get transforms to apply to images for each model execution mode.
|
||||
By default only no transformation is performed.
|
||||
For data augmentation, specify a Compose3D for the training execution mode.
|
||||
"""
|
||||
return ModelTransformsPerExecutionMode()
|
||||
|
||||
def get_segmentation_transform(self) -> ModelTransformsPerExecutionMode:
|
||||
"""
|
||||
Get transforms to apply on segmentations maps inputs for each model execution mode.
|
||||
By default only no transformation is performed.
|
||||
"""
|
||||
return ModelTransformsPerExecutionMode()
|
||||
|
||||
def get_scalar_item_transform(self) -> ModelTransformsPerExecutionMode:
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarItemAugmentation
|
||||
image_transform = self.get_image_transform()
|
||||
segmentation_transform = self.get_segmentation_transform()
|
||||
return ModelTransformsPerExecutionMode(
|
||||
train=ScalarItemAugmentation(image_transform.train, segmentation_transform.train),
|
||||
val=ScalarItemAugmentation(image_transform.val, segmentation_transform.val),
|
||||
test=ScalarItemAugmentation(image_transform.test, segmentation_transform.test))
|
||||
|
||||
|
||||
def get_non_image_features_dict(default_channels: List[str],
|
||||
specific_channels: Optional[Dict[str, List[str]]] = None) -> Dict[str, List[str]]:
|
||||
|
|
|
@ -90,13 +90,25 @@ class SequenceModelBase(ScalarModelBase):
|
|||
|
||||
def create_torch_datasets(self, dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
|
||||
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
|
||||
sample_transforms = self.get_image_sample_transforms()
|
||||
train = SequenceDataset(self, dataset_splits.train, name="training",
|
||||
sample_transforms=sample_transforms.train) # type: ignore
|
||||
val = SequenceDataset(self, dataset_splits.val, feature_statistics=train.feature_statistics, name="validation",
|
||||
sample_transforms=sample_transforms.val) # type: ignore
|
||||
test = SequenceDataset(self, dataset_splits.test, feature_statistics=train.feature_statistics, name="test",
|
||||
sample_transforms=sample_transforms.test) # type: ignore
|
||||
sample_transform = self.get_scalar_item_transform()
|
||||
assert sample_transform.train is not None # for mypy
|
||||
assert sample_transform.val is not None # for mypy
|
||||
assert sample_transform.test is not None # for mypy
|
||||
|
||||
train = SequenceDataset(self,
|
||||
dataset_splits.train,
|
||||
name="training",
|
||||
sample_transform=sample_transform.train)
|
||||
val = SequenceDataset(self,
|
||||
dataset_splits.val,
|
||||
feature_statistics=train.feature_statistics,
|
||||
name="validation",
|
||||
sample_transform=sample_transform.val)
|
||||
test = SequenceDataset(self,
|
||||
dataset_splits.test,
|
||||
feature_statistics=train.feature_statistics,
|
||||
name="test",
|
||||
sample_transform=sample_transform.test)
|
||||
|
||||
return {
|
||||
ModelExecutionMode.TRAIN: train,
|
||||
|
|
|
@ -1,475 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import random
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from InnerEye.Common.common_util import any_pairwise_larger
|
||||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.dataset.sample import Sample
|
||||
from InnerEye.ML.dataset.scalar_sample import ScalarItem
|
||||
from InnerEye.ML.utils.transforms import Transform3D
|
||||
|
||||
|
||||
def random_select_patch_center(sample: Sample, class_weights: List[float] = None) -> np.ndarray:
|
||||
"""
|
||||
Samples a point to use as the coordinates of the patch center. First samples one
|
||||
class among the available classes then samples a center point among the pixels of the sampled
|
||||
class.
|
||||
|
||||
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
|
||||
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
|
||||
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
|
||||
:return numpy int array (3x1) containing patch center spatial coordinates
|
||||
"""
|
||||
num_classes = sample.labels.shape[0]
|
||||
|
||||
if class_weights is not None:
|
||||
if len(class_weights) != num_classes:
|
||||
raise Exception("A weight must be provided for each class, found weights:{}, expected:{}"
|
||||
.format(len(class_weights), num_classes))
|
||||
SegmentationModelBase.validate_class_weights(class_weights)
|
||||
|
||||
# If class weights are not initialised, selection is made with equal probability for all classes
|
||||
available_classes = list(range(num_classes))
|
||||
original_class_weights = class_weights
|
||||
while len(available_classes) > 0:
|
||||
selected_label_class = random.choices(population=available_classes, weights=class_weights, k=1)[0]
|
||||
# Check pixels where mask and label maps are both foreground
|
||||
indices = np.argwhere(np.logical_and(sample.labels[selected_label_class] == 1.0, sample.mask == 1))
|
||||
if not np.any(indices):
|
||||
available_classes.remove(selected_label_class)
|
||||
if class_weights is not None:
|
||||
assert original_class_weights is not None # for mypy
|
||||
class_weights = [original_class_weights[i] for i in available_classes]
|
||||
if sum(class_weights) <= 0.0:
|
||||
raise ValueError("Cannot sample a class: no class present in the sample has a positive weight")
|
||||
else:
|
||||
break
|
||||
|
||||
# Raise an exception if non of the foreground classes are overlapping with the mask
|
||||
if len(available_classes) == 0:
|
||||
raise Exception("No non-mask voxels found, please check your mask and labels map")
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
choice = random.randint(0, len(indices) - 1)
|
||||
|
||||
return indices[choice].astype(int) # Numpy usually stores as floats
|
||||
|
||||
|
||||
def slicers_for_random_crop(sample: Sample,
|
||||
crop_size: TupleInt3,
|
||||
class_weights: List[float] = None) -> Tuple[List[slice], np.ndarray]:
|
||||
"""
|
||||
Computes array slicers that produce random crops of the given crop_size.
|
||||
The selection of the center is dependant on background probability.
|
||||
By default it does not center on background.
|
||||
|
||||
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
|
||||
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
|
||||
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
|
||||
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
|
||||
:return: Tuple element 1: The slicers that convert the input image to the chosen crop. Tuple element 2: The
|
||||
indices of the center point of the crop.
|
||||
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
|
||||
"""
|
||||
shape = sample.image.shape[1:]
|
||||
|
||||
if any_pairwise_larger(crop_size, shape):
|
||||
raise ValueError("The crop_size across each dimension should be greater than zero and less than or equal "
|
||||
"to the current value (crop_size: {}, spatial shape: {})"
|
||||
.format(crop_size, shape))
|
||||
|
||||
# Sample a center pixel location for patch extraction.
|
||||
center = random_select_patch_center(sample, class_weights)
|
||||
|
||||
# Verify and fix overflow for each dimension
|
||||
left = []
|
||||
for i in range(3):
|
||||
margin_left = int(crop_size[i] / 2)
|
||||
margin_right = crop_size[i] - margin_left
|
||||
left_index = center[i] - margin_left
|
||||
right_index = center[i] + margin_right
|
||||
if right_index > shape[i]:
|
||||
left_index = left_index - (right_index - shape[i])
|
||||
if left_index < 0:
|
||||
left_index = 0
|
||||
left.append(left_index)
|
||||
|
||||
return [slice(left[x], left[x] + crop_size[x]) for x in range(0, 3)], center
|
||||
|
||||
|
||||
def random_crop(sample: Sample,
|
||||
crop_size: TupleInt3,
|
||||
class_weights: List[float] = None) -> Tuple[Sample, np.ndarray]:
|
||||
"""
|
||||
Randomly crops images, mask, and labels arrays according to the crop_size argument.
|
||||
The selection of the center is dependant on background probability.
|
||||
By default it does not center on background.
|
||||
|
||||
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
|
||||
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
|
||||
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
|
||||
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
|
||||
:return: Tuple item 1: The cropped images, labels, and mask. Tuple item 2: The center that was chosen for the crop,
|
||||
before shifting to be inside of the image. Tuple item 3: The slicers that convert the input image to the chosen
|
||||
crop.
|
||||
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
|
||||
"""
|
||||
slicers, center = slicers_for_random_crop(sample, crop_size, class_weights)
|
||||
sample = Sample(
|
||||
image=sample.image[:, slicers[0], slicers[1], slicers[2]],
|
||||
labels=sample.labels[:, slicers[0], slicers[1], slicers[2]],
|
||||
mask=sample.mask[slicers[0], slicers[1], slicers[2]],
|
||||
metadata=sample.metadata
|
||||
)
|
||||
return sample, center
|
||||
|
||||
|
||||
class ImageTransformationBase(Transform3D):
|
||||
"""
|
||||
This class is the base class to classes built to define data augmentation transformations
|
||||
for 3D image inputs.
|
||||
"""
|
||||
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self,
|
||||
is_transformation_for_segmentation_maps: bool = False,
|
||||
use_joint_channel_transformation: bool = False):
|
||||
"""
|
||||
:param is_transformation_for_segmentation_maps: if True, only use geometrical transformation suitable
|
||||
for segmentation maps. If False, additionally use color/contrast related transformation suitable for
|
||||
images.
|
||||
:param use_joint_channel_transformation: if True apply the exact same transformation for all channels of
|
||||
a given image. If False, apply a different transformation for each channel.
|
||||
"""
|
||||
self.for_segmentation_input_maps = is_transformation_for_segmentation_maps
|
||||
self.use_joint_channel_transformation = use_joint_channel_transformation
|
||||
|
||||
def draw_next_transform(self) -> List[Callable]:
|
||||
"""
|
||||
Samples all parameters defining the transformation pipeline.
|
||||
Returns a list of operations to apply to each 2D-slice in a given
|
||||
3D volume.
|
||||
(defined by the sampled parameters).
|
||||
|
||||
:return: list of transformations to apply to each B-scan.
|
||||
"""
|
||||
raise NotImplementedError("The child class should implement the sampling of transforms")
|
||||
|
||||
@staticmethod
|
||||
def apply_transform_on_3d_image(image: torch.Tensor, transforms: List[Callable]) -> torch.Tensor:
|
||||
"""
|
||||
Apply a list of transforms sequentially to a 3D image. Each transformation is assumed to be a 2D-transform to
|
||||
be applied to each slice of the given 3D input separately.
|
||||
|
||||
:param image: a 3d tensor dimension [Z, X, Y]. Transform are applied one after another
|
||||
separately for each [X, Y] slice along the Z-dimension (assumed to be the first dimension).
|
||||
:param transforms: a list of transformations to apply to each slice sequentially.
|
||||
:returns image: the transformed 3D-image
|
||||
"""
|
||||
for z in range(image.shape[0]):
|
||||
pil = TF.to_pil_image(image[z])
|
||||
for transform_fn in transforms:
|
||||
pil = transform_fn(pil)
|
||||
image[z] = TF.to_tensor(pil).squeeze()
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def _toss_fair_coin() -> bool:
|
||||
"""
|
||||
Simulates the toss of a fair coin.
|
||||
:returns the outcome of the toss.
|
||||
"""
|
||||
return random.random() > 0.5
|
||||
|
||||
@staticmethod
|
||||
def randomly_negate_level(value: Any) -> Any:
|
||||
"""
|
||||
Negate the value of the input with probability 0.5
|
||||
"""
|
||||
return -value if ImageTransformationBase._toss_fair_coin() else value
|
||||
|
||||
@staticmethod
|
||||
def identity() -> Callable:
|
||||
"""
|
||||
Identity transform.
|
||||
"""
|
||||
return lambda img: img
|
||||
|
||||
@staticmethod
|
||||
def rotate(angle: float) -> Callable:
|
||||
"""
|
||||
Returns a function that rotates a 2D image by
|
||||
a certain angle.
|
||||
"""
|
||||
return lambda img: TF.rotate(img, angle)
|
||||
|
||||
@staticmethod
|
||||
def translateX(shift: float) -> Callable:
|
||||
"""
|
||||
Returns a function that shifts a 2D image horizontally by
|
||||
a given shift.
|
||||
:param shift: a floating point between (-1, 1), the shift is defining as the
|
||||
proportion of the image width by which to shift.
|
||||
"""
|
||||
return lambda img: TF.affine(img, 0, (shift * img.size[0], 0), 1, 0)
|
||||
|
||||
@staticmethod
|
||||
def translateY(shift: float) -> Callable:
|
||||
"""
|
||||
Returns a function that shifts a 2D image vertically by
|
||||
a given shift.
|
||||
:param shift: a floating point between (-1, 1), the shift is defining as the
|
||||
proportion of the image height by which to shift.
|
||||
"""
|
||||
return lambda img: TF.affine(img, 0, (0, shift * img.size[1]), 1, 0)
|
||||
|
||||
@staticmethod
|
||||
def horizontal_flip() -> Callable:
|
||||
"""
|
||||
Returns a function that is flipping a 2D-image horizontally.
|
||||
"""
|
||||
return lambda img: TF.hflip(img)
|
||||
|
||||
@staticmethod
|
||||
def adjust_contrast(constrast_factor: float) -> Callable:
|
||||
"""
|
||||
Returns a function that modifies the contrast of a
|
||||
2D image by a certain factor.
|
||||
:param constrast_factor: Integer > 0. 0 means black image,
|
||||
1 means no transformation, 2 means multipyling the contrast
|
||||
by two.
|
||||
"""
|
||||
return lambda img: TF.adjust_contrast(img, constrast_factor)
|
||||
|
||||
@staticmethod
|
||||
def adjust_brightness(brightness_factor: float) -> Callable:
|
||||
"""
|
||||
Returns a function that modifies the brightness of a
|
||||
2D image by a certain factor.
|
||||
:param brightness_factor: Integer > 0. 0 means black image,
|
||||
1 means no transformation, 2 means multipyling the brightness
|
||||
by two.
|
||||
"""
|
||||
return lambda img: TF.adjust_brightness(img, brightness_factor)
|
||||
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Main function to apply the transformation to one 3D-image.
|
||||
Assumes the same transformations have to be applied on
|
||||
each 2D-slice along the Z-axis.
|
||||
Assumes the Z axis is the first dimension.
|
||||
|
||||
:param image: batch of images of size [C, Z, Y, X]
|
||||
"""
|
||||
assert len(image.shape) == 4
|
||||
res = image.clone()
|
||||
if self.for_segmentation_input_maps:
|
||||
res = res.int()
|
||||
else:
|
||||
res = res.float()
|
||||
if res.max() > 1:
|
||||
raise ValueError("Image tensor should be in "
|
||||
"range 0-1 for conversion to PIL")
|
||||
|
||||
# Sample parameters defining the transformation
|
||||
transforms = self.draw_next_transform()
|
||||
for c in range(image.shape[0]):
|
||||
res[c] = self.apply_transform_on_3d_image(res[c], transforms)
|
||||
if not self.use_joint_channel_transformation:
|
||||
# Resample transformations for the next channel
|
||||
transforms = self.draw_next_transform()
|
||||
return res.to(dtype=image.dtype)
|
||||
|
||||
|
||||
class RandomSliceTransformation(ImageTransformationBase):
|
||||
"""
|
||||
Class to apply a random set of 2D affine transformations to all
|
||||
slices of a 3D volume separately along the z-dimension.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
probability_transformation: float = 0.8,
|
||||
max_angle: int = 10,
|
||||
max_x_shift: float = 0.05,
|
||||
max_y_shift: float = 0.1,
|
||||
max_contrast: float = 2,
|
||||
min_constrast: float = 0,
|
||||
max_brightness: float = 2,
|
||||
min_brightness: float = 0,
|
||||
**kwargs: Any) -> None:
|
||||
"""
|
||||
|
||||
:param probability_transformation: probability of applying the transformation pipeline.
|
||||
:param max_angle: maximum allowed angle for rotation. For each transformation
|
||||
the angle is drawn uniformly between -max_angle and max_angle.
|
||||
:param min_constrast: Minimum contrast factor to apply. 1 means no difference.
|
||||
2 means doubling the contrast. 0 means a black image. Parameter is sampled
|
||||
between min_contrast and max_contrast.
|
||||
:param max_contrast: maximum contrast factor
|
||||
:param max_brightness: Maximum brightness factor to apply. 1 means no difference.
|
||||
2 means doubling the brightness. 0 means a black image. Parameter is sampled
|
||||
between min_brightness and max_brightness.
|
||||
:param max_x_shift: maximum vertical shift in proportion of the image width
|
||||
:param max_y_shift: maximum horizontal shift in proportion of the image height.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.probability_transformation = probability_transformation
|
||||
self.max_angle = max_angle
|
||||
self.max_x_shift = max_x_shift
|
||||
self.max_y_shift = max_y_shift
|
||||
self.max_constrast = max_contrast
|
||||
self.min_constrast = min_constrast
|
||||
self.min_brightness = min_brightness
|
||||
self.max_brightness = max_brightness
|
||||
|
||||
def draw_next_transform(self) -> List[Callable]:
|
||||
"""
|
||||
Samples all parameters defining the transformation pipeline.
|
||||
Returns a list of operations to apply to each slice in the
|
||||
3D volume.
|
||||
|
||||
:return: list of transformations to apply to each slice.
|
||||
"""
|
||||
# Sample parameters for each transformation
|
||||
angle = random.randint(-self.max_angle, self.max_angle)
|
||||
x_shift = random.uniform(-self.max_x_shift, self.max_x_shift)
|
||||
y_shift = random.uniform(-self.max_y_shift, self.max_y_shift)
|
||||
contrast = random.uniform(self.min_constrast, self.max_constrast)
|
||||
brightness = random.uniform(self.min_brightness, self.max_brightness)
|
||||
horizontal_flip = ImageTransformationBase._toss_fair_coin()
|
||||
# Returns the corresponding operations
|
||||
if random.random() < self.probability_transformation:
|
||||
ops = [self.rotate(angle),
|
||||
self.translateX(x_shift),
|
||||
self.translateY(y_shift)]
|
||||
if horizontal_flip:
|
||||
ops.append(self.horizontal_flip())
|
||||
if self.for_segmentation_input_maps:
|
||||
return ops
|
||||
ops.extend([self.adjust_contrast(contrast),
|
||||
self.adjust_brightness(brightness)])
|
||||
else:
|
||||
ops = []
|
||||
return ops
|
||||
|
||||
|
||||
class RandAugmentSlice(ImageTransformationBase):
|
||||
"""
|
||||
Implements the RandAugment procedure on a restricted set of
|
||||
transformations. https://arxiv.org/abs/1909.13719
|
||||
|
||||
Possible transformations for segmentations maps are: rotation, horizontal
|
||||
and vertical shift, horizontal flip, identity. Additional transformations
|
||||
for images are brightness adjustment and contrast adjustment.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
magnitude: int = 3,
|
||||
n_transforms: int = 2,
|
||||
**kwargs: Any) -> None:
|
||||
"""
|
||||
:param magnitude: magnitude to apply to the transformations as defined in the RandAugment paper.
|
||||
1 means a weak transform, 10 is the strongest transform.
|
||||
:param n_transforms: number of transformation to sample for each image.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.magnitude = magnitude
|
||||
self.n_transforms = n_transforms
|
||||
self._max_magnitude = 10.0
|
||||
self._max_x_shift = 0.1
|
||||
self._max_y_shift = 0.2
|
||||
self._max_angle = 30
|
||||
self._max_contrast = 1
|
||||
self._max_brightness = 1
|
||||
|
||||
def get_all_transforms(self) -> Dict[str, Callable]:
|
||||
"""
|
||||
Defines the possible transformations for one fixed magnitude level
|
||||
to sample from.
|
||||
"""
|
||||
# Convert magnitude to argument for each transform
|
||||
level = self.magnitude / self._max_magnitude
|
||||
angle = self.randomly_negate_level(level) * self._max_angle
|
||||
x_shift = self.randomly_negate_level(level) * self._max_x_shift
|
||||
y_shift = self.randomly_negate_level(level) * self._max_y_shift
|
||||
# Contrast / brightness factor of 1 means no change. 0 means black, 2 means times 2.
|
||||
contrast = self.randomly_negate_level(level) * self._max_contrast + 1
|
||||
brightness = self.randomly_negate_level(level) * self._max_brightness + 1
|
||||
transforms_dict = {
|
||||
"identity": self.identity(),
|
||||
"rotate": self.rotate(angle),
|
||||
"translateX": self.translateX(x_shift),
|
||||
"translateY": self.translateY(y_shift),
|
||||
"hFlip": self.horizontal_flip()
|
||||
}
|
||||
if self.for_segmentation_input_maps:
|
||||
return transforms_dict
|
||||
|
||||
transforms_dict.update(
|
||||
{"constrast": self.adjust_contrast(contrast),
|
||||
"brightness": self.adjust_brightness(brightness),
|
||||
})
|
||||
|
||||
return transforms_dict
|
||||
|
||||
def draw_next_transform(self) -> List[Callable]:
|
||||
"""
|
||||
Samples all parameters defining the transformation pipeline.
|
||||
Returns a list of operations to apply to each slice of a 3D volume
|
||||
(defined by the sampled parameters).
|
||||
|
||||
:return: list of transformations to apply to each slice.
|
||||
"""
|
||||
available_transforms = self.get_all_transforms()
|
||||
transform_names = np.random.choice(list(available_transforms), self.n_transforms)
|
||||
ops = [available_transforms[name] for name in transform_names]
|
||||
return ops
|
||||
|
||||
|
||||
class ScalarItemAugmentation(Transform3D[ScalarItem]):
|
||||
"""
|
||||
Wrapper around an augmentation pipeline for applying an image transformation
|
||||
to a ScalarItem input and return the transformed sample. Applies the
|
||||
transformation either to the images or the segmentation maps depending on the
|
||||
defined transformation to apply. Several objects of this class can be applied
|
||||
in a row inside a Compose3D object.
|
||||
"""
|
||||
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self, transform: ImageTransformationBase):
|
||||
"""
|
||||
|
||||
:param transform: the transformation to apply to the image.
|
||||
"""
|
||||
self.transform = transform
|
||||
|
||||
def __call__(self, item: ScalarItem) -> ScalarItem:
|
||||
if hasattr(self.transform, "for_segmentation_input_maps") and self.transform.for_segmentation_input_maps:
|
||||
if item.segmentations is None:
|
||||
raise ValueError("A segmentation data augmentation transform has been"
|
||||
"specified but no segmentations has been loaded.")
|
||||
return item.clone_with_overrides(segmentations=self.transform(item.segmentations))
|
||||
else:
|
||||
return item.clone_with_overrides(images=self.transform(item.images))
|
||||
|
||||
|
||||
class SampleImageAugmentation(Transform3D[Sample]):
|
||||
"""
|
||||
Wrapper around augmentation pipeline for applying an image transformation
|
||||
to a Sample input (for segmentation models).
|
||||
"""
|
||||
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self, transform: ImageTransformationBase) -> None:
|
||||
self.transform = transform
|
||||
|
||||
def __call__(self, item: Sample) -> Sample:
|
||||
return item.clone_with_overrides(image=self.transform(item.image))
|
|
@ -10,13 +10,14 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import param
|
||||
|
||||
from InnerEye.ML.augmentations.augmentation_for_segmentation_utils import slicers_for_random_crop
|
||||
from InnerEye.Common.generic_parsing import GenericConfig
|
||||
from InnerEye.ML.config import SegmentationModelBase
|
||||
from InnerEye.ML.dataset.cropping_dataset import CroppingDataset
|
||||
from InnerEye.ML.dataset.full_image_dataset import FullImageDataset
|
||||
from InnerEye.ML.dataset.sample import Sample
|
||||
from InnerEye.ML.plotting import resize_and_save, scan_with_transparent_overlay
|
||||
from InnerEye.ML.utils import augmentation, io_util
|
||||
from InnerEye.ML.utils import io_util
|
||||
# The name of the folder inside the default outputs folder that will holds plots that show the effect of
|
||||
# sampling random patches
|
||||
from InnerEye.ML.utils.image_util import get_unit_image_header
|
||||
|
@ -60,9 +61,9 @@ def visualize_random_crops(sample: Sample,
|
|||
# Nifti file of that datatype.
|
||||
repeats = 200
|
||||
for _ in range(repeats):
|
||||
slicers, _ = augmentation.slicers_for_random_crop(sample=sample,
|
||||
crop_size=config.crop_size,
|
||||
class_weights=config.class_weights)
|
||||
slicers, _ = slicers_for_random_crop(sample=sample,
|
||||
crop_size=config.crop_size,
|
||||
class_weights=config.class_weights)
|
||||
heatmap[slicers[0], slicers[1], slicers[2]] += 1
|
||||
is_3dim = heatmap.shape[0] > 1
|
||||
header = sample.metadata.image_header
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from InnerEye.ML.augmentations.augmentation_for_segmentation_utils import random_crop
|
||||
from InnerEye.ML.dataset.sample import Sample
|
||||
from InnerEye.ML.utils import ml_util
|
||||
|
||||
from Tests.ML.util import DummyPatientMetadata
|
||||
|
||||
ml_util.set_random_seed(1)
|
||||
|
||||
image_size = (8, 8, 8)
|
||||
valid_image_4d = np.random.uniform(size=((5,) + image_size)) * 10
|
||||
valid_mask = np.random.randint(2, size=image_size)
|
||||
number_of_classes = 5
|
||||
class_assignments = np.random.randint(2, size=image_size)
|
||||
valid_labels = np.zeros((number_of_classes,) + image_size)
|
||||
for c in range(number_of_classes):
|
||||
valid_labels[c, class_assignments == c] = 1
|
||||
valid_crop_size = (2, 2, 2)
|
||||
valid_full_crop_size = image_size
|
||||
valid_class_weights = [0.5] + [0.5 / (number_of_classes - 1)] * (number_of_classes - 1)
|
||||
crop_size_requires_padding = (9, 8, 12)
|
||||
|
||||
|
||||
def test_valid_full_crop() -> None:
|
||||
metadata = DummyPatientMetadata
|
||||
sample, _ = random_crop(sample=Sample(image=valid_image_4d,
|
||||
labels=valid_labels,
|
||||
mask=valid_mask,
|
||||
metadata=metadata),
|
||||
crop_size=valid_full_crop_size,
|
||||
class_weights=valid_class_weights)
|
||||
|
||||
assert np.array_equal(sample.image, valid_image_4d)
|
||||
assert np.array_equal(sample.labels, valid_labels)
|
||||
assert np.array_equal(sample.mask, valid_mask)
|
||||
assert sample.metadata == metadata
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image", [None, list(), valid_image_4d])
|
||||
@pytest.mark.parametrize("labels", [None, list(), valid_labels])
|
||||
@pytest.mark.parametrize("mask", [None, list(), valid_mask])
|
||||
@pytest.mark.parametrize("class_weights", [[0, 0, 0], [0], [-1, 0, 1], [-1, -2, -3], valid_class_weights])
|
||||
def test_invalid_arrays(image: Any, labels: Any, mask: Any, class_weights: Any) -> None:
|
||||
"""
|
||||
Tests failure cases of the random_crop function for invalid image, labels, mask or class
|
||||
weights arguments.
|
||||
"""
|
||||
# Skip the final combination, because it is valid
|
||||
if not (np.array_equal(image, valid_image_4d) and np.array_equal(labels, valid_labels)
|
||||
and np.array_equal(mask, valid_mask) and class_weights == valid_class_weights):
|
||||
with pytest.raises(Exception):
|
||||
random_crop(Sample(metadata=DummyPatientMetadata, image=image, labels=labels, mask=mask),
|
||||
valid_crop_size, class_weights)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("crop_size", [None, ["a"], 5])
|
||||
def test_invalid_crop_arg(crop_size: Any) -> None:
|
||||
with pytest.raises(Exception):
|
||||
random_crop(
|
||||
Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=valid_mask),
|
||||
crop_size, valid_class_weights)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("crop_size", [[2, 2], [2, 2, 2, 2], [10, 10, 10]])
|
||||
def test_invalid_crop_size(crop_size: Any) -> None:
|
||||
with pytest.raises(Exception):
|
||||
random_crop(
|
||||
Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=valid_mask),
|
||||
crop_size, valid_class_weights)
|
||||
|
||||
|
||||
def test_random_crop_no_fg() -> None:
|
||||
with pytest.raises(Exception):
|
||||
random_crop(Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels,
|
||||
mask=np.zeros_like(valid_mask)),
|
||||
valid_crop_size, valid_class_weights)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
random_crop(Sample(metadata=DummyPatientMetadata, image=valid_image_4d,
|
||||
labels=np.zeros_like(valid_labels), mask=valid_mask),
|
||||
valid_crop_size, valid_class_weights)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("crop_size", [valid_crop_size])
|
||||
def test_random_crop(crop_size: Any) -> None:
|
||||
labels = valid_labels
|
||||
# create labels such that there are no foreground voxels in a particular class
|
||||
# this should ne handled gracefully (class being ignored from sampling)
|
||||
labels[0] = 1
|
||||
labels[1] = 0
|
||||
sample, _ = random_crop(Sample(
|
||||
image=valid_image_4d,
|
||||
labels=valid_labels,
|
||||
mask=valid_mask,
|
||||
metadata=DummyPatientMetadata
|
||||
), crop_size, valid_class_weights)
|
||||
|
||||
expected_img_crop_size = (valid_image_4d.shape[0], *crop_size)
|
||||
expected_labels_crop_size = (valid_labels.shape[0], *crop_size)
|
||||
|
||||
assert sample.image.shape == expected_img_crop_size
|
||||
assert sample.labels.shape == expected_labels_crop_size
|
||||
assert sample.mask.shape == tuple(crop_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("class_weights",
|
||||
[None, [0, 0.5, 0.5, 0, 0], [0.1, 0.45, 0.45, 0, 0], [0.5, 0.25, 0.25, 0, 0],
|
||||
[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0.04, 0.12, 0.20, 0.28, 0.36], [0, 0.5, 0, 0.5, 0]])
|
||||
def test_valid_class_weights(class_weights: List[float]) -> None:
|
||||
"""
|
||||
Produce a large number of crops and make sure the crop center class proportions respect class weights
|
||||
"""
|
||||
ml_util.set_random_seed(1)
|
||||
num_classes = len(valid_labels)
|
||||
image = np.zeros_like(valid_image_4d)
|
||||
labels = np.zeros_like(valid_labels)
|
||||
class0, class1, class2 = non_empty_classes = [0, 2, 4]
|
||||
labels[class0] = 1
|
||||
labels[class0][3, 3, 3] = 0
|
||||
labels[class0][3, 2, 3] = 0
|
||||
labels[class1][3, 3, 3] = 1
|
||||
labels[class2][3, 2, 3] = 1
|
||||
|
||||
mask = np.ones_like(valid_mask)
|
||||
sample = Sample(image=image, labels=labels, mask=mask, metadata=DummyPatientMetadata)
|
||||
|
||||
crop_size = (1, 1, 1)
|
||||
total_crops = 200
|
||||
sampled_label_center_distribution = np.zeros(num_classes)
|
||||
|
||||
# If there is no class that has a non-zero weight and is present in the sample, there is no valid
|
||||
# way to select a class, so we expect an exception to be thrown.
|
||||
if class_weights is not None and sum(class_weights[c] for c in non_empty_classes) == 0:
|
||||
with pytest.raises(ValueError):
|
||||
random_crop(sample, crop_size, class_weights)
|
||||
return
|
||||
|
||||
for _ in range(0, total_crops):
|
||||
crop_sample, center = random_crop(sample, crop_size, class_weights)
|
||||
sampled_class = list(labels[:, center[0], center[1], center[2]]).index(1)
|
||||
sampled_label_center_distribution[sampled_class] += 1
|
||||
|
||||
sampled_label_center_distribution /= total_crops
|
||||
|
||||
if class_weights is None:
|
||||
weight = 1.0 / len(non_empty_classes)
|
||||
expected_label_center_distribution = [weight if c in non_empty_classes else 0.0
|
||||
for c in range(number_of_classes)]
|
||||
else:
|
||||
total = sum(class_weights[c] for c in non_empty_classes)
|
||||
expected_label_center_distribution = [class_weights[c] / total if c in non_empty_classes else 0.0
|
||||
for c in range(number_of_classes)]
|
||||
assert np.allclose(sampled_label_center_distribution, expected_label_center_distribution, atol=0.1)
|
|
@ -0,0 +1,96 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from torchvision.transforms.functional import to_pil_image, to_tensor
|
||||
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma
|
||||
|
||||
image_size = (256, 256)
|
||||
|
||||
test_tensor_1channel_1slice = torch.ones([1, 1, *image_size], dtype=torch.float)
|
||||
test_tensor_1channel_1slice[..., 100:150, 100:200] = 1 / 255.
|
||||
min_test_image, max_test_image = test_tensor_1channel_1slice.min(), test_tensor_1channel_1slice.max()
|
||||
test_tensor_2channels_2slices = torch.ones([2, 2, *image_size], dtype=torch.float)
|
||||
test_tensor_2channels_2slices[..., 100:150, 100:200] = 1 / 255.
|
||||
invalid_test_tensor = torch.ones([1, *image_size])
|
||||
test_pil_image = Image.open(str(full_ml_test_data_path() / "image_and_contour.png")).convert("RGB")
|
||||
test_image_as_tensor = to_tensor(test_pil_image).unsqueeze(0) # put in a [1, C, H, W] format
|
||||
|
||||
|
||||
def test_add_gaussian_noise() -> None:
|
||||
"""
|
||||
Tests functionality of add gaussian noise
|
||||
"""
|
||||
# Test case of image with 1 channel, 1 slice (2D)
|
||||
torch.manual_seed(10)
|
||||
transformed = AddGaussianNoise(std=0.05, p_apply=1)(test_tensor_1channel_1slice.clone())
|
||||
torch.manual_seed(10)
|
||||
noise = torch.randn(size=(1, *image_size)) * 0.05
|
||||
assert torch.isclose(
|
||||
torch.clamp(test_tensor_1channel_1slice + noise, min_test_image, max_test_image), # type: ignore
|
||||
transformed).all()
|
||||
|
||||
# Test p_apply = 0
|
||||
untransformed = AddGaussianNoise(std=0.05, p_apply=0)(test_tensor_1channel_1slice.clone())
|
||||
assert torch.isclose(untransformed, test_tensor_1channel_1slice).all()
|
||||
|
||||
# Check that it applies the same transform to all slices if number of slices > 1
|
||||
torch.manual_seed(10)
|
||||
transformed = AddGaussianNoise(std=0.05, p_apply=1)(test_tensor_2channels_2slices.clone())
|
||||
assert torch.isclose(
|
||||
torch.clamp(test_tensor_2channels_2slices + noise, min_test_image, max_test_image), # type: ignore
|
||||
transformed).all()
|
||||
|
||||
|
||||
def test_elastic_transform() -> None:
|
||||
"""
|
||||
Tests elastic transform
|
||||
"""
|
||||
np.random.seed(7)
|
||||
transformed_image = ElasticTransform(sigma=4, alpha=34, p_apply=1.0)(test_image_as_tensor.clone())
|
||||
transformed_pil = to_pil_image(transformed_image.squeeze(0))
|
||||
expected_pil_image = Image.open(full_ml_test_data_path() / "elastic_transformed_image_and_contour.png").convert(
|
||||
"RGB")
|
||||
assert expected_pil_image == transformed_pil
|
||||
untransformed_image = ElasticTransform(sigma=4, alpha=34, p_apply=0.0)(test_image_as_tensor.clone())
|
||||
assert torch.isclose(test_image_as_tensor, untransformed_image).all()
|
||||
|
||||
|
||||
def test_expand_channels() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ExpandChannels()(invalid_test_tensor)
|
||||
|
||||
tensor_img = ExpandChannels()(test_tensor_1channel_1slice.clone())
|
||||
assert tensor_img.shape == torch.Size([1, 3, *image_size])
|
||||
|
||||
|
||||
def test_random_gamma() -> None:
|
||||
# This is invalid input (expects 4 dimensions)
|
||||
with pytest.raises(ValueError):
|
||||
RandomGamma(scale=(0.3, 3))(invalid_test_tensor)
|
||||
|
||||
random.seed(0)
|
||||
transformed_1 = RandomGamma(scale=(0.3, 3))(test_tensor_1channel_1slice.clone())
|
||||
assert transformed_1.shape == test_tensor_1channel_1slice.shape
|
||||
|
||||
tensor_img = torch.ones([2, 3, *image_size])
|
||||
transformed_2 = RandomGamma(scale=(0.3, 3))(tensor_img)
|
||||
# If you run on 1 channel, 1 Z dimension the gamma transform applied should be the same for all slices.
|
||||
assert transformed_2.shape == torch.Size([2, 3, *image_size])
|
||||
assert torch.isclose(transformed_2[0], transformed_2[1]).all()
|
||||
assert torch.isclose(transformed_2[0, 1], transformed_2[0, 2]).all() and \
|
||||
torch.isclose(transformed_2[0, 0], transformed_2[0, 2]).all()
|
||||
|
||||
human_readable_transformed = to_pil_image(RandomGamma(scale=(2, 3))(test_image_as_tensor).squeeze(0))
|
||||
expected_pil_image = Image.open(full_ml_test_data_path() / "gamma_transformed_image_and_contour.png").convert("RGB")
|
||||
assert expected_pil_image == human_readable_transformed
|
|
@ -0,0 +1,165 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import random
|
||||
|
||||
import PIL
|
||||
import pytest
|
||||
import torch
|
||||
from torchvision.transforms import CenterCrop, ColorJitter, RandomAffine, RandomErasing, RandomHorizontalFlip, \
|
||||
RandomResizedCrop, Resize, ToTensor
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
|
||||
from InnerEye.ML.augmentations.image_transforms import AddGaussianNoise, ElasticTransform, ExpandChannels, RandomGamma
|
||||
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline, \
|
||||
create_cxr_transforms_from_config
|
||||
|
||||
from Tests.SSL.test_data_modules import cxr_augmentation_config
|
||||
|
||||
import numpy as np
|
||||
|
||||
image_size = (32, 32)
|
||||
crop_size = 24
|
||||
test_image_as_array = np.ones(list(image_size)) * 255.
|
||||
test_image_as_array[10:15, 10:20] = 1
|
||||
test_image_as_pil = PIL.Image.fromarray(test_image_as_array).convert("L")
|
||||
test_2d_image_as_CHW_tensor = to_tensor(test_image_as_array)
|
||||
|
||||
test_2d_image_as_ZCHW_tensor = test_2d_image_as_CHW_tensor.unsqueeze(0)
|
||||
|
||||
test_4d_scan_as_tensor = torch.ones([5, 4, *image_size]) * 255.
|
||||
test_4d_scan_as_tensor[..., 10:15, 10:20] = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_different_transformation_per_channel", [True, False])
|
||||
def test_torchvision_on_various_input(use_different_transformation_per_channel: bool) -> None:
|
||||
"""
|
||||
This tests that we can run transformation pipeline with out of the box torchvision transforms on various types
|
||||
of input: PIL image, 3D tensor, 4D tensors. Tests that use_different_transformation_per_channel has the correct
|
||||
behavior.
|
||||
"""
|
||||
|
||||
transform = ImageTransformationPipeline(
|
||||
[CenterCrop(crop_size),
|
||||
RandomErasing(),
|
||||
RandomAffine(degrees=(10, 12), shear=15, translate=(0.1, 0.3))
|
||||
],
|
||||
use_different_transformation_per_channel)
|
||||
|
||||
# Test PIL image input
|
||||
transformed = transform(test_image_as_pil)
|
||||
assert isinstance(transformed, torch.Tensor)
|
||||
assert transformed.shape == torch.Size([1, crop_size, crop_size])
|
||||
|
||||
# Test image as [C, H. W] tensor
|
||||
transformed = transform(test_2d_image_as_CHW_tensor.clone())
|
||||
assert isinstance(transformed, torch.Tensor)
|
||||
assert transformed.shape == torch.Size([1, crop_size, crop_size])
|
||||
|
||||
# Test image as [1, 1, H, W]
|
||||
transformed = transform(test_2d_image_as_ZCHW_tensor)
|
||||
assert isinstance(transformed, torch.Tensor)
|
||||
assert transformed.shape == torch.Size([1, 1, crop_size, crop_size])
|
||||
|
||||
# Test with a fake 4D scan [C, Z, H, W] -> [25, 34, 32, 32]
|
||||
transformed = transform(test_4d_scan_as_tensor)
|
||||
assert isinstance(transformed, torch.Tensor)
|
||||
assert transformed.shape == torch.Size([5, 4, crop_size, crop_size])
|
||||
|
||||
# Same transformation should be applied to all slices and channels.
|
||||
assert torch.isclose(transformed[0, 0], transformed[1, 1]).all() != use_different_transformation_per_channel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_different_transformation_per_channel", [True, False])
|
||||
def test_custom_tf_on_various_input(use_different_transformation_per_channel: bool) -> None:
|
||||
"""
|
||||
This tests that we can run transformation pipeline with our custom transforms on various types
|
||||
of input: PIL image, 3D tensor, 4D tensors. Tests that use_different_transformation_per_channel has the correct
|
||||
behavior. The transforms are test individually in test_image_transforms.py
|
||||
"""
|
||||
pipeline = ImageTransformationPipeline(
|
||||
[ElasticTransform(sigma=4, alpha=34, p_apply=1),
|
||||
AddGaussianNoise(p_apply=1, std=0.05),
|
||||
RandomGamma(scale=(0.3, 3))
|
||||
],
|
||||
use_different_transformation_per_channel)
|
||||
|
||||
# Test PIL image input
|
||||
transformed = pipeline(test_image_as_pil)
|
||||
assert transformed.shape == test_2d_image_as_CHW_tensor.shape
|
||||
|
||||
# Test image as [C, H, W] tensor
|
||||
pipeline(test_2d_image_as_CHW_tensor)
|
||||
assert transformed.shape == test_2d_image_as_CHW_tensor.shape
|
||||
|
||||
# Test image as [1, 1, H, W]
|
||||
transformed = pipeline(test_2d_image_as_ZCHW_tensor)
|
||||
assert isinstance(transformed, torch.Tensor)
|
||||
assert transformed.shape == torch.Size([1, 1, *image_size])
|
||||
|
||||
# Test with a fake scan [C, Z, H, W] -> [25, 34, 32, 32]
|
||||
transformed = pipeline(test_4d_scan_as_tensor)
|
||||
assert isinstance(transformed, torch.Tensor)
|
||||
assert transformed.shape == test_4d_scan_as_tensor.shape
|
||||
|
||||
# Same transformation should be applied to all slices and channels.
|
||||
assert torch.isclose(transformed[0, 0], transformed[1, 1]).all() != use_different_transformation_per_channel
|
||||
|
||||
|
||||
def test_create_transform_pipeline_from_config() -> None:
|
||||
"""
|
||||
Tests that the pipeline returned by create_transform_pipeline_from_config returns the expected transformation.
|
||||
"""
|
||||
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=True)
|
||||
fake_cxr_as_array = np.ones([256, 256]) * 255.
|
||||
fake_cxr_as_array[100:150, 100:200] = 1
|
||||
fake_cxr_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")
|
||||
|
||||
all_transforms = [ExpandChannels(),
|
||||
RandomAffine(degrees=180, translate=(0, 0), shear=40),
|
||||
RandomResizedCrop(scale=(0.4, 1.0), size=256),
|
||||
RandomHorizontalFlip(p=0.5),
|
||||
RandomGamma(scale=(0.5, 1.5)),
|
||||
ColorJitter(saturation=0, brightness=0.2, contrast=0.2),
|
||||
ElasticTransform(sigma=4, alpha=34, p_apply=0.4),
|
||||
CenterCrop(size=224),
|
||||
RandomErasing(scale=(0.15, 0.4), ratio=(0.33, 3)),
|
||||
AddGaussianNoise(std=0.05, p_apply=0.5)
|
||||
]
|
||||
|
||||
np.random.seed(3)
|
||||
torch.manual_seed(3)
|
||||
random.seed(3)
|
||||
|
||||
transformed_image = transformation_pipeline(fake_cxr_image)
|
||||
assert isinstance(transformed_image, torch.Tensor)
|
||||
# Expected pipeline
|
||||
image = np.ones([256, 256]) * 255.
|
||||
image[100:150, 100:200] = 1
|
||||
image = PIL.Image.fromarray(image).convert("L")
|
||||
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
|
||||
image = ToTensor()(image).reshape([1, 1, 256, 256])
|
||||
|
||||
np.random.seed(3)
|
||||
torch.manual_seed(3)
|
||||
random.seed(3)
|
||||
|
||||
expected_transformed = image
|
||||
for t in all_transforms:
|
||||
expected_transformed = t(expected_transformed)
|
||||
# The pipeline takes as input [C, Z, H, W] and returns [C, Z, H, W]
|
||||
# But the transforms list expect [Z, C, H, W] and returns [Z, C, H, W] so need to permute dimension to compare
|
||||
expected_transformed = torch.transpose(expected_transformed, 1, 0).squeeze(1)
|
||||
assert torch.isclose(expected_transformed, transformed_image).all()
|
||||
|
||||
# Test the evaluation pipeline
|
||||
transformation_pipeline = create_cxr_transforms_from_config(cxr_augmentation_config, apply_augmentations=False)
|
||||
transformed_image = transformation_pipeline(image)
|
||||
assert isinstance(transformed_image, torch.Tensor)
|
||||
all_transforms = [ExpandChannels(), Resize(size=256), CenterCrop(size=224)]
|
||||
expected_transformed = image
|
||||
for t in all_transforms:
|
||||
expected_transformed = t(expected_transformed)
|
||||
expected_transformed = torch.transpose(expected_transformed, 1, 0).squeeze(1)
|
||||
assert torch.isclose(expected_transformed, transformed_image).all()
|
|
@ -652,7 +652,7 @@ S4,label,,False,3.0
|
|||
traverse_dirs_when_loading=True,
|
||||
local_dataset=test_output_dirs.root_dir)
|
||||
raw_dataset = ScalarDataset(args, data_frame=df)
|
||||
normalized = ScalarDataset(args, data_frame=df, sample_transforms=WindowNormalizationForScalarItem())
|
||||
normalized = ScalarDataset(args, data_frame=df, sample_transform=WindowNormalizationForScalarItem())
|
||||
assert len(raw_dataset) == 4
|
||||
for i in range(4):
|
||||
raw_item = raw_dataset[i]
|
||||
|
|
|
@ -10,12 +10,15 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
from torchvision.transforms import ColorJitter, RandomAffine
|
||||
|
||||
from InnerEye.Common import common_util
|
||||
from InnerEye.Common.common_util import SUBJECT_METRICS_FILE_NAME, logging_to_stdout
|
||||
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
|
||||
from InnerEye.Common.metrics_constants import LoggingColumns, MetricType, SEQUENCE_POSITION_HUE_NAME_PREFIX
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
|
||||
|
||||
from InnerEye.ML.dataset.sequence_dataset import SequenceDataset
|
||||
from InnerEye.ML.deep_learning_config import TemperatureScalingConfig
|
||||
from InnerEye.ML.lightning_models import transfer_batch_to_device
|
||||
|
@ -27,7 +30,7 @@ from InnerEye.ML.run_ml import MLRunner
|
|||
from InnerEye.ML.scalar_config import ScalarLoss
|
||||
from InnerEye.ML.sequence_config import SEQUENCE_LENGTH_FILE, SEQUENCE_LENGTH_STATS_FILE, SequenceModelBase
|
||||
from InnerEye.ML.utils import ml_util
|
||||
from InnerEye.ML.utils.augmentation import RandAugmentSlice, ScalarItemAugmentation
|
||||
|
||||
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
|
||||
from InnerEye.ML.utils.io_util import ImageAndSegmentations
|
||||
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling, get_scalar_model_inputs_and_labels
|
||||
|
@ -107,12 +110,12 @@ class ToySequenceModel(SequenceModelBase):
|
|||
proportion_val=0.1,
|
||||
)
|
||||
|
||||
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
|
||||
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
|
||||
if self.use_combined_model:
|
||||
return ModelTransformsPerExecutionMode(
|
||||
train=ScalarItemAugmentation(
|
||||
transform=RandAugmentSlice(use_joint_channel_transformation=False,
|
||||
is_transformation_for_segmentation_maps=True)))
|
||||
train=ImageTransformationPipeline(
|
||||
transforms=[RandomAffine(degrees=30, translate=(0.1, 0.1), shear=15),
|
||||
ColorJitter(brightness=0.2)]))
|
||||
else:
|
||||
return ModelTransformsPerExecutionMode()
|
||||
|
||||
|
|
|
@ -12,19 +12,20 @@ import numpy as np
|
|||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
from torchvision.transforms import ColorJitter, RandomAffine
|
||||
|
||||
from InnerEye.Common import common_util
|
||||
from InnerEye.Common.common_util import logging_to_stdout
|
||||
from InnerEye.Common.output_directories import OutputFolderForTests
|
||||
from InnerEye.Common.type_annotations import TupleInt3
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
|
||||
from InnerEye.ML.augmentations.transform_pipeline import ImageTransformationPipeline
|
||||
from InnerEye.ML.dataset.scalar_dataset import ScalarDataset, ScalarItemAugmentation
|
||||
from InnerEye.ML.lightning_models import transfer_batch_to_device
|
||||
from InnerEye.ML.model_config_base import ModelTransformsPerExecutionMode
|
||||
from InnerEye.ML.models.architectures.classification.image_encoder_with_mlp import ImageEncoderWithMlp, \
|
||||
ImagingFeatureType
|
||||
from InnerEye.ML.run_ml import MLRunner
|
||||
from InnerEye.ML.scalar_config import AggregationType, ScalarLoss, ScalarModelBase, get_non_image_features_dict
|
||||
from InnerEye.ML.utils.augmentation import RandAugmentSlice, ScalarItemAugmentation
|
||||
from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
|
||||
from InnerEye.ML.utils.image_util import HDF5_NUM_SEGMENTATION_CLASSES, segmentation_to_one_hot
|
||||
from InnerEye.ML.utils.io_util import ImageAndSegmentations, NumpyFile
|
||||
|
@ -101,15 +102,24 @@ class ImageEncoder(ScalarModelBase):
|
|||
def get_post_loss_logits_normalization_function(self) -> Callable:
|
||||
return torch.nn.Sigmoid()
|
||||
|
||||
def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
|
||||
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
|
||||
"""
|
||||
Get transforms to perform on image samples for each model execution mode.
|
||||
"""
|
||||
return ModelTransformsPerExecutionMode(
|
||||
train=ScalarItemAugmentation(
|
||||
RandAugmentSlice(is_transformation_for_segmentation_maps=(
|
||||
self.imaging_feature_type == ImagingFeatureType.Segmentation
|
||||
or self.imaging_feature_type == ImagingFeatureType.ImageAndSegmentation))))
|
||||
if self.imaging_feature_type in [ImagingFeatureType.Image, ImagingFeatureType.ImageAndSegmentation]:
|
||||
return ModelTransformsPerExecutionMode(
|
||||
train=ImageTransformationPipeline(
|
||||
transforms=[RandomAffine(10), ColorJitter(0.2)],
|
||||
use_different_transformation_per_channel=True))
|
||||
return ModelTransformsPerExecutionMode()
|
||||
|
||||
def get_segmentation_transform(self) -> ModelTransformsPerExecutionMode:
|
||||
if self.imaging_feature_type in [ImagingFeatureType.Segmentation, ImagingFeatureType.ImageAndSegmentation]:
|
||||
return ModelTransformsPerExecutionMode(
|
||||
train=ImageTransformationPipeline(
|
||||
transforms=[RandomAffine(10), ColorJitter(0.2)],
|
||||
use_different_transformation_per_channel=True))
|
||||
return ModelTransformsPerExecutionMode()
|
||||
|
||||
|
||||
@pytest.mark.skipif(common_util.is_windows(), reason="Too slow on windows")
|
||||
|
@ -177,9 +187,10 @@ S3,week1,scan3.npy,True,6,60,Male,Val2
|
|||
)
|
||||
config_for_dataset.read_dataset_into_dataframe_and_pre_process()
|
||||
|
||||
dataset = ScalarDataset(config_for_dataset,
|
||||
sample_transforms=ScalarItemAugmentation(
|
||||
RandAugmentSlice(is_transformation_for_segmentation_maps=False)))
|
||||
dataset = ScalarDataset(
|
||||
config_for_dataset,
|
||||
sample_transform=ScalarItemAugmentation(ImageTransformationPipeline([RandomAffine(10), ColorJitter(0.2)],
|
||||
use_different_transformation_per_channel=True)))
|
||||
assert len(dataset) == 3
|
||||
|
||||
config = ImageEncoder(
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d4fb26944ad1fef4265cb52b64160865a940cd28ed4ad2bb7bf1c90458622bf6
|
||||
size 39050
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:775d542598a607881a313c0ac828ffc47dd86b2567bcb4366646b7ea14a121d8
|
||||
size 3923
|
|
@ -1,405 +0,0 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import random
|
||||
from typing import Any, Callable, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from InnerEye.ML.dataset.sample import Sample
|
||||
from InnerEye.ML.utils import augmentation, ml_util
|
||||
from InnerEye.ML.utils.augmentation import ImageTransformationBase
|
||||
from Tests.ML.util import DummyPatientMetadata
|
||||
|
||||
ml_util.set_random_seed(1)
|
||||
|
||||
image_size = (8, 8, 8)
|
||||
valid_image_4d = np.random.uniform(size=((5,) + image_size)) * 10
|
||||
valid_mask = np.random.randint(2, size=image_size)
|
||||
number_of_classes = 5
|
||||
class_assignments = np.random.randint(2, size=image_size)
|
||||
valid_labels = np.zeros((number_of_classes,) + image_size)
|
||||
for c in range(number_of_classes):
|
||||
valid_labels[c, class_assignments == c] = 1
|
||||
valid_crop_size = (2, 2, 2)
|
||||
valid_full_crop_size = image_size
|
||||
valid_class_weights = [0.5] + [0.5 / (number_of_classes - 1)] * (number_of_classes - 1)
|
||||
crop_size_requires_padding = (9, 8, 12)
|
||||
|
||||
|
||||
# Random Crop Tests
|
||||
def test_valid_full_crop() -> None:
|
||||
metadata = DummyPatientMetadata
|
||||
sample, _ = augmentation.random_crop(sample=Sample(image=valid_image_4d,
|
||||
labels=valid_labels,
|
||||
mask=valid_mask,
|
||||
metadata=metadata),
|
||||
crop_size=valid_full_crop_size,
|
||||
class_weights=valid_class_weights)
|
||||
|
||||
assert np.array_equal(sample.image, valid_image_4d)
|
||||
assert np.array_equal(sample.labels, valid_labels)
|
||||
assert np.array_equal(sample.mask, valid_mask)
|
||||
assert sample.metadata == metadata
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image", [None, list(), valid_image_4d])
|
||||
@pytest.mark.parametrize("labels", [None, list(), valid_labels])
|
||||
@pytest.mark.parametrize("mask", [None, list(), valid_mask])
|
||||
@pytest.mark.parametrize("class_weights", [[0, 0, 0], [0], [-1, 0, 1], [-1, -2, -3], valid_class_weights])
|
||||
def test_invalid_arrays(image: Any, labels: Any, mask: Any, class_weights: Any) -> None:
|
||||
"""
|
||||
Tests failure cases of the random_crop function for invalid image, labels, mask or class
|
||||
weights arguments.
|
||||
"""
|
||||
# Skip the final combination, because it is valid
|
||||
if not (np.array_equal(image, valid_image_4d) and np.array_equal(labels, valid_labels)
|
||||
and np.array_equal(mask, valid_mask) and class_weights == valid_class_weights):
|
||||
with pytest.raises(Exception):
|
||||
augmentation.random_crop(Sample(metadata=DummyPatientMetadata, image=image, labels=labels, mask=mask),
|
||||
valid_crop_size, class_weights)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("crop_size", [None, ["a"], 5])
|
||||
def test_invalid_crop_arg(crop_size: Any) -> None:
|
||||
with pytest.raises(Exception):
|
||||
augmentation.random_crop(
|
||||
Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=valid_mask),
|
||||
crop_size, valid_class_weights)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("crop_size", [[2, 2], [2, 2, 2, 2], [10, 10, 10]])
|
||||
def test_invalid_crop_size(crop_size: Any) -> None:
|
||||
with pytest.raises(Exception):
|
||||
augmentation.random_crop(
|
||||
Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels, mask=valid_mask),
|
||||
crop_size, valid_class_weights)
|
||||
|
||||
|
||||
def test_random_crop_no_fg() -> None:
|
||||
with pytest.raises(Exception):
|
||||
augmentation.random_crop(Sample(metadata=DummyPatientMetadata, image=valid_image_4d, labels=valid_labels,
|
||||
mask=np.zeros_like(valid_mask)),
|
||||
valid_crop_size, valid_class_weights)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
augmentation.random_crop(Sample(metadata=DummyPatientMetadata, image=valid_image_4d,
|
||||
labels=np.zeros_like(valid_labels), mask=valid_mask),
|
||||
valid_crop_size, valid_class_weights)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("crop_size", [valid_crop_size])
|
||||
def test_random_crop(crop_size: Any) -> None:
|
||||
labels = valid_labels
|
||||
# create labels such that there are no foreground voxels in a particular class
|
||||
# this should ne handled gracefully (class being ignored from sampling)
|
||||
labels[0] = 1
|
||||
labels[1] = 0
|
||||
sample, _ = augmentation.random_crop(Sample(
|
||||
image=valid_image_4d,
|
||||
labels=valid_labels,
|
||||
mask=valid_mask,
|
||||
metadata=DummyPatientMetadata
|
||||
), crop_size, valid_class_weights)
|
||||
|
||||
expected_img_crop_size = (valid_image_4d.shape[0], *crop_size)
|
||||
expected_labels_crop_size = (valid_labels.shape[0], *crop_size)
|
||||
|
||||
assert sample.image.shape == expected_img_crop_size
|
||||
assert sample.labels.shape == expected_labels_crop_size
|
||||
assert sample.mask.shape == tuple(crop_size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("class_weights",
|
||||
[None, [0, 0.5, 0.5, 0, 0], [0.1, 0.45, 0.45, 0, 0], [0.5, 0.25, 0.25, 0, 0],
|
||||
[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0.04, 0.12, 0.20, 0.28, 0.36], [0, 0.5, 0, 0.5, 0]])
|
||||
def test_valid_class_weights(class_weights: List[float]) -> None:
|
||||
"""
|
||||
Produce a large number of crops and make sure the crop center class proportions respect class weights
|
||||
"""
|
||||
ml_util.set_random_seed(1)
|
||||
num_classes = len(valid_labels)
|
||||
image = np.zeros_like(valid_image_4d)
|
||||
labels = np.zeros_like(valid_labels)
|
||||
class0, class1, class2 = non_empty_classes = [0, 2, 4]
|
||||
labels[class0] = 1
|
||||
labels[class0][3, 3, 3] = 0
|
||||
labels[class0][3, 2, 3] = 0
|
||||
labels[class1][3, 3, 3] = 1
|
||||
labels[class2][3, 2, 3] = 1
|
||||
|
||||
mask = np.ones_like(valid_mask)
|
||||
sample = Sample(image=image, labels=labels, mask=mask, metadata=DummyPatientMetadata)
|
||||
|
||||
crop_size = (1, 1, 1)
|
||||
total_crops = 200
|
||||
sampled_label_center_distribution = np.zeros(num_classes)
|
||||
|
||||
# If there is no class that has a non-zero weight and is present in the sample, there is no valid
|
||||
# way to select a class, so we expect an exception to be thrown.
|
||||
if class_weights is not None and sum(class_weights[c] for c in non_empty_classes) == 0:
|
||||
with pytest.raises(ValueError):
|
||||
augmentation.random_crop(sample, crop_size, class_weights)
|
||||
return
|
||||
|
||||
for _ in range(0, total_crops):
|
||||
crop_sample, center = augmentation.random_crop(sample, crop_size, class_weights)
|
||||
sampled_class = list(labels[:, center[0], center[1], center[2]]).index(1)
|
||||
sampled_label_center_distribution[sampled_class] += 1
|
||||
|
||||
sampled_label_center_distribution /= total_crops
|
||||
|
||||
if class_weights is None:
|
||||
weight = 1.0 / len(non_empty_classes)
|
||||
expected_label_center_distribution = [weight if c in non_empty_classes else 0.0
|
||||
for c in range(number_of_classes)]
|
||||
else:
|
||||
total = sum(class_weights[c] for c in non_empty_classes)
|
||||
expected_label_center_distribution = [class_weights[c] / total if c in non_empty_classes else 0.0
|
||||
for c in range(number_of_classes)]
|
||||
assert np.allclose(sampled_label_center_distribution, expected_label_center_distribution, atol=0.1)
|
||||
|
||||
|
||||
def _check_transformation_result(image_as_tensor: torch.Tensor,
|
||||
transformation: Callable,
|
||||
expected: torch.Tensor) -> None:
|
||||
test_tensor_pil = TF.to_pil_image(image_as_tensor)
|
||||
transformed = TF.to_tensor(transformation(test_tensor_pil)).squeeze()
|
||||
np.testing.assert_allclose(transformed, expected, rtol=0.02)
|
||||
|
||||
|
||||
# Augmentation pipeline tests
|
||||
|
||||
@pytest.mark.parametrize(["transformation", "expected"],
|
||||
[(ImageTransformationBase.horizontal_flip(), torch.tensor([[0, 1, 2],
|
||||
[0, 2, 1],
|
||||
[2, 0, 0]])),
|
||||
(ImageTransformationBase.rotate(45), torch.tensor([[1, 0, 0],
|
||||
[2, 2, 2],
|
||||
[1, 0, 0]])),
|
||||
(ImageTransformationBase.translateX(0.3), torch.tensor([[0, 2, 1],
|
||||
[0, 1, 2],
|
||||
[0, 0, 0]])),
|
||||
(ImageTransformationBase.translateY(0.3), torch.tensor([[0, 0, 0],
|
||||
[2, 1, 0],
|
||||
[1, 2, 0]])),
|
||||
(ImageTransformationBase.identity(), torch.tensor([[2, 1, 0],
|
||||
[1, 2, 0],
|
||||
[0, 0, 2]]))])
|
||||
def test_transformations_for_segmentations(transformation: Callable, expected: torch.Tensor) -> None:
|
||||
"""
|
||||
Tests each individual transformation of the ImageTransformationBase class on a 2D input representing
|
||||
a segmentation map.
|
||||
"""
|
||||
image_as_tensor = torch.tensor([[2, 1, 0],
|
||||
[1, 2, 0],
|
||||
[0, 0, 2]], dtype=torch.int32)
|
||||
_check_transformation_result(image_as_tensor, transformation, expected)
|
||||
|
||||
|
||||
def test_invalid_segmentation_type() -> None:
|
||||
"""
|
||||
Validates the necessity of converting segmentation maps to int before PIL
|
||||
conversion.
|
||||
"""
|
||||
image_as_tensor = torch.tensor([[2, 1, 0],
|
||||
[1, 2, 0],
|
||||
[0, 0, 2]], dtype=torch.float32)
|
||||
expected = torch.tensor([[1, 0, 0], [2, 2, 2], [1, 0, 0]])
|
||||
with pytest.raises(AssertionError):
|
||||
_check_transformation_result(image_as_tensor, ImageTransformationBase.rotate(45), expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["transformation", "expected"],
|
||||
[(ImageTransformationBase.horizontal_flip(), torch.tensor([[0.1, 0.5, 1],
|
||||
[0.1, 1, 0.5],
|
||||
[1, 0.1, 0.1]])),
|
||||
(ImageTransformationBase.adjust_contrast(2), torch.tensor([[1., 0.509804, 0.],
|
||||
[0.509804, 1., 0.],
|
||||
[0., 0., 1.]])),
|
||||
(ImageTransformationBase.adjust_brightness(2), torch.tensor([[1.0000, 0.9961, 0.1961],
|
||||
[0.9961, 1.0000, 0.1961],
|
||||
[0.1961, 0.1961, 1.0000]])),
|
||||
(ImageTransformationBase.adjust_contrast(0), torch.tensor([[0.4863, 0.4863, 0.4863],
|
||||
[0.4863, 0.4863, 0.4863],
|
||||
[0.4863, 0.4863, 0.4863]])),
|
||||
(ImageTransformationBase.adjust_brightness(0), torch.tensor([[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0]]))])
|
||||
def test_transformation_image(transformation: Callable, expected: torch.Tensor) -> None:
|
||||
"""
|
||||
Tests each individual transformation of the ImageTransformationBase class on a 2D input representing
|
||||
a natural image.
|
||||
"""
|
||||
image_as_tensor = torch.tensor([[1, 0.5, 0.1],
|
||||
[0.5, 1, 0.1],
|
||||
[0.1, 0.1, 1]], dtype=torch.float32)
|
||||
_check_transformation_result(image_as_tensor, transformation, expected)
|
||||
|
||||
|
||||
def test_apply_transformations() -> None:
|
||||
"""
|
||||
Testing the function applying a series of transformations to a given image
|
||||
"""
|
||||
operations = [ImageTransformationBase.identity(), ImageTransformationBase.translateX(0.3),
|
||||
ImageTransformationBase.horizontal_flip()]
|
||||
|
||||
# Case 1 on segmentations
|
||||
image_as_tensor = torch.tensor([[[2, 1, 0], [1, 2, 0], [0, 0, 2]],
|
||||
[[2, 1, 0], [1, 2, 0], [0, 0, 2]]], dtype=torch.int32)
|
||||
transformed_tensor = ImageTransformationBase.apply_transform_on_3d_image(image=image_as_tensor,
|
||||
transforms=operations)
|
||||
expected = torch.tensor([[[1, 2, 0], [2, 1, 0], [0, 0, 0]],
|
||||
[[1, 2, 0], [2, 1, 0], [0, 0, 0]]])
|
||||
assert torch.all(expected == transformed_tensor)
|
||||
|
||||
# Case 2 on image
|
||||
image_as_tensor = torch.tensor([[[1, 0.5, 0.1], [0.5, 1, 0.1], [0.1, 0.1, 1]],
|
||||
[[1, 0.5, 0.1], [0.5, 1, 0.1], [0.1, 0.1, 1]]], dtype=torch.float32)
|
||||
transformed_tensor = ImageTransformationBase.apply_transform_on_3d_image(image=image_as_tensor,
|
||||
transforms=operations)
|
||||
expected = torch.tensor([[[0.5, 1, 0], [1, 0.5, 0], [0.1, 0.1, 0]],
|
||||
[[0.5, 1, 0], [1, 0.5, 0], [0.1, 0.1, 0]]])
|
||||
np.testing.assert_allclose(transformed_tensor, expected, rtol=0.02)
|
||||
|
||||
|
||||
def _compute_expected_pipeline_result(transformations: List[List[Callable]],
|
||||
input_image: torch.Tensor) -> torch.Tensor:
|
||||
expected = input_image.clone()
|
||||
expected[0] = ImageTransformationBase.apply_transform_on_3d_image(expected[0],
|
||||
transformations[0])
|
||||
expected[1] = ImageTransformationBase.apply_transform_on_3d_image(expected[1],
|
||||
transformations[1])
|
||||
return expected
|
||||
|
||||
|
||||
def test_RandAugment_pipeline() -> None:
|
||||
"""
|
||||
Test the RandAugment transformation pipeline for online data augmentation.
|
||||
"""
|
||||
# Set random seeds for transformations
|
||||
np.random.seed(1)
|
||||
random.seed(0)
|
||||
|
||||
# Get inputs
|
||||
one_channel_image = torch.tensor([[[2, 1, 0], [1, 2, 0], [0, 0, 2]],
|
||||
[[2, 1, 0], [1, 2, 0], [0, 0, 2]]], dtype=torch.int32)
|
||||
two_channel_image = torch.stack((one_channel_image, one_channel_image), dim=0)
|
||||
|
||||
# Case no transformation applied
|
||||
pipeline = augmentation.RandAugmentSlice(magnitude=3,
|
||||
n_transforms=0,
|
||||
is_transformation_for_segmentation_maps=True)
|
||||
transformed = pipeline(two_channel_image)
|
||||
assert torch.all(two_channel_image == transformed)
|
||||
|
||||
# Case separate transformation per channel
|
||||
pipeline = augmentation.RandAugmentSlice(magnitude=3,
|
||||
n_transforms=1,
|
||||
is_transformation_for_segmentation_maps=True,
|
||||
use_joint_channel_transformation=False)
|
||||
expected_sampled_ops_channel_1 = [ImageTransformationBase.translateY(-0.3 * 0.2)]
|
||||
expected_sampled_ops_channel_2 = [ImageTransformationBase.horizontal_flip()]
|
||||
expected = _compute_expected_pipeline_result(transformations=[expected_sampled_ops_channel_1,
|
||||
expected_sampled_ops_channel_2],
|
||||
input_image=two_channel_image)
|
||||
transformed = pipeline(two_channel_image)
|
||||
assert torch.all(transformed == expected)
|
||||
|
||||
# Case same transformation for all channels
|
||||
pipeline = augmentation.RandAugmentSlice(magnitude=5,
|
||||
n_transforms=2,
|
||||
is_transformation_for_segmentation_maps=True,
|
||||
use_joint_channel_transformation=True)
|
||||
transformed = pipeline(two_channel_image)
|
||||
|
||||
expected_sampled_ops_channel = [ImageTransformationBase.rotate(0.5 * 30),
|
||||
ImageTransformationBase.translateY(0.5 * 0.2)]
|
||||
|
||||
expected = _compute_expected_pipeline_result(transformations=[expected_sampled_ops_channel,
|
||||
expected_sampled_ops_channel],
|
||||
input_image=two_channel_image)
|
||||
assert torch.all(transformed == expected)
|
||||
|
||||
# Case for images
|
||||
two_channel_image = two_channel_image / 2.0
|
||||
pipeline = augmentation.RandAugmentSlice(magnitude=3, n_transforms=1,
|
||||
use_joint_channel_transformation=True,
|
||||
is_transformation_for_segmentation_maps=False)
|
||||
transformed = pipeline(two_channel_image)
|
||||
expected_sampled_ops_channel = [ImageTransformationBase.adjust_contrast(1 - 0.3)]
|
||||
expected = _compute_expected_pipeline_result(transformations=[expected_sampled_ops_channel,
|
||||
expected_sampled_ops_channel],
|
||||
input_image=two_channel_image)
|
||||
assert torch.all(transformed == expected)
|
||||
|
||||
|
||||
def test_RandomSliceTransformation_pipeline() -> None:
|
||||
"""
|
||||
Test the RandomSerial transformation pipeline for online data augmentation.
|
||||
"""
|
||||
# Set random seeds for transformations
|
||||
np.random.seed(1)
|
||||
random.seed(0)
|
||||
|
||||
one_channel_image = torch.tensor([[[2, 1, 0], [1, 2, 0], [0, 0, 2]],
|
||||
[[2, 1, 0], [1, 2, 0], [0, 0, 2]]], dtype=torch.int32)
|
||||
image_with_two_channels = torch.stack((one_channel_image, one_channel_image), dim=0)
|
||||
|
||||
# Case no transformation applied
|
||||
pipeline = augmentation.RandomSliceTransformation(probability_transformation=0,
|
||||
is_transformation_for_segmentation_maps=True)
|
||||
transformed = pipeline(image_with_two_channels)
|
||||
assert torch.all(image_with_two_channels == transformed)
|
||||
|
||||
# Case separate transformation per channel
|
||||
pipeline = augmentation.RandomSliceTransformation(probability_transformation=1,
|
||||
is_transformation_for_segmentation_maps=True,
|
||||
use_joint_channel_transformation=False)
|
||||
transformed = pipeline(image_with_two_channels)
|
||||
expected_transformations_channel_1 = [ImageTransformationBase.rotate(-7),
|
||||
ImageTransformationBase.translateX(0.011836899667533166),
|
||||
ImageTransformationBase.translateY(-0.04989873172751189),
|
||||
ImageTransformationBase.horizontal_flip()]
|
||||
expected_transformations_channel_2 = [ImageTransformationBase.rotate(-1),
|
||||
ImageTransformationBase.translateX(-0.04012366553408523),
|
||||
ImageTransformationBase.translateY(-0.08525151327425498),
|
||||
ImageTransformationBase.horizontal_flip()]
|
||||
expected = _compute_expected_pipeline_result(transformations=[expected_transformations_channel_1,
|
||||
expected_transformations_channel_2],
|
||||
input_image=image_with_two_channels)
|
||||
assert torch.all(transformed == expected)
|
||||
|
||||
# Case same transformation for all channels
|
||||
pipeline = augmentation.RandomSliceTransformation(probability_transformation=1,
|
||||
is_transformation_for_segmentation_maps=True,
|
||||
use_joint_channel_transformation=True)
|
||||
transformed = pipeline(image_with_two_channels)
|
||||
expected_transformations_channel_1 = [ImageTransformationBase.rotate(9),
|
||||
ImageTransformationBase.translateX(-0.0006422133534675356),
|
||||
ImageTransformationBase.translateY(0.07352055509855618),
|
||||
ImageTransformationBase.horizontal_flip()]
|
||||
expected = _compute_expected_pipeline_result(transformations=[expected_transformations_channel_1,
|
||||
expected_transformations_channel_1],
|
||||
input_image=image_with_two_channels)
|
||||
assert torch.all(transformed == expected)
|
||||
|
||||
# Case for images - convert to range 0-1 first
|
||||
image_with_two_channels = image_with_two_channels / 4.0
|
||||
pipeline = augmentation.RandomSliceTransformation(probability_transformation=1,
|
||||
is_transformation_for_segmentation_maps=False,
|
||||
use_joint_channel_transformation=True)
|
||||
transformed = pipeline(image_with_two_channels)
|
||||
expected_transformations_channel_1 = [ImageTransformationBase.rotate(8),
|
||||
ImageTransformationBase.translateX(-0.02782961037858135),
|
||||
ImageTransformationBase.translateY(0.06066901082924045),
|
||||
ImageTransformationBase.adjust_contrast(0.2849887878176489),
|
||||
ImageTransformationBase.adjust_brightness(1.0859799153800245)]
|
||||
expected = _compute_expected_pipeline_result(transformations=[expected_transformations_channel_1,
|
||||
expected_transformations_channel_1],
|
||||
input_image=image_with_two_channels)
|
||||
assert torch.all(transformed == expected)
|
|
@ -18,13 +18,13 @@ from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataMod
|
|||
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
|
||||
InnerEyeCIFARTrainTransform, get_cxr_ssl_transforms
|
||||
from InnerEye.ML.SSL.lightning_containers.ssl_container import SSLContainer, SSLDatasetName
|
||||
from InnerEye.ML.SSL.utils import SSLDataModuleType, load_ssl_augmentation_config
|
||||
from InnerEye.ML.SSL.utils import SSLDataModuleType, load_yaml_augmentation_config
|
||||
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_encoder_augmentation_cxr
|
||||
from Tests.SSL.test_ssl_containers import _create_test_cxr_data
|
||||
|
||||
path_to_test_dataset = full_ml_test_data_path("cxr_test_dataset")
|
||||
_create_test_cxr_data(path_to_test_dataset)
|
||||
cxr_augmentation_config = load_ssl_augmentation_config(path_encoder_augmentation_cxr)
|
||||
cxr_augmentation_config = load_yaml_augmentation_config(path_encoder_augmentation_cxr)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_windows(), reason="Too slow on windows")
|
||||
|
|
|
@ -1,108 +0,0 @@
|
|||
import random
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from scipy.ndimage import gaussian_filter, map_coordinates
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import AddGaussianNoise, CenterCrop, ElasticTransform, \
|
||||
ExpandChannels, RandomAffine, RandomColorJitter, RandomErasing, RandomGamma, RandomHorizontalFlip, RandomResizeCrop, \
|
||||
Resize, create_chest_xray_transform
|
||||
from Tests.SSL.test_data_modules import cxr_augmentation_config
|
||||
|
||||
|
||||
def test_add_gaussian_noise() -> None:
|
||||
"""
|
||||
Tests functionality of add gaussian noise
|
||||
"""
|
||||
np.random.seed(1)
|
||||
torch.manual_seed(10)
|
||||
array = np.ones([1, 256, 256]) * 255.
|
||||
array[0, 100:150, 100:200] = 1
|
||||
tensor_img = torch.tensor(array / 255.)
|
||||
transformed = AddGaussianNoise(cxr_augmentation_config)(tensor_img)
|
||||
torch.manual_seed(10)
|
||||
noise = torch.randn(size=(1, 256, 256)) * 0.05
|
||||
assert torch.isclose(torch.clamp(tensor_img + noise, 0, 1), transformed).all()
|
||||
with pytest.raises(AssertionError):
|
||||
AddGaussianNoise(cxr_augmentation_config)(tensor_img * 255.)
|
||||
|
||||
|
||||
def test_elastic_transform() -> None:
|
||||
"""
|
||||
Tests elastic transform
|
||||
"""
|
||||
image = np.ones([256, 256]) * 255.
|
||||
image[100:150, 100:200] = 1
|
||||
|
||||
# Computed expected transform
|
||||
np.random.seed(7)
|
||||
np.random.random(1)
|
||||
|
||||
shape = (256, 256)
|
||||
dx = gaussian_filter((np.random.random(shape) * 2 - 1), 4, mode="constant", cval=0) * 34
|
||||
dy = gaussian_filter((np.random.random(shape) * 2 - 1), 4, mode="constant", cval=0) * 34
|
||||
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
|
||||
indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))
|
||||
expected_array = map_coordinates(image, indices, order=1).reshape(shape)
|
||||
# Actual transform
|
||||
np.random.seed(7)
|
||||
transformed_image = np.asarray(ElasticTransform(cxr_augmentation_config)(PIL.Image.fromarray(image)))
|
||||
assert np.isclose(expected_array, transformed_image).all()
|
||||
|
||||
|
||||
def test_expand_channels() -> None:
|
||||
image = np.ones([1, 256, 256]) * 255.
|
||||
tensor_img = torch.tensor(image)
|
||||
tensor_img = ExpandChannels()(tensor_img)
|
||||
assert tensor_img.shape == torch.Size([3, 256, 256])
|
||||
assert torch.isclose(tensor_img[0], tensor_img[1]).all() and torch.isclose(tensor_img[1], tensor_img[2]).all()
|
||||
|
||||
|
||||
def test_create_chest_xray_transform() -> None:
|
||||
"""
|
||||
Tests that the pipeline returned by create_chest_xray_transform returns the expected transformation.
|
||||
"""
|
||||
transform = create_chest_xray_transform(cxr_augmentation_config, apply_augmentations=True)
|
||||
image = np.ones([256, 256]) * 255.
|
||||
image[100:150, 100:200] = 1
|
||||
image = PIL.Image.fromarray(image).convert("L")
|
||||
np.random.seed(3)
|
||||
torch.manual_seed(3)
|
||||
random.seed(3)
|
||||
transformed_image = transform(image)
|
||||
|
||||
# Expected pipeline
|
||||
np.random.seed(3)
|
||||
torch.manual_seed(3)
|
||||
random.seed(3)
|
||||
image = RandomAffine(cxr_augmentation_config)(image)
|
||||
image = RandomResizeCrop(cxr_augmentation_config)(image)
|
||||
image = Resize(cxr_augmentation_config)(image)
|
||||
image = RandomHorizontalFlip(cxr_augmentation_config)(image)
|
||||
image = RandomGamma(cxr_augmentation_config)(image)
|
||||
image = RandomColorJitter(cxr_augmentation_config)(image)
|
||||
image = ElasticTransform(cxr_augmentation_config)(image)
|
||||
image = CenterCrop(cxr_augmentation_config)(image)
|
||||
image = ToTensor()(image)
|
||||
image = RandomErasing(cxr_augmentation_config)(image)
|
||||
image = AddGaussianNoise(cxr_augmentation_config)(image)
|
||||
image = ExpandChannels()(image)
|
||||
|
||||
assert torch.isclose(image, transformed_image).all()
|
||||
|
||||
# Test the evaluation pipeline
|
||||
transform = create_chest_xray_transform(cxr_augmentation_config, apply_augmentations=False)
|
||||
image = np.ones([256, 256]) * 255.
|
||||
image[100:150, 100:200] = 1
|
||||
image = PIL.Image.fromarray(image).convert("L")
|
||||
transformed_image = transform(image)
|
||||
|
||||
# Expected pipeline
|
||||
image = Resize(cxr_augmentation_config)(image)
|
||||
image = CenterCrop(cxr_augmentation_config)(image)
|
||||
image = ToTensor()(image)
|
||||
image = ExpandChannels()(image)
|
||||
assert torch.isclose(image, transformed_image).all()
|
|
@ -283,3 +283,35 @@ runs are uploaded to the parent run, in the `CrossValResults` directory. This co
|
|||
There is also a directory `BaselineComparisons`, containing the Wilcoxon test results and
|
||||
scatterplots for the ensemble, as described above for single runs.
|
||||
|
||||
### Augmentations for classification models.
|
||||
|
||||
For classification models, you can define an augmentation pipeline to apply to your images input (resp. segmentations) at
|
||||
training, validation and test time. In order to define such a series of transformations, you will need to overload the
|
||||
`get_image_transform` (resp. `get_segmention_transform`) method of your config class. This method expects you to return
|
||||
a `ModelTransformsPerExecutionMode`, that maps each execution mode to one transform function. We also provide the
|
||||
`ImageTransformationPipeline` a class that creates a pipeline of transforms, from a list of individual transforms and
|
||||
ensures the correct conversion of 2D or 3D PIL.Image or tensor inputs to the obtained pipeline.
|
||||
|
||||
`ImageTransformationPipeline` takes two arguments for its constructor:
|
||||
* `transforms`: a list of image transforms, in particular you can feed in standard [torchvision transforms](https://pytorch.org/vision/0.8/transforms.html) or
|
||||
any other transforms as long as they support an input `[Z, C, H, W]` (where Z is the 3rd dimension (1 for 2D images),
|
||||
C number of channels, H and W the height and width of each 2D slide - this is supported for standard torchvision
|
||||
transforms.). You can also define your own transforms as long as they expect such a `[Z, C, H, W]` input. You can
|
||||
find some examples of custom transforms class in `InnerEye/ML/augmentation/image_transforms.py`.
|
||||
* `use_different_transformation_per_channel`: if True, apply a different version of the augmentation pipeline
|
||||
for each channel. If False, applies the same transformation to each channel, separately. Default to False.
|
||||
|
||||
Below you can find an example of `get_image_transform` that would resize your input images to 256 x 256, and at
|
||||
training time only apply random rotation of +/- 10 degrees, and apply some brightness distortion,
|
||||
using standard pytorch vision transforms.
|
||||
|
||||
```python
|
||||
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
|
||||
"""
|
||||
Get transforms to perform on image samples for each model execution mode.
|
||||
"""
|
||||
return ModelTransformsPerExecutionMode(
|
||||
train=ImageTransformationPipeline(transforms=[Resize(256), RandomAffine(degrees=10), ColorJitter(brightness=0.2)]),
|
||||
val=ImageTransformationPipeline(transforms=[Resize(256)]),
|
||||
test=ImageTransformationPipeline(transforms=[Resize(256)]))
|
||||
```
|
Загрузка…
Ссылка в новой задаче