diff --git a/tests/datasets/test_cbf.py b/tests/datasets/test_cbf.py index eed082816..e90a309ae 100644 --- a/tests/datasets/test_cbf.py +++ b/tests/datasets/test_cbf.py @@ -61,7 +61,7 @@ class TestCanadianBuildingFootprints: assert isinstance(ds, UnionDataset) def test_already_downloaded(self, dataset: CanadianBuildingFootprints) -> None: - CanadianBuildingFootprints(root=dataset.root, download=True) + CanadianBuildingFootprints(dataset.paths, download=True) def test_plot(self, dataset: CanadianBuildingFootprints) -> None: query = dataset.bounds diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index ba33de7c4..a659e039c 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -141,7 +141,7 @@ class TestChesapeakeCVPR: ) monkeypatch.setattr( ChesapeakeCVPR, - "files", + "_files", ["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"], ) root = str(tmp_path) diff --git a/tests/datasets/test_enviroatlas.py b/tests/datasets/test_enviroatlas.py index 53760f3a5..8f65119fa 100644 --- a/tests/datasets/test_enviroatlas.py +++ b/tests/datasets/test_enviroatlas.py @@ -47,7 +47,7 @@ class TestEnviroAtlas: ) monkeypatch.setattr( EnviroAtlas, - "files", + "_files", ["pittsburgh_pa-2010_1m-train_tiles-debuffered", "spatial_index.geojson"], ) root = str(tmp_path) diff --git a/tests/datasets/test_openbuildings.py b/tests/datasets/test_openbuildings.py index fe66b596d..461159995 100644 --- a/tests/datasets/test_openbuildings.py +++ b/tests/datasets/test_openbuildings.py @@ -37,7 +37,7 @@ class TestOpenBuildings: monkeypatch.setattr(OpenBuildings, "md5s", md5s) transforms = nn.Identity() - return OpenBuildings(root=root, transforms=transforms) + return OpenBuildings(root, transforms=transforms) def test_no_shapes_to_rasterize( self, dataset: OpenBuildings, tmp_path: Path @@ -61,19 +61,19 @@ class TestOpenBuildings: with pytest.raises( RuntimeError, match="have manually downloaded the dataset as suggested " ): - OpenBuildings(root=false_root) + OpenBuildings(false_root) def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None: with open(os.path.join(tmp_path, "000_buildings.csv.gz"), "w") as f: f.write("bad") with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): - OpenBuildings(dataset.root, checksum=True) + OpenBuildings(dataset.paths, checksum=True) def test_no_meta_data_found(self, tmp_path: Path) -> None: false_root = os.path.join(tmp_path, "empty") os.makedirs(false_root) with pytest.raises(FileNotFoundError, match="Meta data file"): - OpenBuildings(root=false_root) + OpenBuildings(false_root) def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None: # change meta data to another 'title_url' so that there is no match found @@ -85,7 +85,7 @@ class TestOpenBuildings: json.dump(content, f) with pytest.raises(FileNotFoundError, match="data was found in"): - OpenBuildings(dataset.root) + OpenBuildings(dataset.paths) def test_getitem(self, dataset: OpenBuildings) -> None: x = dataset[dataset.bounds] diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index d9282a095..f2fe2b064 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -4,7 +4,8 @@ """Canadian Building Footprints dataset.""" import os -from typing import Any, Callable, Optional +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -60,7 +61,7 @@ class CanadianBuildingFootprints(VectorDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: float = 0.00001, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, @@ -70,7 +71,7 @@ class CanadianBuildingFootprints(VectorDataset): """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -83,8 +84,11 @@ class CanadianBuildingFootprints(VectorDataset): FileNotFoundError: if no files are found in ``root`` RuntimeError: if ``download=False`` and data is not found, or ``checksum=True`` and checksums don't match + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.checksum = checksum if download: @@ -96,7 +100,7 @@ class CanadianBuildingFootprints(VectorDataset): + "You can use download=True to download it" ) - super().__init__(root, crs, res, transforms) + super().__init__(paths, crs, res, transforms) def _check_integrity(self) -> bool: """Check integrity of dataset. @@ -104,8 +108,9 @@ class CanadianBuildingFootprints(VectorDataset): Returns: True if dataset files are found and/or MD5s match, else False """ + assert isinstance(self.paths, str) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): - filepath = os.path.join(self.root, prov_terr + ".zip") + filepath = os.path.join(self.paths, prov_terr + ".zip") if not check_integrity(filepath, md5 if self.checksum else None): return False return True @@ -115,11 +120,11 @@ class CanadianBuildingFootprints(VectorDataset): if self._check_integrity(): print("Files already downloaded and verified") return - + assert isinstance(self.paths, str) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): download_and_extract_archive( self.url + prov_terr + ".zip", - self.root, + self.paths, md5=md5 if self.checksum else None, ) diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 1c17fe84b..5dd3acb1b 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -495,7 +495,7 @@ class ChesapeakeCVPR(GeoDataset): ) # these are used to check the integrity of the dataset - files = [ + _files = [ "de_1m_2013_extended-debuffered-test_tiles", "de_1m_2013_extended-debuffered-train_tiles", "de_1m_2013_extended-debuffered-val_tiles", @@ -704,7 +704,7 @@ class ChesapeakeCVPR(GeoDataset): return os.path.exists(os.path.join(self.root, filename)) # Check if the extracted files already exist - if all(map(exists, self.files)): + if all(map(exists, self._files)): return # Check if the zip files have already been downloaded diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 56732ff11..551c142ea 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -80,7 +80,7 @@ class EnviroAtlas(GeoDataset): ) # these are used to check the integrity of the dataset - files = [ + _files = [ "austin_tx-2012_1m-test_tiles-debuffered", "austin_tx-2012_1m-val5_tiles-debuffered", "durham_nc-2012_1m-test_tiles-debuffered", @@ -422,7 +422,7 @@ class EnviroAtlas(GeoDataset): return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename)) # Check if the extracted files already exist - if all(map(exists, self.files)): + if all(map(exists, self._files)): return # Check if the zip files have already been downloaded diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 09564dbab..7c5c68ac6 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -72,9 +72,17 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): dataset = landsat7 | landsat8 """ + paths: Union[str, Iterable[str]] _crs = CRS.from_epsg(4326) _res = 0.0 + #: Glob expression used to search for files. + #: + #: This expression should be specific enough that it will not pick up files from + #: other datasets. It should not include a file extension, as the dataset may be in + #: a different file format than what it was originally downloaded as. + filename_glob = "*" + # NOTE: according to the Python docs: # # * https://docs.python.org/3/library/exceptions.html#NotImplementedError @@ -269,17 +277,36 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}") self._res = new_res + @property + def files(self) -> set[str]: + """A list of all files in the dataset. + + Returns: + All files in the dataset. + + .. versionadded:: 0.5 + """ + # Make iterable + if isinstance(self.paths, str): + paths: Iterable[str] = [self.paths] + else: + paths = self.paths + + # Using set to remove any duplicates if directories are overlapping + files: set[str] = set() + for path in paths: + if os.path.isdir(path): + pathname = os.path.join(path, "**", self.filename_glob) + files |= set(glob.iglob(pathname, recursive=True)) + else: + files.add(path) + + return files + class RasterDataset(GeoDataset): """Abstract base class for :class:`GeoDataset` stored as raster files.""" - #: Glob expression used to search for files. - #: - #: This expression should be specific enough that it will not pick up files from - #: other datasets. It should not include a file extension, as the dataset may be in - #: a different file format than what it was originally downloaded as. - filename_glob = "*" - #: Regular expression used to extract date from filename. #: #: The expression should use named groups. The expression may contain any number of @@ -423,32 +450,6 @@ class RasterDataset(GeoDataset): self._crs = cast(CRS, crs) self._res = cast(float, res) - @property - def files(self) -> set[str]: - """A list of all files in the dataset. - - Returns: - All files in the dataset. - - .. versionadded:: 0.5 - """ - # Make iterable - if isinstance(self.paths, str): - paths: Iterable[str] = [self.paths] - else: - paths = self.paths - - # Using set to remove any duplicates if directories are overlapping - files: set[str] = set() - for path in paths: - if os.path.isdir(path): - pathname = os.path.join(path, "**", self.filename_glob) - files |= set(glob.iglob(pathname, recursive=True)) - else: - files.add(path) - - return files - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. @@ -571,16 +572,9 @@ class RasterDataset(GeoDataset): class VectorDataset(GeoDataset): """Abstract base class for :class:`GeoDataset` stored as vector files.""" - #: Glob expression used to search for files. - #: - #: This expression should be specific enough that it will not pick up files from - #: other datasets. It should not include a file extension, as the dataset may be in - #: a different file format than what it was originally downloaded as. - filename_glob = "*" - def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: float = 0.0001, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, @@ -589,7 +583,7 @@ class VectorDataset(GeoDataset): """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -603,16 +597,18 @@ class VectorDataset(GeoDataset): .. versionadded:: 0.4 The *label_name* parameter. + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ super().__init__(transforms) - self.root = root + self.paths = paths self.label_name = label_name # Populate the dataset index i = 0 - pathname = os.path.join(root, "**", self.filename_glob) - for filepath in glob.iglob(pathname, recursive=True): + for filepath in self.files: try: with fiona.open(filepath) as src: if crs is None: @@ -633,7 +629,7 @@ class VectorDataset(GeoDataset): i += 1 if i == 0: - msg = f"No {self.__class__.__name__} data was found in `root='{root}'`" + msg = f"No {self.__class__.__name__} data was found in `root='{paths}'`" raise FileNotFoundError(msg) self._crs = crs diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index a714beb9a..32c66d731 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -7,7 +7,8 @@ import glob import json import os import sys -from typing import Any, Callable, Optional, cast +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union, cast import fiona import fiona.transform @@ -205,7 +206,7 @@ class OpenBuildings(VectorDataset): def __init__( self, - root: str = "data", + paths: Union[str, Iterable[str]] = "data", crs: Optional[CRS] = None, res: float = 0.0001, transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, @@ -214,7 +215,7 @@ class OpenBuildings(VectorDataset): """Initialize a new Dataset instance. Args: - root: root directory where dataset can be found + paths: one or more root directories to search or files to load crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS @@ -224,11 +225,13 @@ class OpenBuildings(VectorDataset): Raises: FileNotFoundError: if no files are found in ``root`` + + .. versionchanged:: 0.5 + *root* was renamed to *paths*. """ - self.root = root + self.paths = paths self.res = res self.checksum = checksum - self.root = root self.res = res self.transforms = transforms @@ -237,7 +240,8 @@ class OpenBuildings(VectorDataset): # Create an R-tree to index the dataset using the polygon centroid as bounds self.index = Index(interleaved=False, properties=Property(dimension=3)) - with open(os.path.join(root, "tiles.geojson")) as f: + assert isinstance(self.paths, str) + with open(os.path.join(self.paths, "tiles.geojson")) as f: data = json.load(f) features = data["features"] @@ -245,7 +249,7 @@ class OpenBuildings(VectorDataset): feature["properties"]["tile_url"].split("/")[-1] for feature in features ] # get csv filename - polygon_files = glob.glob(os.path.join(self.root, self.zipfile_glob)) + polygon_files = glob.glob(os.path.join(self.paths, self.zipfile_glob)) polygon_filenames = [f.split(os.sep)[-1] for f in polygon_files] matched_features = [ @@ -274,14 +278,14 @@ class OpenBuildings(VectorDataset): coords = (minx, maxx, miny, maxy, mint, maxt) filepath = os.path.join( - self.root, feature["properties"]["tile_url"].split("/")[-1] + self.paths, feature["properties"]["tile_url"].split("/")[-1] ) self.index.insert(i, coords, filepath) i += 1 if i == 0: raise FileNotFoundError( - f"No {self.__class__.__name__} data was found in '{self.root}'" + f"No {self.__class__.__name__} data was found in '{self.paths}'" ) self._crs = crs @@ -398,7 +402,8 @@ class OpenBuildings(VectorDataset): FileNotFoundError: if metadata file is not found in root """ # Check if the zip files have already been downloaded and checksum - pathname = os.path.join(self.root, self.zipfile_glob) + assert isinstance(self.paths, str) + pathname = os.path.join(self.paths, self.zipfile_glob) i = 0 for zipfile in glob.iglob(pathname): filename = os.path.basename(zipfile) @@ -410,14 +415,14 @@ class OpenBuildings(VectorDataset): return # check if the metadata file has been downloaded - if not os.path.exists(os.path.join(self.root, self.meta_data_filename)): + if not os.path.exists(os.path.join(self.paths, self.meta_data_filename)): raise FileNotFoundError( f"Meta data file {self.meta_data_filename} " - f"not found in in `root={self.root}`." + f"not found in in `root={self.paths}`." ) raise RuntimeError( - f"Dataset not found in `root={self.root}` " + f"Dataset not found in `root={self.paths}` " "either specify a different `root` directory or make sure you " "have manually downloaded the dataset as suggested in the documentation." )