* Update levircd

* fix test_getitem
This commit is contained in:
Robin Cole 2023-10-28 14:23:59 +01:00 коммит произвёл GitHub
Родитель 7641755e6f
Коммит 9e78960b58
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 6 добавлений и 6 удалений

Просмотреть файл

@ -38,9 +38,11 @@ class TestLEVIRCDPlus:
def test_getitem(self, dataset: LEVIRCDPlus) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["image1"], torch.Tensor)
assert isinstance(x["image2"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
assert x["image"].shape[0] == 2
assert x["image1"].shape[0] == 3
assert x["image2"].shape[0] == 3
def test_len(self, dataset: LEVIRCDPlus) -> None:
assert len(dataset) == 2

Просмотреть файл

@ -106,9 +106,7 @@ class LEVIRCDPlus(NonGeoDataset):
image1 = self._load_image(files["image1"])
image2 = self._load_image(files["image2"])
mask = self._load_target(files["mask"])
image = torch.stack(tensors=[image1, image2], dim=0)
sample = {"image": image, "mask": mask}
sample = {"image1": image1, "image2": image2, "mask": mask}
if self.transforms is not None:
sample = self.transforms(sample)
@ -227,7 +225,7 @@ class LEVIRCDPlus(NonGeoDataset):
.. versionadded:: 0.2
"""
image1, image2, mask = (sample["image"][0], sample["image"][1], sample["mask"])
image1, image2, mask = (sample["image1"], sample["image2"], sample["mask"])
ncols = 3
if "prediction" in sample: