Adding lightweight check for dataset integrity to SEN12MS

This commit is contained in:
Caleb Robinson 2021-07-20 19:10:16 +00:00
Родитель b8b6e69cce
Коммит cfc5884311
2 изменённых файлов: 31 добавлений и 2 удалений

Просмотреть файл

@ -65,3 +65,8 @@ class TestSEN12MS:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
SEN12MS(str(tmp_path))
def test_check_integrity_light(self) -> None:
root = os.path.join("tests", "data")
ds = SEN12MS(root, checksum=False)
assert isinstance(ds, SEN12MS)

Просмотреть файл

@ -76,6 +76,14 @@ class SEN12MS(VisionDataset):
"train_list.txt",
"test_list.txt",
]
light_filenames = [
"ROIs1158_spring/",
"ROIs1868_summer/",
"ROIs1970_fall/",
"ROIs2017_winter/",
"train_list.txt",
"test_list.txt",
]
md5s = [
"6e2e8fa8b8cba77ddab49fd20ff5c37b",
"fba019bb27a08c1db96b31f718c34d79",
@ -120,8 +128,12 @@ class SEN12MS(VisionDataset):
self.transforms = transforms
self.checksum = checksum
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted.")
if checksum:
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted.")
else:
if not self._check_integrity_light():
raise RuntimeError("Dataset not found or corrupted.")
with open(os.path.join(self.root, self.base_folder, split + "_list.txt")) as f:
self.ids = f.readlines()
@ -185,6 +197,18 @@ class SEN12MS(VisionDataset):
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor
def _check_integrity_light(self) -> bool:
"""Checks the integrity of the dataset structure.
Returns:
True if the dataset directories and split files are found, else False
"""
for filename in self.light_filenames:
filepath = os.path.join(self.root, self.base_folder, filename)
if not os.path.exists(filepath):
return False
return True
def _check_integrity(self) -> bool:
"""Check integrity of dataset.