Updated the HEDjitter augmentation to reflect the update in
https://github.com/scikit-image/scikit-image/blob/main/skimage/color/colorconv.py#L1463-L1504

In addition, I included a normalization step. Our implementation is now
following exactly
https://github.com/gatsby2016/Augmentation-PyTorch-Transforms/blob/master/myTransforms.py
while using only Pytorch (instead of skimage).
This commit is contained in:
maxilse 2022-10-28 12:32:37 +02:00 коммит произвёл GitHub
Родитель 431ff2769c
Коммит 21c30d8cc1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 41 добавлений и 29 удалений

Просмотреть файл

@ -12,10 +12,10 @@ class HEDJitter(object):
"""
Randomly perturbe the HED color space value an RGB image.
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix,
taken from Ruifrok and Johnston (2001): "Quantification of histochemical staining by color deconvolution."
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix.
Second, it perturbed the hematoxylin, eosin stains independently.
Third, it transformed the resulting stains into regular RGB color space.
PyTorch version of: https://github.com/gatsby2016/Augmentation-PyTorch-Transforms/blob/master/myTransforms.py
Usage example:
>>> transform = HEDJitter(0.05)
@ -37,45 +37,45 @@ class HEDJitter(object):
self.hed_from_rgb = torch.tensor([[1.87798274, -1.00767869, -0.55611582],
[-0.06590806, 1.13473037, -0.1355218],
[-0.60190736, -0.48041419, 1.57358807]])
self.log_adjust = torch.log(torch.tensor(1E-6))
@staticmethod
def adjust_hed(img: torch.Tensor,
theta: float,
stain_from_rgb_mat: torch.Tensor,
rgb_from_stain_mat: torch.Tensor
) -> torch.Tensor:
def adjust_hed(self, img: torch.Tensor) -> torch.Tensor:
"""
Applies HED jitter to image.
:param img: Input image.
:param theta: Strength of the jitter. HED_light: theta=0.05; HED_strong: theta=0.2.
:param stain_from_rgb_mat: Transformation matrix from HED to RGB.
:param rgb_from_stain_mat: Transformation matrix from RGB to HED.
"""
alpha = torch.FloatTensor(1, 3).uniform_(1 - theta, 1 + theta)
beta = torch.FloatTensor(1, 3).uniform_(-theta, theta)
# Only perturb the H (=0) and E (=1) channels
alpha[0][-1] = 1.
beta[0][-1] = 0.
alpha = torch.FloatTensor(img.shape[0], 1, 1, 3).uniform_(1 - self.theta, 1 + self.theta)
beta = torch.FloatTensor(img.shape[0], 1, 1, 3).uniform_(-self.theta, self.theta)
# Separate stains
img = img.permute([0, 2, 3, 1])
img = img + 2 # for consistency with skimage
stains = -torch.log10(img) @ stain_from_rgb_mat
stains = alpha * stains + beta # perturbations in HED color space
img = torch.maximum(img, 1E-6 * torch.ones(img.shape))
stains = (torch.log(img) / self.log_adjust) @ self.hed_from_rgb
stains = torch.maximum(stains, torch.zeros(stains.shape))
# perturbations in HED color space
stains = alpha * stains + beta
# Combine stains
img = 10 ** (-stains @ rgb_from_stain_mat) - 2
img = -(stains * (-self.log_adjust)) @ self.rgb_from_hed
img = torch.exp(img)
img = torch.clip(img, 0, 1)
img = img.permute(0, 3, 1, 2)
# Normalize
imin = torch.amin(img, dim=[1, 2, 3], keepdim=True)
imax = torch.amax(img, dim=[1, 2, 3], keepdim=True)
img = (img - imin) / (imax - imin)
return img
def __call__(self, img: torch.Tensor) -> torch.Tensor:
if img.shape[1] != 3:
raise ValueError("HED jitter can only be applied to images with 3 channels (RGB).")
return self.adjust_hed(img, self.theta, self.hed_from_rgb, self.rgb_from_hed)
return self.adjust_hed(img)
class StainNormalization(object):

Просмотреть файл

@ -71,15 +71,27 @@ def test_stain_normalization() -> None:
def test_hed_jitter() -> None:
data_augmentation = HEDJitter(0.05)
expected_output_img = torch.Tensor(
[[[[0.6241, 0.1635],
[0.9993, 1.0000]],
[[1.0000, 1.0000],
[1.0000, 1.0000]],
[[0.2232, 0.8028],
[0.9117, 0.1742]]]])
expected_output_img1 = torch.Tensor(
[[[[0.9639, 0.4130],
[0.9134, 1.0000]],
[[0.3125, 0.0000],
[0.4474, 0.1820]],
[[0.9195, 0.5265],
[0.9118, 0.8291]]]])
expected_output_img2 = torch.Tensor(
[[[[0.8411, 0.2361],
[0.7857, 0.8766]],
[[0.7075, 0.0000],
[1.0000, 0.4138]],
[[0.9694, 0.4674],
[0.9577, 0.8476]]]])
expected_output_bag = torch.vstack([expected_output_img1,
expected_output_img2])
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True)
_test_data_augmentation(data_augmentation,
dummy_bag,
expected_output_bag,
stochastic=True)
def test_gaussian_blur() -> None: