This commit is contained in:
Adam J. Stewart 2021-06-24 21:36:01 +00:00
Родитель 4178d278c8
Коммит 4c6ddce84a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
1 изменённых файлов: 28 добавлений и 10 удалений

Просмотреть файл

@ -10,24 +10,36 @@ from torchgeo.transforms import transforms
@pytest.fixture
def sample() -> Dict[str, Tensor]:
return {
"image": torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]),
"masks": torch.tensor([[0, 0, 1], [0, 1, 1], [1, 1, 1]]),
"boxes": torch.tensor([[0, 0, 2, 2], [1, 1, 3, 3]]),
"image": torch.tensor( # type: ignore[attr-defined]
[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
),
"masks": torch.tensor( # type: ignore[attr-defined]
[[0, 0, 1], [0, 1, 1], [1, 1, 1]]
),
"boxes": torch.tensor( # type: ignore[attr-defined]
[[0, 0, 2, 2], [1, 1, 3, 3]]
),
}
def assert_matching(output: Dict[str, Tensor], expected: Dict[str, Tensor]) -> None:
for key in expected:
assert torch.allclose(output[key], expected[key])
assert torch.allclose(output[key], expected[key]) # type: ignore[attr-defined]
def test_random_horizontal_flip(sample: Dict[str, Tensor]) -> None:
tr = transforms.RandomHorizontalFlip(p=1)
output = tr(sample)
expected = {
"image": torch.tensor([[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]),
"masks": torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]),
"boxes": torch.tensor([[1, 0, 3, 2], [0, 1, 2, 3]]),
"image": torch.tensor( # type: ignore[attr-defined]
[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]
),
"masks": torch.tensor( # type: ignore[attr-defined]
[[1, 0, 0], [1, 1, 0], [1, 1, 1]]
),
"boxes": torch.tensor( # type: ignore[attr-defined]
[[1, 0, 3, 2], [0, 1, 2, 3]]
),
}
assert_matching(output, expected)
@ -36,9 +48,15 @@ def test_random_vertical_flip(sample: Dict[str, Tensor]) -> None:
tr = transforms.RandomVerticalFlip(p=1)
output = tr(sample)
expected = {
"image": torch.tensor([[[7, 8, 9], [4, 5, 6], [1, 2, 3]]]),
"masks": torch.tensor([[1, 1, 1], [0, 1, 1], [0, 0, 1]]),
"boxes": torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]),
"image": torch.tensor( # type: ignore[attr-defined]
[[[7, 8, 9], [4, 5, 6], [1, 2, 3]]]
),
"masks": torch.tensor( # type: ignore[attr-defined]
[[1, 1, 1], [0, 1, 1], [0, 0, 1]]
),
"boxes": torch.tensor( # type: ignore[attr-defined]
[[0, 1, 2, 3], [1, 0, 3, 2]]
),
}
assert_matching(output, expected)