diff --git a/tests/transforms/test_color.py b/tests/transforms/test_color.py index 2e271f89b..2cea90b39 100644 --- a/tests/transforms/test_color.py +++ b/tests/transforms/test_color.py @@ -1,11 +1,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import kornia.augmentation as K import pytest import torch from torch import Tensor -from torchgeo.transforms import AugmentationSequential, RandomGrayscale +from torchgeo.transforms import RandomGrayscale @pytest.fixture @@ -33,12 +34,15 @@ def batch() -> dict[str, Tensor]: ], ) def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> None: - aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image']) + aug = K.AugmentationSequential( + RandomGrayscale(weights, p=1), keepdim=True, data_keys=None + ) + # https://github.com/kornia/kornia/issues/2848 + aug.keepdim = True output = aug(sample) assert output['image'].shape == sample['image'].shape - assert output['image'].sum() == sample['image'].sum() for i in range(1, 3): - assert torch.allclose(output['image'][0, 0], output['image'][0, i]) + assert torch.allclose(output['image'][0], output['image'][i]) @pytest.mark.parametrize( @@ -50,9 +54,8 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> ], ) def test_random_grayscale_batch(weights: Tensor, batch: dict[str, Tensor]) -> None: - aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image']) + aug = K.AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=None) output = aug(batch) assert output['image'].shape == batch['image'].shape - assert output['image'].sum() == batch['image'].sum() for i in range(1, 3): assert torch.allclose(output['image'][0, 0], output['image'][0, i]) diff --git a/tests/transforms/test_indices.py b/tests/transforms/test_indices.py index 3d83f8573..9e6f54e48 100644 --- a/tests/transforms/test_indices.py +++ b/tests/transforms/test_indices.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import kornia.augmentation as K import pytest import torch from torch import Tensor @@ -20,7 +21,6 @@ from torchgeo.transforms import ( AppendRBNDVI, AppendSWI, AppendTriBandNormalizedDifferenceIndex, - AugmentationSequential, ) @@ -42,9 +42,8 @@ def batch() -> dict[str, Tensor]: def test_append_index_sample(sample: dict[str, Tensor]) -> None: c, h, w = sample['image'].shape - aug = AugmentationSequential( - AppendNormalizedDifferenceIndex(index_a=0, index_b=1), - data_keys=['image', 'mask'], + aug = K.AugmentationSequential( + AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None ) output = aug(sample) assert output['image'].shape == (1, c + 1, h, w) @@ -52,9 +51,8 @@ def test_append_index_sample(sample: dict[str, Tensor]) -> None: def test_append_index_batch(batch: dict[str, Tensor]) -> None: b, c, h, w = batch['image'].shape - aug = AugmentationSequential( - AppendNormalizedDifferenceIndex(index_a=0, index_b=1), - data_keys=['image', 'mask'], + aug = K.AugmentationSequential( + AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None ) output = aug(batch) assert output['image'].shape == (b, c + 1, h, w) @@ -62,9 +60,9 @@ def test_append_index_batch(batch: dict[str, Tensor]) -> None: def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None: b, c, h, w = batch['image'].shape - aug = AugmentationSequential( + aug = K.AugmentationSequential( AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2), - data_keys=['image', 'mask'], + data_keys=None, ) output = aug(batch) assert output['image'].shape == (b, c + 1, h, w) @@ -88,7 +86,7 @@ def test_append_normalized_difference_indices( sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex ) -> None: c, h, w = sample['image'].shape - aug = AugmentationSequential(index(0, 1), data_keys=['image', 'mask']) + aug = K.AugmentationSequential(index(0, 1), data_keys=None) output = aug(sample) assert output['image'].shape == (1, c + 1, h, w) @@ -98,6 +96,6 @@ def test_append_tri_band_normalized_difference_indices( sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex ) -> None: c, h, w = sample['image'].shape - aug = AugmentationSequential(index(0, 1, 2), data_keys=['image', 'mask']) + aug = K.AugmentationSequential(index(0, 1, 2), data_keys=None) output = aug(sample) assert output['image'].shape == (1, c + 1, h, w)