зеркало из https://github.com/microsoft/torchgeo.git
Родитель
7641755e6f
Коммит
9e78960b58
|
@ -38,9 +38,11 @@ class TestLEVIRCDPlus:
|
|||
def test_getitem(self, dataset: LEVIRCDPlus) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["image1"], torch.Tensor)
|
||||
assert isinstance(x["image2"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
assert x["image"].shape[0] == 2
|
||||
assert x["image1"].shape[0] == 3
|
||||
assert x["image2"].shape[0] == 3
|
||||
|
||||
def test_len(self, dataset: LEVIRCDPlus) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
|
|
@ -106,9 +106,7 @@ class LEVIRCDPlus(NonGeoDataset):
|
|||
image1 = self._load_image(files["image1"])
|
||||
image2 = self._load_image(files["image2"])
|
||||
mask = self._load_target(files["mask"])
|
||||
|
||||
image = torch.stack(tensors=[image1, image2], dim=0)
|
||||
sample = {"image": image, "mask": mask}
|
||||
sample = {"image1": image1, "image2": image2, "mask": mask}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
@ -227,7 +225,7 @@ class LEVIRCDPlus(NonGeoDataset):
|
|||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
image1, image2, mask = (sample["image"][0], sample["image"][1], sample["mask"])
|
||||
image1, image2, mask = (sample["image1"], sample["image2"], sample["mask"])
|
||||
ncols = 3
|
||||
|
||||
if "prediction" in sample:
|
||||
|
|
Загрузка…
Ссылка в новой задаче