зеркало из https://github.com/microsoft/hi-ml.git
ENH: HEDJitter update (#643)
Updated the HEDjitter augmentation to reflect the update in https://github.com/scikit-image/scikit-image/blob/main/skimage/color/colorconv.py#L1463-L1504 In addition, I included a normalization step. Our implementation is now following exactly https://github.com/gatsby2016/Augmentation-PyTorch-Transforms/blob/master/myTransforms.py while using only Pytorch (instead of skimage).
This commit is contained in:
Родитель
431ff2769c
Коммит
21c30d8cc1
|
@ -12,10 +12,10 @@ class HEDJitter(object):
|
|||
"""
|
||||
Randomly perturbe the HED color space value an RGB image.
|
||||
|
||||
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix,
|
||||
taken from Ruifrok and Johnston (2001): "Quantification of histochemical staining by color deconvolution."
|
||||
First, it disentangled the hematoxylin and eosin color channels by color deconvolution method using a fixed matrix.
|
||||
Second, it perturbed the hematoxylin, eosin stains independently.
|
||||
Third, it transformed the resulting stains into regular RGB color space.
|
||||
PyTorch version of: https://github.com/gatsby2016/Augmentation-PyTorch-Transforms/blob/master/myTransforms.py
|
||||
|
||||
Usage example:
|
||||
>>> transform = HEDJitter(0.05)
|
||||
|
@ -37,45 +37,45 @@ class HEDJitter(object):
|
|||
self.hed_from_rgb = torch.tensor([[1.87798274, -1.00767869, -0.55611582],
|
||||
[-0.06590806, 1.13473037, -0.1355218],
|
||||
[-0.60190736, -0.48041419, 1.57358807]])
|
||||
self.log_adjust = torch.log(torch.tensor(1E-6))
|
||||
|
||||
@staticmethod
|
||||
def adjust_hed(img: torch.Tensor,
|
||||
theta: float,
|
||||
stain_from_rgb_mat: torch.Tensor,
|
||||
rgb_from_stain_mat: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
def adjust_hed(self, img: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies HED jitter to image.
|
||||
|
||||
:param img: Input image.
|
||||
:param theta: Strength of the jitter. HED_light: theta=0.05; HED_strong: theta=0.2.
|
||||
:param stain_from_rgb_mat: Transformation matrix from HED to RGB.
|
||||
:param rgb_from_stain_mat: Transformation matrix from RGB to HED.
|
||||
"""
|
||||
alpha = torch.FloatTensor(1, 3).uniform_(1 - theta, 1 + theta)
|
||||
beta = torch.FloatTensor(1, 3).uniform_(-theta, theta)
|
||||
|
||||
# Only perturb the H (=0) and E (=1) channels
|
||||
alpha[0][-1] = 1.
|
||||
beta[0][-1] = 0.
|
||||
alpha = torch.FloatTensor(img.shape[0], 1, 1, 3).uniform_(1 - self.theta, 1 + self.theta)
|
||||
beta = torch.FloatTensor(img.shape[0], 1, 1, 3).uniform_(-self.theta, self.theta)
|
||||
|
||||
# Separate stains
|
||||
img = img.permute([0, 2, 3, 1])
|
||||
img = img + 2 # for consistency with skimage
|
||||
stains = -torch.log10(img) @ stain_from_rgb_mat
|
||||
stains = alpha * stains + beta # perturbations in HED color space
|
||||
img = torch.maximum(img, 1E-6 * torch.ones(img.shape))
|
||||
stains = (torch.log(img) / self.log_adjust) @ self.hed_from_rgb
|
||||
stains = torch.maximum(stains, torch.zeros(stains.shape))
|
||||
|
||||
# perturbations in HED color space
|
||||
stains = alpha * stains + beta
|
||||
|
||||
# Combine stains
|
||||
img = 10 ** (-stains @ rgb_from_stain_mat) - 2
|
||||
img = -(stains * (-self.log_adjust)) @ self.rgb_from_hed
|
||||
img = torch.exp(img)
|
||||
img = torch.clip(img, 0, 1)
|
||||
img = img.permute(0, 3, 1, 2)
|
||||
|
||||
# Normalize
|
||||
imin = torch.amin(img, dim=[1, 2, 3], keepdim=True)
|
||||
imax = torch.amax(img, dim=[1, 2, 3], keepdim=True)
|
||||
img = (img - imin) / (imax - imin)
|
||||
|
||||
return img
|
||||
|
||||
def __call__(self, img: torch.Tensor) -> torch.Tensor:
|
||||
if img.shape[1] != 3:
|
||||
raise ValueError("HED jitter can only be applied to images with 3 channels (RGB).")
|
||||
return self.adjust_hed(img, self.theta, self.hed_from_rgb, self.rgb_from_hed)
|
||||
|
||||
return self.adjust_hed(img)
|
||||
|
||||
|
||||
class StainNormalization(object):
|
||||
|
|
|
@ -71,15 +71,27 @@ def test_stain_normalization() -> None:
|
|||
|
||||
def test_hed_jitter() -> None:
|
||||
data_augmentation = HEDJitter(0.05)
|
||||
expected_output_img = torch.Tensor(
|
||||
[[[[0.6241, 0.1635],
|
||||
[0.9993, 1.0000]],
|
||||
[[1.0000, 1.0000],
|
||||
[1.0000, 1.0000]],
|
||||
[[0.2232, 0.8028],
|
||||
[0.9117, 0.1742]]]])
|
||||
expected_output_img1 = torch.Tensor(
|
||||
[[[[0.9639, 0.4130],
|
||||
[0.9134, 1.0000]],
|
||||
[[0.3125, 0.0000],
|
||||
[0.4474, 0.1820]],
|
||||
[[0.9195, 0.5265],
|
||||
[0.9118, 0.8291]]]])
|
||||
expected_output_img2 = torch.Tensor(
|
||||
[[[[0.8411, 0.2361],
|
||||
[0.7857, 0.8766]],
|
||||
[[0.7075, 0.0000],
|
||||
[1.0000, 0.4138]],
|
||||
[[0.9694, 0.4674],
|
||||
[0.9577, 0.8476]]]])
|
||||
expected_output_bag = torch.vstack([expected_output_img1,
|
||||
expected_output_img2])
|
||||
|
||||
_test_data_augmentation(data_augmentation, dummy_img, expected_output_img, stochastic=True)
|
||||
_test_data_augmentation(data_augmentation,
|
||||
dummy_bag,
|
||||
expected_output_bag,
|
||||
stochastic=True)
|
||||
|
||||
|
||||
def test_gaussian_blur() -> None:
|
||||
|
|
Загрузка…
Ссылка в новой задаче