diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index e433b40d2..a9cb9e6f2 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -75,6 +75,7 @@ class _COWC(VisionDataset, abc.ABC): split: str = "train", transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, download: bool = False, + checksum: bool = False, ) -> None: """Initialize a new COWC dataset instance. @@ -84,6 +85,7 @@ class _COWC(VisionDataset, abc.ABC): transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid @@ -93,7 +95,9 @@ class _COWC(VisionDataset, abc.ABC): assert split in ["train", "test"] self.root = root + self.split = split self.transforms = transforms + self.checksum = checksum if download: self.download() @@ -176,11 +180,11 @@ class _COWC(VisionDataset, abc.ABC): """Check integrity of dataset. Returns: - True if dataset MD5s match, else False + True if dataset files are found and/or MD5s match, else False """ for filename, md5 in zip(self.filenames, self.md5s): filepath = os.path.join(self.root, self.base_folder, filename) - if not check_integrity(filepath, md5): + if not check_integrity(filepath, md5 if self.checksum else None): return False return True @@ -196,7 +200,7 @@ class _COWC(VisionDataset, abc.ABC): self.base_url + filename, os.path.join(self.root, self.base_folder), filename=filename, - md5=md5, + md5=md5 if self.checksum else None, ) if filename.endswith(".tbz"): filepath = os.path.join(self.root, self.base_folder, filename) diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index 8fdbf50ea..a32aa7403 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -109,6 +109,7 @@ class CV4AKenyaCropType(VisionDataset): transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, download: bool = False, api_key: Optional[str] = None, + checksum: bool = False, verbose: bool = False, ) -> None: """Initialize a new CV4A Kenya Crop Type Dataset instance. @@ -123,16 +124,21 @@ class CV4AKenyaCropType(VisionDataset): entry and returns a transformed version download: if True, download dataset and store it in the root directory api_key: a RadiantEarth MLHub API key to use for downloading the dataset + checksum: if True, check the MD5 of the downloaded files (may be slow) verbose: if True, print messages when new tiles are loaded Raises: - RuntimeError: if download=True but api_key=None, or download=False but - dataset is missing or checksum fails + RuntimeError: if ``download=True`` but ``api_key=None``, or + ``download=False`` but dataset is missing or checksum fails """ self._validate_bands(bands) self.root = root + self.chip_size = chip_size + self.stride = stride + self.bands = bands self.transforms = transforms + self.checksum = checksum self.verbose = verbose if download: @@ -151,8 +157,6 @@ class CV4AKenyaCropType(VisionDataset): ) # Calculate the indices that we will use over all tiles - self.bands = bands - self.chip_size = chip_size self.chips_metadata = [] for tile_index in range(len(self.tile_names)): for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ @@ -215,7 +219,7 @@ class CV4AKenyaCropType(VisionDataset): tuple of labels and field ids Raises: - AssertionError: if tile_name is invalid + AssertionError: if ``tile_name`` is invalid """ assert tile_name in self.tile_names @@ -246,7 +250,7 @@ class CV4AKenyaCropType(VisionDataset): bands: user-provided tuple of bands to load Raises: - AssertionError: if bands is not a tuple + AssertionError: if ``bands`` is not a tuple ValueError: if an invalid band name is provided """ @@ -271,7 +275,7 @@ class CV4AKenyaCropType(VisionDataset): points in time, 3035 is the tile height, and 2016 is the tile width Raises: - AssertionError: if tile_name is invalid + AssertionError: if ``tile_name`` is invalid """ assert tile_name in self.tile_names @@ -307,7 +311,7 @@ class CV4AKenyaCropType(VisionDataset): array containing a single image tile Raises: - AssertionError: if tile_name or date is invalid + AssertionError: if ``tile_name`` or ``date`` is invalid """ assert tile_name in self.tile_names assert date in self.dates @@ -339,16 +343,16 @@ class CV4AKenyaCropType(VisionDataset): """Check integrity of dataset. Returns: - True if the MD5s of the dataset's archives match, else False + True if dataset files are found and/or MD5s match, else False """ images: bool = check_integrity( os.path.join(self.root, self.base_folder, self.image_meta["filename"]), - self.image_meta["md5"], + self.image_meta["md5"] if self.checksum else None, ) targets: bool = check_integrity( os.path.join(self.root, self.base_folder, self.target_meta["filename"]), - self.target_meta["md5"], + self.target_meta["md5"] if self.checksum else None, ) return images and targets diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index cc41078ae..4c62bf097 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -59,6 +59,7 @@ class LandCoverAI(VisionDataset): split: str = "train", transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, download: bool = False, + checksum: bool = False, ) -> None: """Initialize a new LandCover.ai dataset instance. @@ -68,6 +69,7 @@ class LandCoverAI(VisionDataset): transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid @@ -77,7 +79,9 @@ class LandCoverAI(VisionDataset): assert split in ["train", "val", "test"] self.root = root + self.split = split self.transforms = transforms + self.checksum = checksum if download: self.download() @@ -155,11 +159,11 @@ class LandCoverAI(VisionDataset): """Check integrity of dataset. Returns: - True if dataset MD5s match, else False + True if dataset files are found and/or MD5s match, else False """ integrity: bool = check_integrity( os.path.join(self.root, self.base_folder, self.filename), - self.md5, + self.md5 if self.checksum else None, ) return integrity @@ -175,7 +179,7 @@ class LandCoverAI(VisionDataset): self.url, os.path.join(self.root, self.base_folder), filename=self.filename, - md5=self.md5, + md5=self.md5 if self.checksum else None, ) # Generate train/val/test splits diff --git a/torchgeo/datasets/nwpu.py b/torchgeo/datasets/nwpu.py index 44717cc16..b111d7cae 100644 --- a/torchgeo/datasets/nwpu.py +++ b/torchgeo/datasets/nwpu.py @@ -80,6 +80,7 @@ class VHR10(VisionDataset): split: str = "positive", transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, download: bool = False, + checksum: bool = False, ) -> None: """Initialize a new VHR-10 dataset instance. @@ -89,6 +90,7 @@ class VHR10(VisionDataset): transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid @@ -100,6 +102,7 @@ class VHR10(VisionDataset): self.root = root self.split = split self.transforms = transforms + self.checksum = checksum if download: self.download() @@ -199,11 +202,11 @@ class VHR10(VisionDataset): """Check integrity of dataset. Returns: - True if dataset MD5s match, else False + True if dataset files are found and/or MD5s match, else False """ image: bool = check_integrity( os.path.join(self.root, self.base_folder, self.image_meta["filename"]), - self.image_meta["md5"], + self.image_meta["md5"] if self.checksum else None, ) # Annotations only needed for "positive" image set @@ -216,7 +219,7 @@ class VHR10(VisionDataset): "NWPU VHR-10 dataset", self.target_meta["filename"], ), - self.target_meta["md5"], + self.target_meta["md5"] if self.checksum else None, ) return image and target @@ -233,7 +236,7 @@ class VHR10(VisionDataset): self.image_meta["file_id"], os.path.join(self.root, self.base_folder), self.image_meta["filename"], - self.image_meta["md5"], + self.image_meta["md5"] if self.checksum else None, ) # Must be installed to extract RAR file @@ -251,5 +254,5 @@ class VHR10(VisionDataset): self.target_meta["url"], os.path.join(self.root, self.base_folder, "NWPU VHR-10 dataset"), self.target_meta["filename"], - self.target_meta["md5"], + self.target_meta["md5"] if self.checksum else None, ) diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index d576b36e1..581126862 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -36,14 +36,11 @@ class SEN12MS(GeoDataset): .. code-block: bash - wget "ftp://m1474000:m1474000@dataserv.ub.tum.de/checksum.sha512" - for season in 1158_spring 1868_summer 1970_fall 2017_winter do for source in lc s1 s2 do wget "ftp://m1474000:m1474000@dataserv.ub.tum.de/ROIs${season}_${source}.tar.gz" - shasum -a 512 "ROIs${season}_${source}.tar.gz" tar xvzf "ROIs${season}_${source}.tar.gz" done done @@ -55,8 +52,7 @@ class SEN12MS(GeoDataset): or manually downloaded from https://dataserv.ub.tum.de/s/m1474000 and https://github.com/schmitt-muc/SEN12MS/tree/master/splits. - This download will likely take several hours. The checksums.sha512 - file should be used to confirm the integrity of the downloads. + This download will likely take several hours. """ # noqa: E501 base_folder = "sen12ms" @@ -76,12 +72,29 @@ class SEN12MS(GeoDataset): "train_list.txt", "test_list.txt", ] + md5s = [ + "6e2e8fa8b8cba77ddab49fd20ff5c37b", + "fba019bb27a08c1db96b31f718c34d79", + "d58af2c15a16f376eb3308dc9b685af2", + "2c5bd80244440b6f9d54957c6b1f23d4", + "01044b7f58d33570c6b57fec28a3d449", + "4dbaf72ecb704a4794036fe691427ff3", + "9b126a68b0e3af260071b3139cb57cee", + "19132e0aab9d4d6862fd42e8e6760847", + "b8f117818878da86b5f5e06400eb1866", + "0fa0420ef7bcfe4387c7e6fe226dc728", + "bb8cbfc16b95a4f054a3d5380e0130ed", + "3807545661288dcca312c9c538537b63", + "0a68d4e1eb24f128fccdb930000b2546", + "c7faad064001e646445c4c634169484d", + ] def __init__( self, root: str = "data", split: str = "train", transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + checksum: bool = False, ) -> None: """Initialize a new SEN12MS dataset instance. @@ -90,6 +103,7 @@ class SEN12MS(GeoDataset): split: one of "train" or "test" transforms: a function/transform that takes input sample and its target as entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid @@ -98,7 +112,9 @@ class SEN12MS(GeoDataset): assert split in ["train", "test"] self.root = root + self.split = split self.transforms = transforms + self.checksum = checksum if not self._check_integrity(): raise RuntimeError("Dataset not found.") @@ -169,10 +185,10 @@ class SEN12MS(GeoDataset): """Check integrity of dataset. Returns: - True if files exist, else False + True if dataset files are found and/or MD5s match, else False """ - # We could also check md5s, but it would take ~30 min to compute - for filename in self.filenames: - if not check_integrity(os.path.join(self.root, self.base_folder, filename)): + for filename, md5 in zip(self.filenames, self.md5s): + filepath = os.path.join(self.root, self.base_folder, filename) + if not check_integrity(filepath, md5 if self.checksum else None): return False return True