Allow RasterDataset to accept list of files (#1442)

* Make RasterDataset accept list of files

* Fix check if str

* Use isdir and isfile

* Rename root to paths and update type hint

* Update children of RasterDataset methods using root

* Fix check to cast str to list

* Update conf files for RasterDatasets

* Add initial suggested test

* Add workaround for lists LandCoverAIBase

* Add method handle_nonlocal_path for users to override

* Raise RuntimeError to support existing tests

* Remove reduntand cast to set

* Remove required os.exists for paths

* Revert "Remove required os.exists for paths"

This reverts commit 84bf62b944326c33d5ba8efdcab615c65b124792.

* Use arg  as potitional argument not kwarg

* Improve comments and logs about arg paths

* Remove misleading comment

* Change type hint of 'paths' to Iterable

* Change type hint of 'paths' to Iterable

* Remove premature handling of non-local paths

* Replace root with paths in docstrings

* Add versionadded to list_files docstring

* Add versionchanged to docstrings

* Update type of paths in childred of Raster

* Replace docstring for paths in all raster

* Swap root with paths for conf files for raster

* Add newline before versionchanged

* Revert name to root in conf for ChesapeakeCVPR

* Simplify EUDEM tests

* paths must be a string if you want autodownload support

* Convert list_files to a property

* Fix type hints

* Test with a real empty directory

* More diverse tests

* LandCoverAI: don't yet support list of paths

* Black

* isort

---------

Co-authored-by: Adrian Tofting <adriantofting@mobmob14994.hq.k.grp>
Co-authored-by: Adrian Tofting <adrian@vake.ai>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Adrian Tofting 2023-09-29 16:28:07 +02:00 коммит произвёл GitHub
Родитель 51ffb698ee
Коммит 3cef4fb21d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
36 изменённых файлов: 323 добавлений и 222 удалений

Просмотреть файл

@ -20,4 +20,4 @@ data:
patch_size: 224
num_workers: 16
dict_kwargs:
root: "data/l7irish"
paths: "data/l7irish"

Просмотреть файл

@ -20,4 +20,4 @@ data:
patch_size: 224
num_workers: 16
dict_kwargs:
root: "data/l8biome"
paths: "data/l8biome"

Просмотреть файл

@ -21,5 +21,5 @@ data:
num_workers: 4
patch_size: 32
dict_kwargs:
naip_root: "data/naip"
chesapeake_root: "data/chesapeake/BAYWIDE"
naip_paths: "data/naip"
chesapeake_paths: "data/chesapeake/BAYWIDE"

Просмотреть файл

@ -15,5 +15,5 @@ data:
patch_size: 32
length: 5
dict_kwargs:
root: "tests/data/l7irish"
paths: "tests/data/l7irish"
download: true

Просмотреть файл

@ -15,5 +15,5 @@ data:
patch_size: 32
length: 5
dict_kwargs:
root: "tests/data/l8biome"
paths: "tests/data/l8biome"
download: true

Просмотреть файл

@ -14,6 +14,6 @@ data:
batch_size: 2
patch_size: 32
dict_kwargs:
naip_root: "tests/data/naip"
chesapeake_root: "tests/data/chesapeake/BAYWIDE"
naip_paths: "tests/data/naip"
chesapeake_paths: "tests/data/chesapeake/BAYWIDE"
chesapeake_download: true

Просмотреть файл

@ -52,14 +52,14 @@ class TestAbovegroundLiveWoodyBiomassDensity:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_no_dataset(self) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in."):
AbovegroundLiveWoodyBiomassDensity(root="/test")
def test_no_dataset(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
AbovegroundLiveWoodyBiomassDensity(str(tmp_path))
def test_already_downloaded(
self, dataset: AbovegroundLiveWoodyBiomassDensity
) -> None:
AbovegroundLiveWoodyBiomassDensity(dataset.root)
AbovegroundLiveWoodyBiomassDensity(dataset.paths)
def test_and(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
ds = dataset & dataset

Просмотреть файл

@ -27,7 +27,7 @@ class TestAsterGDEM:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(RuntimeError, match="Dataset not found in"):
AsterGDEM(root=str(tmp_path))
AsterGDEM(str(tmp_path))
def test_getitem(self, dataset: AsterGDEM) -> None:
x = dataset[dataset.bounds]

Просмотреть файл

@ -74,7 +74,7 @@ class TestCDL:
next(dataset.index.intersection(tuple(query)))
def test_already_extracted(self, dataset: CDL) -> None:
CDL(root=dataset.root, years=[2020, 2021])
CDL(dataset.paths, years=[2020, 2021])
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip")

Просмотреть файл

@ -59,7 +59,7 @@ class TestChesapeake13:
assert isinstance(ds, UnionDataset)
def test_already_extracted(self, dataset: Chesapeake13) -> None:
Chesapeake13(root=dataset.root, download=True)
Chesapeake13(dataset.paths, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
url = os.path.join(

Просмотреть файл

@ -44,9 +44,9 @@ class TestCMSGlobalMangroveCanopy:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_no_dataset(self) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in."):
CMSGlobalMangroveCanopy(root="/test")
def test_no_dataset(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
CMSGlobalMangroveCanopy(str(tmp_path))
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join(
@ -65,7 +65,7 @@ class TestCMSGlobalMangroveCanopy:
) as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
CMSGlobalMangroveCanopy(root=str(tmp_path), country="Angola", checksum=True)
CMSGlobalMangroveCanopy(str(tmp_path), country="Angola", checksum=True)
def test_invalid_country(self) -> None:
with pytest.raises(AssertionError):

Просмотреть файл

@ -47,7 +47,7 @@ class TestEsri2020:
assert isinstance(x["mask"], torch.Tensor)
def test_already_extracted(self, dataset: Esri2020) -> None:
Esri2020(root=dataset.root, download=True)
Esri2020(dataset.paths, download=True)
def test_not_extracted(self, tmp_path: Path) -> None:
url = os.path.join(
@ -57,7 +57,7 @@ class TestEsri2020:
"io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip",
)
shutil.copy(url, tmp_path)
Esri2020(root=str(tmp_path))
Esri2020(str(tmp_path))
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):

Просмотреть файл

@ -33,21 +33,22 @@ class TestEUDEM:
assert isinstance(x["mask"], torch.Tensor)
def test_extracted_already(self, dataset: EUDEM) -> None:
zipfile = os.path.join(dataset.root, "eu_dem_v11_E30N10.zip")
shutil.unpack_archive(zipfile, dataset.root, "zip")
EUDEM(dataset.root)
assert isinstance(dataset.paths, str)
zipfile = os.path.join(dataset.paths, "eu_dem_v11_E30N10.zip")
shutil.unpack_archive(zipfile, dataset.paths, "zip")
EUDEM(dataset.paths)
def test_no_dataset(self, tmp_path: Path) -> None:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(RuntimeError, match="Dataset not found in"):
EUDEM(root=str(tmp_path))
EUDEM(str(tmp_path))
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "eu_dem_v11_E30N10.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
EUDEM(root=str(tmp_path), checksum=True)
EUDEM(str(tmp_path), checksum=True)
def test_and(self, dataset: EUDEM) -> None:
ds = dataset & dataset

Просмотреть файл

@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pickle
from collections.abc import Iterable
from pathlib import Path
from typing import Union
import pytest
import torch
@ -178,6 +179,39 @@ class TestRasterDataset:
cache = request.param[1]
return Sentinel2(root, bands=bands, transforms=transforms, cache=cache)
@pytest.mark.parametrize(
"paths",
[
# Single directory
os.path.join("tests", "data", "naip"),
# Multiple directories
[
os.path.join("tests", "data", "naip"),
os.path.join("tests", "data", "naip"),
],
# Single file
os.path.join("tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"),
# Multiple files
(
os.path.join(
"tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"
),
os.path.join(
"tests", "data", "naip", "m_3807511_ne_18_060_20190605.tif"
),
),
# Combination
{
os.path.join("tests", "data", "naip"),
os.path.join(
"tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"
),
},
],
)
def test_files(self, paths: Union[str, Iterable[str]]) -> None:
assert 1 <= len(NAIP(paths).files) <= 2
def test_getitem_single_file(self, naip: NAIP) -> None:
x = naip[naip.bounds]
assert isinstance(x, dict)

Просмотреть файл

@ -47,7 +47,7 @@ class TestGlobBiomass:
assert isinstance(x["mask"], torch.Tensor)
def test_already_extracted(self, dataset: GlobBiomass) -> None:
GlobBiomass(root=dataset.root)
GlobBiomass(dataset.paths)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
@ -57,7 +57,7 @@ class TestGlobBiomass:
with open(os.path.join(tmp_path, "N00E020_agb.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
GlobBiomass(root=str(tmp_path), checksum=True)
GlobBiomass(str(tmp_path), checksum=True)
def test_and(self, dataset: GlobBiomass) -> None:
ds = dataset & dataset

Просмотреть файл

@ -58,7 +58,7 @@ class TestL7Irish:
plt.close()
def test_already_extracted(self, dataset: L7Irish) -> None:
L7Irish(root=dataset.root, download=True)
L7Irish(dataset.paths, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "l7irish", "*.tar.gz")
@ -88,7 +88,7 @@ class TestL7Irish:
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
):
ds = L7Irish(root=dataset.root, bands=["B10", "B20", "B50"])
ds = L7Irish(dataset.paths, bands=["B10", "B20", "B50"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()

Просмотреть файл

@ -58,7 +58,7 @@ class TestL8Biome:
plt.close()
def test_already_extracted(self, dataset: L8Biome) -> None:
L8Biome(root=dataset.root, download=True)
L8Biome(dataset.paths, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "l8biome", "*.tar.gz")
@ -88,7 +88,7 @@ class TestL8Biome:
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
):
ds = L8Biome(root=dataset.root, bands=["B1", "B2", "B5"])
ds = L8Biome(dataset.paths, bands=["B1", "B2", "B5"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()

Просмотреть файл

@ -40,7 +40,7 @@ class TestLandCoverAIGeo:
assert isinstance(x["mask"], torch.Tensor)
def test_already_extracted(self, dataset: LandCoverAIGeo) -> None:
LandCoverAIGeo(root=dataset.root, download=True)
LandCoverAIGeo(dataset.root, download=True)
def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip")

Просмотреть файл

@ -52,7 +52,7 @@ class TestLandsat8:
def test_plot_wrong_bands(self, dataset: Landsat8) -> None:
bands = ("SR_B1",)
ds = Landsat8(root=dataset.root, bands=bands)
ds = Landsat8(dataset.paths, bands=bands)
x = dataset[dataset.bounds]
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"

Просмотреть файл

@ -69,7 +69,7 @@ class TestNLCD:
assert isinstance(ds, UnionDataset)
def test_already_extracted(self, dataset: NLCD) -> None:
NLCD(root=dataset.root, download=True, years=[2019])
NLCD(dataset.paths, download=True, years=[2019])
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join(

Просмотреть файл

@ -133,7 +133,7 @@ class TestSentinel2:
def test_plot_wrong_bands(self, dataset: Sentinel2) -> None:
bands = ["B02"]
ds = Sentinel2(root=dataset.root, res=dataset.res, bands=bands)
ds = Sentinel2(dataset.paths, res=dataset.res, bands=bands)
x = dataset[dataset.bounds]
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"

Просмотреть файл

@ -3,10 +3,10 @@
"""Aboveground Live Woody Biomass Density dataset."""
import glob
import json
import os
from typing import Any, Callable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -59,7 +59,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@ -69,7 +69,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -80,14 +80,17 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.download = download
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
@ -96,14 +99,13 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
RuntimeError: if dataset is missing
"""
# Check if the extracted files already exist
pathname = os.path.join(self.root, self.filename_glob)
if glob.glob(pathname):
if self.files:
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
f"Dataset not found in `root={self.paths}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
@ -113,15 +115,16 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
def _download(self) -> None:
"""Download the dataset."""
download_url(self.url, self.root, self.base_filename)
assert isinstance(self.paths, str)
download_url(self.url, self.paths, self.base_filename)
with open(os.path.join(self.root, self.base_filename)) as f:
with open(os.path.join(self.paths, self.base_filename)) as f:
content = json.load(f)
for item in content["features"]:
download_url(
item["properties"]["download"],
self.root,
self.paths,
item["properties"]["tile_id"] + ".tif",
)

Просмотреть файл

@ -3,9 +3,7 @@
"""Aster Global Digital Elevation Model dataset."""
import glob
import os
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -47,7 +45,7 @@ class AsterGDEM(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@ -56,8 +54,8 @@ class AsterGDEM(RasterDataset):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found, here the collection of
individual zip files for each tile should be found
paths: one or more root directories to search or files to load, here
the collection of individual zip files for each tile should be found
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -67,14 +65,17 @@ class AsterGDEM(RasterDataset):
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if dataset is missing
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
@ -83,12 +84,11 @@ class AsterGDEM(RasterDataset):
RuntimeError: if dataset is missing
"""
# Check if the extracted files already exists
pathname = os.path.join(self.root, self.filename_glob)
if glob.glob(pathname):
if self.files:
return
raise RuntimeError(
f"Dataset not found in `root={self.root}` "
f"Dataset not found in `root={self.paths}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded dataset tiles as suggested in the documentation."
)

Просмотреть файл

@ -3,9 +3,9 @@
"""CDL dataset."""
import glob
import os
from typing import Any, Callable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
import torch
@ -205,7 +205,7 @@ class CDL(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2022],
@ -218,7 +218,7 @@ class CDL(RasterDataset):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -234,11 +234,14 @@ class CDL(RasterDataset):
Raises:
AssertionError: if ``years`` or ``classes`` are invalid
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
.. versionadded:: 0.5
The *years* and *classes* parameters.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
assert set(years) <= self.md5s.keys(), (
"CDL data product only exists for the following years: "
@ -249,7 +252,7 @@ class CDL(RasterDataset):
), f"Only the following classes are valid: {list(self.cmap.keys())}."
assert 0 in classes, "Classes must include the background class: 0"
self.root = root
self.paths = paths
self.years = years
self.classes = classes
self.download = download
@ -259,7 +262,7 @@ class CDL(RasterDataset):
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
# Map chosen classes to ordinal numbers, all others mapped to background class
for v, k in enumerate(self.classes):
@ -289,22 +292,15 @@ class CDL(RasterDataset):
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted files already exist
exists = []
for year in self.years:
filename_year = self.filename_glob.replace("*", str(year))
pathname = os.path.join(self.root, "**", filename_year)
for fname in glob.iglob(pathname, recursive=True):
if not fname.endswith(".zip"):
exists.append(True)
if len(exists) == len(self.years):
if self.files:
return
# Check if the zip files have already been downloaded
exists = []
assert isinstance(self.paths, str)
for year in self.years:
pathname = os.path.join(
self.root, self.zipfile_glob.replace("*", str(year))
self.paths, self.zipfile_glob.replace("*", str(year))
)
if os.path.exists(pathname):
exists.append(True)
@ -318,7 +314,7 @@ class CDL(RasterDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
f"Dataset not found in `root={self.paths}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
@ -332,16 +328,17 @@ class CDL(RasterDataset):
for year in self.years:
download_url(
self.url.format(year),
self.root,
self.paths,
md5=self.md5s[year] if self.checksum else None,
)
def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
for year in self.years:
zipfile_name = self.zipfile_glob.replace("*", str(year))
pathname = os.path.join(self.root, zipfile_name)
extract_archive(pathname, self.root)
pathname = os.path.join(self.paths, zipfile_name)
extract_archive(pathname, self.paths)
def plot(
self,

Просмотреть файл

@ -6,8 +6,8 @@
import abc
import os
import sys
from collections.abc import Sequence
from typing import Any, Callable, Optional, cast
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
import fiona
import matplotlib.pyplot as plt
@ -89,7 +89,7 @@ class Chesapeake(RasterDataset, abc.ABC):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@ -100,7 +100,7 @@ class Chesapeake(RasterDataset, abc.ABC):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -112,10 +112,13 @@ class Chesapeake(RasterDataset, abc.ABC):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.download = download
self.checksum = checksum
@ -132,7 +135,7 @@ class Chesapeake(RasterDataset, abc.ABC):
)
self._cmap = ListedColormap(colors)
super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
@ -141,18 +144,19 @@ class Chesapeake(RasterDataset, abc.ABC):
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted file already exists
if os.path.exists(os.path.join(self.root, self.filename)):
if self.files:
return
# Check if the zip file has already been downloaded
if os.path.exists(os.path.join(self.root, self.zipfile)):
assert isinstance(self.paths, str)
if os.path.exists(os.path.join(self.paths, self.zipfile)):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
f"Dataset not found in `root={self.paths}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
@ -163,11 +167,12 @@ class Chesapeake(RasterDataset, abc.ABC):
def _download(self) -> None:
"""Download the dataset."""
download_url(self.url, self.root, filename=self.zipfile, md5=self.md5)
download_url(self.url, self.paths, filename=self.zipfile, md5=self.md5)
def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.zipfile))
assert isinstance(self.paths, str)
extract_archive(os.path.join(self.paths, self.zipfile))
def plot(
self,

Просмотреть файл

@ -3,9 +3,8 @@
"""CMS Global Mangrove Canopy dataset."""
import glob
import os
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -168,7 +167,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
measurement: str = "agb",
@ -180,7 +179,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -193,11 +192,14 @@ class CMSGlobalMangroveCanopy(RasterDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if dataset is missing or checksum fails
AssertionError: if country or measurement arg are not str or invalid
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.checksum = checksum
assert isinstance(country, str), "Country argument must be a str."
@ -220,7 +222,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
@ -229,12 +231,12 @@ class CMSGlobalMangroveCanopy(RasterDataset):
RuntimeError: if dataset is missing or checksum fails
"""
# Check if the extracted files already exist
pathname = os.path.join(self.root, "**", self.filename_glob)
if glob.glob(pathname):
if self.files:
return
# Check if the zip file has already been downloaded
pathname = os.path.join(self.root, self.zipfile)
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, self.zipfile)
if os.path.exists(pathname):
if self.checksum and not check_integrity(pathname, self.md5):
raise RuntimeError("Dataset found, but corrupted.")
@ -242,14 +244,15 @@ class CMSGlobalMangroveCanopy(RasterDataset):
return
raise RuntimeError(
f"Dataset not found in `root={self.root}` "
f"Dataset not found in `root={self.paths}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded the dataset as instructed in the documentation."
)
def _extract(self) -> None:
"""Extract the dataset."""
pathname = os.path.join(self.root, self.zipfile)
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, self.zipfile)
extract_archive(pathname)
def plot(

Просмотреть файл

@ -5,7 +5,8 @@
import glob
import os
from typing import Any, Callable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -67,7 +68,7 @@ class Esri2020(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@ -78,7 +79,7 @@ class Esri2020(RasterDataset):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -90,16 +91,19 @@ class Esri2020(RasterDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.download = download
self.checksum = checksum
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
@ -108,12 +112,12 @@ class Esri2020(RasterDataset):
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted file already exists
pathname = os.path.join(self.root, "**", self.filename_glob)
if glob.glob(pathname):
if self.files:
return
# Check if the zip files have already been downloaded
pathname = os.path.join(self.root, self.zipfile)
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, self.zipfile)
if glob.glob(pathname):
self._extract()
return
@ -121,7 +125,7 @@ class Esri2020(RasterDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
f"Dataset not found in `root={self.paths}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
@ -132,11 +136,12 @@ class Esri2020(RasterDataset):
def _download(self) -> None:
"""Download the dataset."""
download_url(self.url, self.root, filename=self.zipfile, md5=self.md5)
download_url(self.url, self.paths, filename=self.zipfile, md5=self.md5)
def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.zipfile))
assert isinstance(self.paths, str)
extract_archive(os.path.join(self.paths, self.zipfile))
def plot(
self,

Просмотреть файл

@ -5,7 +5,8 @@
import glob
import os
from typing import Any, Callable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -82,7 +83,7 @@ class EUDEM(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@ -92,8 +93,8 @@ class EUDEM(RasterDataset):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found, here the collection of
individual zip files for each tile should be found
paths: one or more root directories to search or files to load, here
the collection of individual zip files for each tile should be found
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -104,14 +105,17 @@ class EUDEM(RasterDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.checksum = checksum
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
@ -120,12 +124,12 @@ class EUDEM(RasterDataset):
RuntimeError: if dataset is missing or checksum fails
"""
# Check if the extracted file already exists
pathname = os.path.join(self.root, self.filename_glob)
if glob.glob(pathname):
if self.files:
return
# Check if the zip files have already been downloaded
pathname = os.path.join(self.root, self.zipfile_glob)
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, self.zipfile_glob)
if glob.glob(pathname):
for zipfile in glob.iglob(pathname):
filename = os.path.basename(zipfile)
@ -135,7 +139,7 @@ class EUDEM(RasterDataset):
return
raise RuntimeError(
f"Dataset not found in `root={self.root}` "
f"Dataset not found in `root={self.paths}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded the dataset as suggested in the documentation."
)

Просмотреть файл

@ -9,8 +9,8 @@ import glob
import os
import re
import sys
from collections.abc import Sequence
from typing import Any, Callable, Optional, cast
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
import fiona
import fiona.transform
@ -329,7 +329,7 @@ class RasterDataset(GeoDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
@ -339,7 +339,7 @@ class RasterDataset(GeoDataset):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -350,19 +350,21 @@ class RasterDataset(GeoDataset):
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
super().__init__(transforms)
self.root = root
self.paths = paths
self.bands = bands or self.all_bands
self.cache = cache
# Populate the dataset index
i = 0
pathname = os.path.join(root, "**", self.filename_glob)
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
for filepath in glob.iglob(pathname, recursive=True):
for filepath in self.files:
match = re.match(filename_regex, os.path.basename(filepath))
if match is not None:
try:
@ -396,7 +398,10 @@ class RasterDataset(GeoDataset):
i += 1
if i == 0:
msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`"
msg = (
f"No {self.__class__.__name__} data was found "
f"in `paths={self.paths!r}'`"
)
if self.bands:
msg += f" with `bands={self.bands}`"
raise FileNotFoundError(msg)
@ -418,6 +423,32 @@ class RasterDataset(GeoDataset):
self._crs = cast(CRS, crs)
self._res = cast(float, res)
@property
def files(self) -> set[str]:
"""A list of all files in the dataset.
Returns:
All files in the dataset.
.. versionadded:: 0.5
"""
# Make iterable
if isinstance(self.paths, str):
paths: Iterable[str] = [self.paths]
else:
paths = self.paths
# Using set to remove any duplicates if directories are overlapping
files: set[str] = set()
for path in paths:
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
files.add(path)
return files
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.

Просмотреть файл

@ -5,7 +5,8 @@
import glob
import os
from typing import Any, Callable, Optional, cast
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union, cast
import matplotlib.pyplot as plt
import torch
@ -118,7 +119,7 @@ class GlobBiomass(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
measurement: str = "agb",
@ -129,7 +130,7 @@ class GlobBiomass(RasterDataset):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -141,11 +142,14 @@ class GlobBiomass(RasterDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if dataset is missing or checksum fails
AssertionError: if measurement argument is invalid, or not a str
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.checksum = checksum
assert isinstance(measurement, str), "Measurement argument must be a str."
@ -161,7 +165,7 @@ class GlobBiomass(RasterDataset):
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
@ -206,12 +210,12 @@ class GlobBiomass(RasterDataset):
RuntimeError: if dataset is missing or checksum fails
"""
# Check if the extracted file already exists
pathname = os.path.join(self.root, self.filename_glob)
if glob.glob(pathname):
if self.files:
return
# Check if the zip files have already been downloaded
pathname = os.path.join(self.root, self.zipfile_glob)
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, self.zipfile_glob)
if glob.glob(pathname):
for zipfile in glob.iglob(pathname):
filename = os.path.basename(zipfile)
@ -221,7 +225,7 @@ class GlobBiomass(RasterDataset):
return
raise RuntimeError(
f"Dataset not found in `root={self.root}` "
f"Dataset not found in `root={self.paths}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded the dataset as suggested in the documentation."
)

Просмотреть файл

@ -5,8 +5,8 @@
import glob
import os
from collections.abc import Sequence
from typing import Any, Callable, Optional, cast
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -91,7 +91,7 @@ class L7Irish(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = CRS.from_epsg(3857),
res: Optional[float] = None,
bands: Sequence[str] = all_bands,
@ -103,7 +103,7 @@ class L7Irish(RasterDataset):
"""Initialize a new L7Irish instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to EPSG:3857)
res: resolution of the dataset in units of CRS
@ -118,15 +118,18 @@ class L7Irish(RasterDataset):
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.download = download
self.checksum = checksum
self._verify()
super().__init__(
root, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
)
def _verify(self) -> None:
@ -136,12 +139,12 @@ class L7Irish(RasterDataset):
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted files already exist
pathname = os.path.join(self.root, "**", self.filename_glob)
for fname in glob.iglob(pathname, recursive=True):
if self.files:
return
# Check if the tar.gz files have already been downloaded
pathname = os.path.join(self.root, "*.tar.gz")
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "*.tar.gz")
if glob.glob(pathname):
self._extract()
return
@ -149,7 +152,7 @@ class L7Irish(RasterDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
f"Dataset not found in `root={self.paths}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
@ -162,12 +165,13 @@ class L7Irish(RasterDataset):
"""Download the dataset."""
for biome, md5 in self.md5s.items():
download_url(
self.url.format(biome), self.root, md5=md5 if self.checksum else None
self.url.format(biome), self.paths, md5=md5 if self.checksum else None
)
def _extract(self) -> None:
"""Extract the dataset."""
pathname = os.path.join(self.root, "*.tar.gz")
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "*.tar.gz")
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)

Просмотреть файл

@ -5,8 +5,8 @@
import glob
import os
from collections.abc import Sequence
from typing import Any, Callable, Optional, cast
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -90,7 +90,7 @@ class L8Biome(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]],
crs: Optional[CRS] = CRS.from_epsg(3857),
res: Optional[float] = None,
bands: Sequence[str] = all_bands,
@ -102,7 +102,7 @@ class L8Biome(RasterDataset):
"""Initialize a new L8Biome instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to EPSG:3857)
res: resolution of the dataset in units of CRS
@ -117,15 +117,18 @@ class L8Biome(RasterDataset):
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
self.root = root
self.paths = paths
self.download = download
self.checksum = checksum
self._verify()
super().__init__(
root, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
)
def _verify(self) -> None:
@ -135,12 +138,12 @@ class L8Biome(RasterDataset):
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted files already exist
pathname = os.path.join(self.root, "**", self.filename_glob)
for fname in glob.iglob(pathname, recursive=True):
if self.files:
return
# Check if the tar.gz files have already been downloaded
pathname = os.path.join(self.root, "*.tar.gz")
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "*.tar.gz")
if glob.glob(pathname):
self._extract()
return
@ -148,7 +151,7 @@ class L8Biome(RasterDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
f"Dataset not found in `root={self.paths}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
@ -161,12 +164,13 @@ class L8Biome(RasterDataset):
"""Download the dataset."""
for biome, md5 in self.md5s.items():
download_url(
self.url.format(biome), self.root, md5=md5 if self.checksum else None
self.url.format(biome), self.paths, md5=md5 if self.checksum else None
)
def _extract(self) -> None:
"""Extract the dataset."""
pathname = os.path.join(self.root, "*.tar.gz")
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "*.tar.gz")
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)

Просмотреть файл

@ -222,20 +222,20 @@ class LandCoverAIGeo(LandCoverAIBase, RasterDataset):
"""Initialize a new LandCover.ai NonGeo dataset instance.
Args:
root: root directory where dataset can be found
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
root: root directory where dataset can be found
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
LandCoverAIBase.__init__(self, root, download, checksum)
RasterDataset.__init__(self, root, crs, res, transforms=transforms, cache=cache)

Просмотреть файл

@ -4,8 +4,8 @@
"""Landsat datasets."""
import abc
from collections.abc import Sequence
from typing import Any, Callable, Optional
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
@ -58,7 +58,7 @@ class Landsat(RasterDataset, abc.ABC):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
@ -68,7 +68,7 @@ class Landsat(RasterDataset, abc.ABC):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -79,12 +79,15 @@ class Landsat(RasterDataset, abc.ABC):
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
bands = bands or self.default_bands
self.filename_glob = self.filename_glob.format(bands[0])
super().__init__(root, crs, res, bands, transforms, cache)
super().__init__(paths, crs, res, bands, transforms, cache)
def plot(
self,

Просмотреть файл

@ -5,7 +5,8 @@
import glob
import os
from typing import Any, Callable, Optional
from collections.abc import Iterable
from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
import torch
@ -106,7 +107,7 @@ class NLCD(RasterDataset):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2019],
@ -119,7 +120,7 @@ class NLCD(RasterDataset):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -135,8 +136,11 @@ class NLCD(RasterDataset):
Raises:
AssertionError: if ``years`` or ``classes`` are invalid
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
assert set(years) <= self.md5s.keys(), (
"NLCD data product only exists for the following years: "
@ -147,7 +151,7 @@ class NLCD(RasterDataset):
), f"Only the following classes are valid: {list(self.cmap.keys())}."
assert 0 in classes, "Classes must include the background class: 0"
self.root = root
self.paths = paths
self.years = years
self.classes = classes
self.download = download
@ -157,7 +161,7 @@ class NLCD(RasterDataset):
self._verify()
super().__init__(root, crs, res, transforms=transforms, cache=cache)
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
# Map chosen classes to ordinal numbers, all others mapped to background class
for v, k in enumerate(self.classes):
@ -187,23 +191,15 @@ class NLCD(RasterDataset):
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted files already exist
exists = []
for year in self.years:
filename_year = self.filename_glob.replace("*", str(year), 1)
pathname = os.path.join(self.root, "**", filename_year)
if glob.glob(pathname, recursive=True):
exists.append(True)
else:
exists.append(False)
if all(exists):
if self.files:
return
# Check if the zip files have already been downloaded
exists = []
for year in self.years:
zipfile_year = self.zipfile_glob.replace("*", str(year), 1)
pathname = os.path.join(self.root, "**", zipfile_year)
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "**", zipfile_year)
if glob.glob(pathname, recursive=True):
exists.append(True)
self._extract()
@ -216,7 +212,7 @@ class NLCD(RasterDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
f"Dataset not found in `root={self.paths}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
@ -230,7 +226,7 @@ class NLCD(RasterDataset):
for year in self.years:
download_url(
self.url.format(year),
self.root,
self.paths,
md5=self.md5s[year] if self.checksum else None,
)
@ -238,8 +234,9 @@ class NLCD(RasterDataset):
"""Extract the dataset."""
for year in self.years:
zipfile_name = self.zipfile_glob.replace("*", str(year), 1)
pathname = os.path.join(self.root, "**", zipfile_name)
extract_archive(glob.glob(pathname, recursive=True)[0], self.root)
assert isinstance(self.paths, str)
pathname = os.path.join(self.paths, "**", zipfile_name)
extract_archive(glob.glob(pathname, recursive=True)[0], self.paths)
def plot(
self,

Просмотреть файл

@ -3,8 +3,8 @@
"""Sentinel datasets."""
from collections.abc import Sequence
from typing import Any, Callable, Optional
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
import torch
@ -140,7 +140,7 @@ class Sentinel1(Sentinel):
def __init__(
self,
root: str = "data",
paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
res: float = 10,
bands: Sequence[str] = ["VV", "VH"],
@ -150,7 +150,7 @@ class Sentinel1(Sentinel):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -162,7 +162,10 @@ class Sentinel1(Sentinel):
Raises:
AssertionError: if ``bands`` is invalid
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
.. versionchanged:: 0.5
*root* was renamed to *paths*.
"""
assert len(bands) > 0, "'bands' cannot be an empty list"
assert len(bands) == len(set(bands)), "'bands' contains duplicate bands"
@ -184,7 +187,7 @@ To create a dataset containing both, use:
self.filename_glob = self.filename_glob.format(bands[0])
super().__init__(root, crs, res, bands, transforms, cache)
super().__init__(paths, crs, res, bands, transforms, cache)
def plot(
self,
@ -293,7 +296,7 @@ class Sentinel2(Sentinel):
def __init__(
self,
root: str = "data",
paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 10,
bands: Optional[Sequence[str]] = None,
@ -303,7 +306,7 @@ class Sentinel2(Sentinel):
"""Initialize a new Dataset instance.
Args:
root: root directory where dataset can be found
paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@ -314,13 +317,16 @@ class Sentinel2(Sentinel):
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``root``
FileNotFoundError: if no files are found in ``paths``
.. versionchanged:: 0.5
*root* was renamed to *paths*
"""
bands = bands or self.all_bands
self.filename_glob = self.filename_glob.format(bands[0])
self.filename_regex = self.filename_regex.format(res)
super().__init__(root, crs, res, bands, transforms, cache)
super().__init__(paths, crs, res, bands, transforms, cache)
def plot(
self,