Updated augmentations for histology images (#179)

Fixed HEDJitter augmentation, the previous version was also jittering the D channel
Removed dependency of StainNormalization on skimage, which makes it also faster
Added GaussianBlur that is faster than the torchvision version
Added rotation by multiples of 90 degrees
Adjusted/added the test for all augmentations.
All the above augmentations are frequently used for histopathology.

Added torchvision to the environment since it is necessary for the augmentations.
This commit is contained in:
maxilse 2022-01-11 14:16:39 +01:00 коммит произвёл GitHub
Родитель 7902a6e138
Коммит a33c1ed07d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 200 добавлений и 79 удалений

Просмотреть файл

@ -10,6 +10,22 @@ For each Pull Request, the affected code parts should be briefly described and a
release. In the first PR after a release has been made, a section for the upcoming release should be added, by copying
the section headers (Added/Changed/...) and incrementing the package version.
## 0.1.14
### Added
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) Add GaussianBlur and RotationByMultiplesOf90 augmentations. Added torchvision and opencv to
the environment file since it is necessary for the augmentations.
### Changed
### Fixed
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) HEDJitter was jittering the D channel as well. StainNormalization was relying on skimage.
### Removed
### Deprecated
## 0.1.13
### Added

Просмотреть файл

@ -70,7 +70,7 @@ class WrappedTensorboard(Tensorboard):
self._run_watchers.append(run_watcher)
for w in self._run_watchers:
self._executor.submit(w.refresh_requeue)
self._executor.submit(w.refresh_requeue) # type: ignore
# We use sys.executable here to ensure that we can import modules from the same environment
# as the current process.

Просмотреть файл

@ -4,5 +4,5 @@ matplotlib==3.4.3
opencv-python-headless==4.5.1.48
pandas==1.3.4
pytorch-lightning>=1.4.9
scikit-image==0.17.2
torchvision==0.9.0
torch>=1.8

Просмотреть файл

@ -1,57 +1,92 @@
from skimage import color
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import cv2
import numpy as np
import torch
import torchvision.transforms.functional as TF
class HEDJitter(object):
"""
A class to randomly perturb the HEAD color space value of an RGB image
"""
def __init__(self, theta: float = 0.) -> None: # HED_light: theta=0.05; HED_strong: theta=0.2
Randomly perturbe the HED color space value an RGB image.
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix,
taken from Ruifrok and Johnston (2001): "Quantification of histochemical staining by color deconvolution."
Second, it perturbed the hematoxylin, eosin stains independently.
Third, it transformed the resulting stains into regular RGB color space.
Usage example:
>>> transform = HEDJitter(0.05)
>>> img = transform(img)
"""
def __init__(self, theta: float = 0.) -> None:
"""
:param theta: How much to jitter HED color space.
HED_light: theta=0.05; HED_strong: theta=0.2.
alpha is chosen from a uniform distribution [1-theta, 1+theta].
beta is chosen from a uniform distribution [-theta, theta].
The jitter formula is :math:`s' = \alpha * s + \beta`.
"""
self.theta = theta
self.rgb_from_hed = torch.tensor([[0.65, 0.70, 0.29],
[0.07, 0.99, 0.11],
[0.27, 0.57, 0.78]])
self.hed_from_rgb = torch.tensor([[1.87798274, -1.00767869, -0.55611582],
[-0.06590806, 1.13473037, -0.1355218],
[-0.60190736, -0.48041419, 1.57358807]])
@staticmethod
def adjust_hed(img: torch.Tensor, theta: float) -> torch.Tensor:
def adjust_hed(img: torch.Tensor,
theta: float,
stain_from_rgb_mat: torch.Tensor,
rgb_from_stain_mat: torch.Tensor
) -> torch.Tensor:
"""
Randomly perturb the hematoxylin-Eosin-DAB (HED) color space value of an RGB image
Steps involved in this process:
1. separate the stains (RGB to HED color space conversion)
2. perturb the stains independently
3. convert the resulting stains back to RGB color space
Applies HED jitter to image.
:param img: A Torch Tensor representing the image to be transformed
:param theta: A float representing how much to jitter HED color space by
:return: a Torch Tensor of stains transformed into RGB color space.
:param img: Input image.
:param theta: Strength of the jitter. HED_light: theta=0.05; HED_strong: theta=0.2.
:param stain_from_rgb_mat: Transformation matrix from HED to RGB.
:param rgb_from_stain_mat: Transformation matrix from RGB to HED.
"""
# alpha is chosen from a uniform distribution [1 - theta, 1 + theta]
alpha = np.random.uniform(1 - theta, 1 + theta, (1, 3))
# beta is chosen from a uniform distribution [-theta, theta]
beta = np.random.uniform(-theta, theta, (1, 3))
alpha = torch.FloatTensor(1, 3).uniform_(1 - theta, 1 + theta)
beta = torch.FloatTensor(1, 3).uniform_(-theta, theta)
assert img.ndim == 4, "Expected a Tensor with 4 dimensions"
# channel dim must be last for next function
img = img.permute([0, 2, 3, 1]).numpy()
s = color.rgb2hed(img)
# Only perturb the H (=0) and E (=1) channels
alpha[0][-1] = 1.
beta[0][-1] = 0.
# the jitter formula (perturbations in HED color space) is **s' = \alpha * s + \beta**
ns = alpha * s + beta
# Separate stains
img = img.permute([0, 2, 3, 1])
img = img + 2 # for consistency with skimage
stains = -torch.log10(img) @ stain_from_rgb_mat
stains = alpha * stains + beta # perturbations in HED color space
nimg = color.hed2rgb(ns)
nimg = np.clip(nimg, 0, 1)
nimg = torch.Tensor(nimg).permute(0, 3, 1, 2)
# Combine stains
img = 10 ** (-stains @ rgb_from_stain_mat) - 2
img = torch.clip(img, 0, 1)
img = img.permute(0, 3, 1, 2)
return nimg
return img
def __call__(self, img: torch.Tensor) -> torch.Tensor:
return self.adjust_hed(img, self.theta)
return self.adjust_hed(img, self.theta, self.hed_from_rgb, self.rgb_from_hed)
class StainNormalization(object):
"""
A class to normalize the stain of an image given a reference image. Following
Erik Reinhard,Bruce Gooch., Color Transfer between Images, IEEE ComputerGraphics and Applications.
"""Normalize the stain of an image given a reference image.
Following Erik Reinhard, Bruce Gooch (2001): Color Transfer between Images.
First, mask all white pixels.
Second, convert remaining pixels to lab space and normalize each channel.
Third, add mean and std of reference image.
Fourth, convert back to rgb and add white pixels back.
Usage example:
>>> transform = StainNormalization()
>>> img = transform(img)
"""
def __init__(self) -> None:
# mean and std per channel of a reference image
@ -61,22 +96,13 @@ class StainNormalization(object):
@staticmethod
def stain_normalize(img: torch.Tensor, reference_mean: np.ndarray, reference_std: np.ndarray) -> torch.Tensor:
"""
Normalize the stain of an image given a reference image
Applies stain normalization to image.
Steps involved:
1. mask all white pixels
2. convert remaining pixels to lab space and normalize each channel
3. add mean and std of reference image
4. convert back to rgb and add white pixels back
:param img: the image whose stain should be normalised
:param reference_mean: the mean of the reference image, for normalisation
:param reference_std: the standard deviation of the reference image, for normalisation
:return: A Torch tensor representing the image with normalized stain
:param img: Input image.
:param reference_mean: Mean per channel of a reference image.
:param reference_std: STD per channel of a reference image.
"""
assert img.ndim == 4, "Expected a Tensor with 4 dimensions"
# only 3 channels, color channel last, range 0 - 255
img = img.permute([0, 2, 3, 1]).squeeze().numpy() * 255
img = img.permute([0, 2, 3, 1]).squeeze().numpy() * 255 # only 3 channels, color channel last, range 0 - 255
img = img.astype(np.uint8) # type: ignore
whitemask = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
@ -86,8 +112,7 @@ class StainNormalization(object):
imagelab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
imagelab_masked = np.ma.MaskedArray(imagelab, whitemask) # type: np.ma.MaskedArray
# Sometimes STD is near 0, add epsilon to avoid div by 0
epsilon = 1e-11
epsilon = 1e-11 # Sometimes STD is near 0, add epsilon to avoid div by 0
imagelab_masked_mean = imagelab_masked.mean(axis=(0, 1))
imagelab_masked_std = imagelab_masked.std(axis=(0, 1)) + epsilon
@ -96,13 +121,65 @@ class StainNormalization(object):
imagelab = np.clip(imagelab, 0, 255)
imagelab = imagelab.astype(np.uint8)
nimg = cv2.cvtColor(imagelab, cv2.COLOR_LAB2RGB)
# add back white pixels
nimg[whitemask] = img[whitemask]
# convert back to Tensor
nimg = torch.Tensor(nimg).unsqueeze(0).permute(0, 3, 1, 2) / 255.
nimg[whitemask] = img[whitemask] # add back white pixels
nimg = torch.Tensor(nimg).unsqueeze(0).permute(0, 3, 1, 2) / 255. # back to pytorch format
return nimg
def __call__(self, img: torch.Tensor) -> torch.Tensor:
return self.stain_normalize(img, self.reference_mean, self.reference_std)
class GaussianBlur(object):
"""
Implements Gaussian blur as described in the SimCLR paper (https://arxiv.org/abs/2002.05709).
Blur image using a Gaussian kernel with a randomly sampled STD.
Slight modification of the code in pl_bolts to make it work with our transform pipeline.
Usage example:
>>> transform = GaussianBlur(kernel_size=int(224 * 0.1) + 1)
>>> img = transform(img)
"""
def __init__(self, kernel_size: int, p: float = 0.5, min: float = 0.1, max: float = 2.0) -> None:
"""
:param kernel_size: Size of the Gaussian kernel, e.g., about 10% of the image size.
:param p: Probability of applying blur.
:param min: lower bound of the interval from which we sample the STD
:param max: upper bound of the interval from which we sample the STD
"""
self.min = min
self.max = max
self.kernel_size = kernel_size
self.p = p
def __call__(self, sample: torch.Tensor) -> torch.Tensor:
prob = np.random.random_sample()
if prob < self.p:
sigma = (self.max - self.min) * np.random.random_sample() + self.min
sample = np.array(sample.squeeze()) # type: ignore
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
sample = torch.Tensor(sample).unsqueeze(0)
return sample
class RandomRotationByMultiplesOf90(object):
"""
Rotation of input image by 0, 90, 180 or 270 degrees.
Usage example:
>>> transform = RandomRotationByMultiplesOf90()
>>> img = transform(img)
"""
def __init__(self) -> None:
super().__init__()
def __call__(self, sample: torch.Tensor) -> torch.Tensor:
angle = np.random.choice([0., 90., 180., 270.])
if angle != 0.:
sample = TF.rotate(sample, angle)
return sample

Просмотреть файл

@ -2,29 +2,31 @@ import torch
import numpy as np
import random
from torch import Tensor
from typing import Callable
from health_ml.utils.data_augmentations import HEDJitter, StainNormalization
from health_ml.utils.data_augmentations import HEDJitter, StainNormalization, GaussianBlur, \
RandomRotationByMultiplesOf90
# global dummy image
dummy_img = torch.Tensor(
[[
[[0.4767, 0.0415], [0.8325, 0.8420]],
[[0.9859, 0.9119], [0.8717, 0.9098]],
[[0.1592, 0.7216], [0.8305, 0.1127]]
]]
)
[[[[0.4767, 0.0415],
[0.8325, 0.8420]],
[[0.9859, 0.9119],
[0.8717, 0.9098]],
[[0.1592, 0.7216],
[0.8305, 0.1127]]]])
def _test_data_augmentation(data_augmentation: Callable[[torch.Tensor], torch.Tensor],
def _test_data_augmentation(data_augmentation: Callable[[Tensor], Tensor],
input_img: torch.Tensor,
expected_output_img: torch.Tensor,
stochastic: bool) -> None:
stochastic: bool,
seed: int = 0) -> None:
if stochastic:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
augmented_img = data_augmentation(input_img)
@ -51,12 +53,12 @@ def _test_data_augmentation(data_augmentation: Callable[[torch.Tensor], torch.Te
def test_stain_normalization() -> None:
data_augmentation = StainNormalization()
expected_output_img = torch.Tensor(
[[
[[0.8627, 0.4510], [0.8314, 0.9373]],
[[0.6157, 0.2353], [0.8706, 0.4863]],
[[0.8235, 0.5294], [0.8275, 0.7725]]
]]
)
[[[[0.8627, 0.4510],
[0.8314, 0.9373]],
[[0.6157, 0.2353],
[0.8706, 0.4863]],
[[0.8235, 0.5294],
[0.8275, 0.7725]]]])
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=False)
@ -64,11 +66,37 @@ def test_stain_normalization() -> None:
def test_hed_jitter() -> None:
data_augmentation = HEDJitter(0.05)
expected_output_img = torch.Tensor(
[[
[[0.4536, 0.0221], [0.8084, 0.8164]],
[[0.9781, 0.9108], [0.8522, 0.8933]],
[[0.1138, 0.6730], [0.7773, 0.0666]]
]]
)
[[[[0.6241, 0.1635],
[0.9993, 1.0000]],
[[1.0000, 1.0000],
[1.0000, 1.0000]],
[[0.2232, 0.8028],
[0.9117, 0.1742]]]])
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True)
def test_gaussian_blur() -> None:
data_augmentation = GaussianBlur(3, p=1.0)
expected_output_img = torch.Tensor(
[[[[0.8302, 0.7639],
[0.8149, 0.6943]],
[[0.7423, 0.6225],
[0.6815, 0.6094]],
[[0.7821, 0.6929],
[0.7393, 0.7463]]]])
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True)
def test_random_rotation() -> None:
data_augmentation = RandomRotationByMultiplesOf90()
expected_output_img = torch.Tensor(
[[[[0.0415, 0.8420],
[0.4767, 0.8325]],
[[0.9119, 0.9098],
[0.9859, 0.8717]],
[[0.7216, 0.1127],
[0.1592, 0.8305]]]])
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True, seed=1)