зеркало из https://github.com/microsoft/hi-ml.git
Updated augmentations for histology images (#179)
Fixed HEDJitter augmentation, the previous version was also jittering the D channel Removed dependency of StainNormalization on skimage, which makes it also faster Added GaussianBlur that is faster than the torchvision version Added rotation by multiples of 90 degrees Adjusted/added the test for all augmentations. All the above augmentations are frequently used for histopathology. Added torchvision to the environment since it is necessary for the augmentations.
This commit is contained in:
Родитель
7902a6e138
Коммит
a33c1ed07d
16
CHANGELOG.md
16
CHANGELOG.md
|
@ -10,6 +10,22 @@ For each Pull Request, the affected code parts should be briefly described and a
|
|||
release. In the first PR after a release has been made, a section for the upcoming release should be added, by copying
|
||||
the section headers (Added/Changed/...) and incrementing the package version.
|
||||
|
||||
## 0.1.14
|
||||
|
||||
### Added
|
||||
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) Add GaussianBlur and RotationByMultiplesOf90 augmentations. Added torchvision and opencv to
|
||||
the environment file since it is necessary for the augmentations.
|
||||
|
||||
### Changed
|
||||
|
||||
### Fixed
|
||||
- ([#179](https://github.com/microsoft/hi-ml/pull/179)) HEDJitter was jittering the D channel as well. StainNormalization was relying on skimage.
|
||||
|
||||
### Removed
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
||||
## 0.1.13
|
||||
|
||||
### Added
|
||||
|
|
|
@ -70,7 +70,7 @@ class WrappedTensorboard(Tensorboard):
|
|||
self._run_watchers.append(run_watcher)
|
||||
|
||||
for w in self._run_watchers:
|
||||
self._executor.submit(w.refresh_requeue)
|
||||
self._executor.submit(w.refresh_requeue) # type: ignore
|
||||
|
||||
# We use sys.executable here to ensure that we can import modules from the same environment
|
||||
# as the current process.
|
||||
|
|
|
@ -4,5 +4,5 @@ matplotlib==3.4.3
|
|||
opencv-python-headless==4.5.1.48
|
||||
pandas==1.3.4
|
||||
pytorch-lightning>=1.4.9
|
||||
scikit-image==0.17.2
|
||||
torchvision==0.9.0
|
||||
torch>=1.8
|
||||
|
|
|
@ -1,57 +1,92 @@
|
|||
from skimage import color
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
|
||||
class HEDJitter(object):
|
||||
"""
|
||||
A class to randomly perturb the HEAD color space value of an RGB image
|
||||
"""
|
||||
def __init__(self, theta: float = 0.) -> None: # HED_light: theta=0.05; HED_strong: theta=0.2
|
||||
Randomly perturbe the HED color space value an RGB image.
|
||||
|
||||
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix,
|
||||
taken from Ruifrok and Johnston (2001): "Quantification of histochemical staining by color deconvolution."
|
||||
Second, it perturbed the hematoxylin, eosin stains independently.
|
||||
Third, it transformed the resulting stains into regular RGB color space.
|
||||
|
||||
Usage example:
|
||||
>>> transform = HEDJitter(0.05)
|
||||
>>> img = transform(img)
|
||||
"""
|
||||
def __init__(self, theta: float = 0.) -> None:
|
||||
"""
|
||||
:param theta: How much to jitter HED color space.
|
||||
HED_light: theta=0.05; HED_strong: theta=0.2.
|
||||
alpha is chosen from a uniform distribution [1-theta, 1+theta].
|
||||
beta is chosen from a uniform distribution [-theta, theta].
|
||||
The jitter formula is :math:`s' = \alpha * s + \beta`.
|
||||
"""
|
||||
self.theta = theta
|
||||
self.rgb_from_hed = torch.tensor([[0.65, 0.70, 0.29],
|
||||
[0.07, 0.99, 0.11],
|
||||
[0.27, 0.57, 0.78]])
|
||||
self.hed_from_rgb = torch.tensor([[1.87798274, -1.00767869, -0.55611582],
|
||||
[-0.06590806, 1.13473037, -0.1355218],
|
||||
[-0.60190736, -0.48041419, 1.57358807]])
|
||||
|
||||
@staticmethod
|
||||
def adjust_hed(img: torch.Tensor, theta: float) -> torch.Tensor:
|
||||
def adjust_hed(img: torch.Tensor,
|
||||
theta: float,
|
||||
stain_from_rgb_mat: torch.Tensor,
|
||||
rgb_from_stain_mat: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Randomly perturb the hematoxylin-Eosin-DAB (HED) color space value of an RGB image
|
||||
Steps involved in this process:
|
||||
1. separate the stains (RGB to HED color space conversion)
|
||||
2. perturb the stains independently
|
||||
3. convert the resulting stains back to RGB color space
|
||||
Applies HED jitter to image.
|
||||
|
||||
:param img: A Torch Tensor representing the image to be transformed
|
||||
:param theta: A float representing how much to jitter HED color space by
|
||||
:return: a Torch Tensor of stains transformed into RGB color space.
|
||||
:param img: Input image.
|
||||
:param theta: Strength of the jitter. HED_light: theta=0.05; HED_strong: theta=0.2.
|
||||
:param stain_from_rgb_mat: Transformation matrix from HED to RGB.
|
||||
:param rgb_from_stain_mat: Transformation matrix from RGB to HED.
|
||||
"""
|
||||
# alpha is chosen from a uniform distribution [1 - theta, 1 + theta]
|
||||
alpha = np.random.uniform(1 - theta, 1 + theta, (1, 3))
|
||||
# beta is chosen from a uniform distribution [-theta, theta]
|
||||
beta = np.random.uniform(-theta, theta, (1, 3))
|
||||
alpha = torch.FloatTensor(1, 3).uniform_(1 - theta, 1 + theta)
|
||||
beta = torch.FloatTensor(1, 3).uniform_(-theta, theta)
|
||||
|
||||
assert img.ndim == 4, "Expected a Tensor with 4 dimensions"
|
||||
# channel dim must be last for next function
|
||||
img = img.permute([0, 2, 3, 1]).numpy()
|
||||
s = color.rgb2hed(img)
|
||||
# Only perturb the H (=0) and E (=1) channels
|
||||
alpha[0][-1] = 1.
|
||||
beta[0][-1] = 0.
|
||||
|
||||
# the jitter formula (perturbations in HED color space) is **s' = \alpha * s + \beta**
|
||||
ns = alpha * s + beta
|
||||
# Separate stains
|
||||
img = img.permute([0, 2, 3, 1])
|
||||
img = img + 2 # for consistency with skimage
|
||||
stains = -torch.log10(img) @ stain_from_rgb_mat
|
||||
stains = alpha * stains + beta # perturbations in HED color space
|
||||
|
||||
nimg = color.hed2rgb(ns)
|
||||
nimg = np.clip(nimg, 0, 1)
|
||||
nimg = torch.Tensor(nimg).permute(0, 3, 1, 2)
|
||||
# Combine stains
|
||||
img = 10 ** (-stains @ rgb_from_stain_mat) - 2
|
||||
img = torch.clip(img, 0, 1)
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
|
||||
return nimg
|
||||
return img
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
return self.adjust_hed(img, self.theta)
|
||||
return self.adjust_hed(img, self.theta, self.hed_from_rgb, self.rgb_from_hed)
|
||||
|
||||
|
||||
class StainNormalization(object):
|
||||
"""
|
||||
A class to normalize the stain of an image given a reference image. Following
|
||||
Erik Reinhard,Bruce Gooch., “Color Transfer between Images,” IEEE ComputerGraphics and Applications.
|
||||
"""Normalize the stain of an image given a reference image.
|
||||
|
||||
Following Erik Reinhard, Bruce Gooch (2001): “Color Transfer between Images.”
|
||||
First, mask all white pixels.
|
||||
Second, convert remaining pixels to lab space and normalize each channel.
|
||||
Third, add mean and std of reference image.
|
||||
Fourth, convert back to rgb and add white pixels back.
|
||||
|
||||
Usage example:
|
||||
>>> transform = StainNormalization()
|
||||
>>> img = transform(img)
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
# mean and std per channel of a reference image
|
||||
|
@ -61,22 +96,13 @@ class StainNormalization(object):
|
|||
@staticmethod
|
||||
def stain_normalize(img: torch.Tensor, reference_mean: np.ndarray, reference_std: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Normalize the stain of an image given a reference image
|
||||
Applies stain normalization to image.
|
||||
|
||||
Steps involved:
|
||||
1. mask all white pixels
|
||||
2. convert remaining pixels to lab space and normalize each channel
|
||||
3. add mean and std of reference image
|
||||
4. convert back to rgb and add white pixels back
|
||||
|
||||
:param img: the image whose stain should be normalised
|
||||
:param reference_mean: the mean of the reference image, for normalisation
|
||||
:param reference_std: the standard deviation of the reference image, for normalisation
|
||||
:return: A Torch tensor representing the image with normalized stain
|
||||
:param img: Input image.
|
||||
:param reference_mean: Mean per channel of a reference image.
|
||||
:param reference_std: STD per channel of a reference image.
|
||||
"""
|
||||
assert img.ndim == 4, "Expected a Tensor with 4 dimensions"
|
||||
# only 3 channels, color channel last, range 0 - 255
|
||||
img = img.permute([0, 2, 3, 1]).squeeze().numpy() * 255
|
||||
img = img.permute([0, 2, 3, 1]).squeeze().numpy() * 255 # only 3 channels, color channel last, range 0 - 255
|
||||
img = img.astype(np.uint8) # type: ignore
|
||||
|
||||
whitemask = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||
|
@ -86,8 +112,7 @@ class StainNormalization(object):
|
|||
imagelab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
|
||||
imagelab_masked = np.ma.MaskedArray(imagelab, whitemask) # type: np.ma.MaskedArray
|
||||
|
||||
# Sometimes STD is near 0, add epsilon to avoid div by 0
|
||||
epsilon = 1e-11
|
||||
epsilon = 1e-11 # Sometimes STD is near 0, add epsilon to avoid div by 0
|
||||
imagelab_masked_mean = imagelab_masked.mean(axis=(0, 1))
|
||||
imagelab_masked_std = imagelab_masked.std(axis=(0, 1)) + epsilon
|
||||
|
||||
|
@ -96,13 +121,65 @@ class StainNormalization(object):
|
|||
imagelab = np.clip(imagelab, 0, 255)
|
||||
imagelab = imagelab.astype(np.uint8)
|
||||
nimg = cv2.cvtColor(imagelab, cv2.COLOR_LAB2RGB)
|
||||
|
||||
# add back white pixels
|
||||
nimg[whitemask] = img[whitemask]
|
||||
# convert back to Tensor
|
||||
nimg = torch.Tensor(nimg).unsqueeze(0).permute(0, 3, 1, 2) / 255.
|
||||
nimg[whitemask] = img[whitemask] # add back white pixels
|
||||
nimg = torch.Tensor(nimg).unsqueeze(0).permute(0, 3, 1, 2) / 255. # back to pytorch format
|
||||
|
||||
return nimg
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
return self.stain_normalize(img, self.reference_mean, self.reference_std)
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
"""
|
||||
Implements Gaussian blur as described in the SimCLR paper (https://arxiv.org/abs/2002.05709).
|
||||
|
||||
Blur image using a Gaussian kernel with a randomly sampled STD.
|
||||
Slight modification of the code in pl_bolts to make it work with our transform pipeline.
|
||||
|
||||
Usage example:
|
||||
>>> transform = GaussianBlur(kernel_size=int(224 * 0.1) + 1)
|
||||
>>> img = transform(img)
|
||||
"""
|
||||
def __init__(self, kernel_size: int, p: float = 0.5, min: float = 0.1, max: float = 2.0) -> None:
|
||||
"""
|
||||
:param kernel_size: Size of the Gaussian kernel, e.g., about 10% of the image size.
|
||||
:param p: Probability of applying blur.
|
||||
:param min: lower bound of the interval from which we sample the STD
|
||||
:param max: upper bound of the interval from which we sample the STD
|
||||
"""
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.kernel_size = kernel_size
|
||||
self.p = p
|
||||
|
||||
def __call__(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
prob = np.random.random_sample()
|
||||
|
||||
if prob < self.p:
|
||||
sigma = (self.max - self.min) * np.random.random_sample() + self.min
|
||||
sample = np.array(sample.squeeze()) # type: ignore
|
||||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma)
|
||||
sample = torch.Tensor(sample).unsqueeze(0)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class RandomRotationByMultiplesOf90(object):
|
||||
"""
|
||||
Rotation of input image by 0, 90, 180 or 270 degrees.
|
||||
|
||||
Usage example:
|
||||
>>> transform = RandomRotationByMultiplesOf90()
|
||||
>>> img = transform(img)
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
angle = np.random.choice([0., 90., 180., 270.])
|
||||
|
||||
if angle != 0.:
|
||||
sample = TF.rotate(sample, angle)
|
||||
|
||||
return sample
|
||||
|
|
|
@ -2,29 +2,31 @@ import torch
|
|||
import numpy as np
|
||||
import random
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Callable
|
||||
|
||||
from health_ml.utils.data_augmentations import HEDJitter, StainNormalization
|
||||
|
||||
from health_ml.utils.data_augmentations import HEDJitter, StainNormalization, GaussianBlur, \
|
||||
RandomRotationByMultiplesOf90
|
||||
|
||||
# global dummy image
|
||||
dummy_img = torch.Tensor(
|
||||
[[
|
||||
[[0.4767, 0.0415], [0.8325, 0.8420]],
|
||||
[[0.9859, 0.9119], [0.8717, 0.9098]],
|
||||
[[0.1592, 0.7216], [0.8305, 0.1127]]
|
||||
]]
|
||||
)
|
||||
[[[[0.4767, 0.0415],
|
||||
[0.8325, 0.8420]],
|
||||
[[0.9859, 0.9119],
|
||||
[0.8717, 0.9098]],
|
||||
[[0.1592, 0.7216],
|
||||
[0.8305, 0.1127]]]])
|
||||
|
||||
|
||||
def _test_data_augmentation(data_augmentation: Callable[[torch.Tensor], torch.Tensor],
|
||||
def _test_data_augmentation(data_augmentation: Callable[[Tensor], Tensor],
|
||||
input_img: torch.Tensor,
|
||||
expected_output_img: torch.Tensor,
|
||||
stochastic: bool) -> None:
|
||||
stochastic: bool,
|
||||
seed: int = 0) -> None:
|
||||
if stochastic:
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
random.seed(0)
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
augmented_img = data_augmentation(input_img)
|
||||
|
||||
|
@ -51,12 +53,12 @@ def _test_data_augmentation(data_augmentation: Callable[[torch.Tensor], torch.Te
|
|||
def test_stain_normalization() -> None:
|
||||
data_augmentation = StainNormalization()
|
||||
expected_output_img = torch.Tensor(
|
||||
[[
|
||||
[[0.8627, 0.4510], [0.8314, 0.9373]],
|
||||
[[0.6157, 0.2353], [0.8706, 0.4863]],
|
||||
[[0.8235, 0.5294], [0.8275, 0.7725]]
|
||||
]]
|
||||
)
|
||||
[[[[0.8627, 0.4510],
|
||||
[0.8314, 0.9373]],
|
||||
[[0.6157, 0.2353],
|
||||
[0.8706, 0.4863]],
|
||||
[[0.8235, 0.5294],
|
||||
[0.8275, 0.7725]]]])
|
||||
|
||||
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=False)
|
||||
|
||||
|
@ -64,11 +66,37 @@ def test_stain_normalization() -> None:
|
|||
def test_hed_jitter() -> None:
|
||||
data_augmentation = HEDJitter(0.05)
|
||||
expected_output_img = torch.Tensor(
|
||||
[[
|
||||
[[0.4536, 0.0221], [0.8084, 0.8164]],
|
||||
[[0.9781, 0.9108], [0.8522, 0.8933]],
|
||||
[[0.1138, 0.6730], [0.7773, 0.0666]]
|
||||
]]
|
||||
)
|
||||
[[[[0.6241, 0.1635],
|
||||
[0.9993, 1.0000]],
|
||||
[[1.0000, 1.0000],
|
||||
[1.0000, 1.0000]],
|
||||
[[0.2232, 0.8028],
|
||||
[0.9117, 0.1742]]]])
|
||||
|
||||
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True)
|
||||
|
||||
|
||||
def test_gaussian_blur() -> None:
|
||||
data_augmentation = GaussianBlur(3, p=1.0)
|
||||
expected_output_img = torch.Tensor(
|
||||
[[[[0.8302, 0.7639],
|
||||
[0.8149, 0.6943]],
|
||||
[[0.7423, 0.6225],
|
||||
[0.6815, 0.6094]],
|
||||
[[0.7821, 0.6929],
|
||||
[0.7393, 0.7463]]]])
|
||||
|
||||
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True)
|
||||
|
||||
|
||||
def test_random_rotation() -> None:
|
||||
data_augmentation = RandomRotationByMultiplesOf90()
|
||||
expected_output_img = torch.Tensor(
|
||||
[[[[0.0415, 0.8420],
|
||||
[0.4767, 0.8325]],
|
||||
[[0.9119, 0.9098],
|
||||
[0.9859, 0.8717]],
|
||||
[[0.7216, 0.1127],
|
||||
[0.1592, 0.8305]]]])
|
||||
|
||||
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True, seed=1)
|
||||
|
|
Загрузка…
Ссылка в новой задаче