зеркало из https://github.com/microsoft/torchgeo.git
transforms: Switch to kornia AugmentationSequential (#2008)
This commit is contained in:
Родитель
c8e1e09f3d
Коммит
87a5da24fa
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import kornia.augmentation as K
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torchgeo.transforms import AugmentationSequential, RandomGrayscale
|
||||
from torchgeo.transforms import RandomGrayscale
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -33,12 +34,15 @@ def batch() -> dict[str, Tensor]:
|
|||
],
|
||||
)
|
||||
def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> None:
|
||||
aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image'])
|
||||
aug = K.AugmentationSequential(
|
||||
RandomGrayscale(weights, p=1), keepdim=True, data_keys=None
|
||||
)
|
||||
# https://github.com/kornia/kornia/issues/2848
|
||||
aug.keepdim = True
|
||||
output = aug(sample)
|
||||
assert output['image'].shape == sample['image'].shape
|
||||
assert output['image'].sum() == sample['image'].sum()
|
||||
for i in range(1, 3):
|
||||
assert torch.allclose(output['image'][0, 0], output['image'][0, i])
|
||||
assert torch.allclose(output['image'][0], output['image'][i])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -50,9 +54,8 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) ->
|
|||
],
|
||||
)
|
||||
def test_random_grayscale_batch(weights: Tensor, batch: dict[str, Tensor]) -> None:
|
||||
aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image'])
|
||||
aug = K.AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=None)
|
||||
output = aug(batch)
|
||||
assert output['image'].shape == batch['image'].shape
|
||||
assert output['image'].sum() == batch['image'].sum()
|
||||
for i in range(1, 3):
|
||||
assert torch.allclose(output['image'][0, 0], output['image'][0, i])
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import kornia.augmentation as K
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -20,7 +21,6 @@ from torchgeo.transforms import (
|
|||
AppendRBNDVI,
|
||||
AppendSWI,
|
||||
AppendTriBandNormalizedDifferenceIndex,
|
||||
AugmentationSequential,
|
||||
)
|
||||
|
||||
|
||||
|
@ -42,9 +42,8 @@ def batch() -> dict[str, Tensor]:
|
|||
|
||||
def test_append_index_sample(sample: dict[str, Tensor]) -> None:
|
||||
c, h, w = sample['image'].shape
|
||||
aug = AugmentationSequential(
|
||||
AppendNormalizedDifferenceIndex(index_a=0, index_b=1),
|
||||
data_keys=['image', 'mask'],
|
||||
aug = K.AugmentationSequential(
|
||||
AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None
|
||||
)
|
||||
output = aug(sample)
|
||||
assert output['image'].shape == (1, c + 1, h, w)
|
||||
|
@ -52,9 +51,8 @@ def test_append_index_sample(sample: dict[str, Tensor]) -> None:
|
|||
|
||||
def test_append_index_batch(batch: dict[str, Tensor]) -> None:
|
||||
b, c, h, w = batch['image'].shape
|
||||
aug = AugmentationSequential(
|
||||
AppendNormalizedDifferenceIndex(index_a=0, index_b=1),
|
||||
data_keys=['image', 'mask'],
|
||||
aug = K.AugmentationSequential(
|
||||
AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None
|
||||
)
|
||||
output = aug(batch)
|
||||
assert output['image'].shape == (b, c + 1, h, w)
|
||||
|
@ -62,9 +60,9 @@ def test_append_index_batch(batch: dict[str, Tensor]) -> None:
|
|||
|
||||
def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None:
|
||||
b, c, h, w = batch['image'].shape
|
||||
aug = AugmentationSequential(
|
||||
aug = K.AugmentationSequential(
|
||||
AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2),
|
||||
data_keys=['image', 'mask'],
|
||||
data_keys=None,
|
||||
)
|
||||
output = aug(batch)
|
||||
assert output['image'].shape == (b, c + 1, h, w)
|
||||
|
@ -88,7 +86,7 @@ def test_append_normalized_difference_indices(
|
|||
sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex
|
||||
) -> None:
|
||||
c, h, w = sample['image'].shape
|
||||
aug = AugmentationSequential(index(0, 1), data_keys=['image', 'mask'])
|
||||
aug = K.AugmentationSequential(index(0, 1), data_keys=None)
|
||||
output = aug(sample)
|
||||
assert output['image'].shape == (1, c + 1, h, w)
|
||||
|
||||
|
@ -98,6 +96,6 @@ def test_append_tri_band_normalized_difference_indices(
|
|||
sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex
|
||||
) -> None:
|
||||
c, h, w = sample['image'].shape
|
||||
aug = AugmentationSequential(index(0, 1, 2), data_keys=['image', 'mask'])
|
||||
aug = K.AugmentationSequential(index(0, 1, 2), data_keys=None)
|
||||
output = aug(sample)
|
||||
assert output['image'].shape == (1, c + 1, h, w)
|
||||
|
|
Загрузка…
Ссылка в новой задаче