Add year filter functionality CDL (#1337)

* add year functionality CDL

* requested changes and test coverage

* removesuffix clearer than strip

* recursive file verify

* Fix versionadded syntax

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Nils Lehmann 2023-05-25 08:25:46 +02:00 коммит произвёл GitHub
Родитель ded347e17b
Коммит 2d576d915f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 77 добавлений и 36 удалений

Просмотреть файл

@ -27,17 +27,23 @@ class TestCDL:
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL:
monkeypatch.setattr(torchgeo.datasets.cdl, "download_url", download_url)
md5s = [
(2021, "e929beb9c8e59fa1d7b7f82e64edaae1"),
(2020, "e95c2d40ce0c261ed6ee0bd00b49e4b6"),
]
md5s = {
2021: "e929beb9c8e59fa1d7b7f82e64edaae1",
2020: "e95c2d40ce0c261ed6ee0bd00b49e4b6",
}
monkeypatch.setattr(CDL, "md5s", md5s)
url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip")
monkeypatch.setattr(CDL, "url", url)
monkeypatch.setattr(plt, "show", lambda *args: None)
root = str(tmp_path)
transforms = nn.Identity()
return CDL(root, transforms=transforms, download=True, checksum=True)
return CDL(
root,
transforms=transforms,
download=True,
checksum=True,
years=[2020, 2021],
)
def test_getitem(self, dataset: CDL) -> None:
x = dataset[dataset.bounds]
@ -60,14 +66,21 @@ class TestCDL:
next(dataset.index.intersection(tuple(query)))
def test_already_extracted(self, dataset: CDL) -> None:
CDL(root=dataset.root, download=True)
CDL(root=dataset.root, years=[2020, 2021])
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip")
root = str(tmp_path)
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
CDL(root)
CDL(root, years=[2020, 2021])
def test_invalid_year(self, tmp_path: Path) -> None:
with pytest.raises(
AssertionError,
match="CDL data product only exists for the following years:",
):
CDL(str(tmp_path), years=[1996])
def test_plot(self, dataset: CDL) -> None:
query = dataset.bounds

Просмотреть файл

@ -47,23 +47,23 @@ class CDL(RasterDataset):
is_image = False
url = "https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip" # noqa: E501
md5s = [
(2022, "754cf50670cdfee511937554785de3e6"),
(2021, "27606eab08fe975aa138baad3e5dfcd8"),
(2020, "483ee48c503aa81b684225179b402d42"),
(2019, "a5168a2fc93acbeaa93e24eee3d8c696"),
(2018, "4ad0d7802a9bb751685eb239b0fa8609"),
(2017, "d173f942a70f94622f9b8290e7548684"),
(2016, "fddc5dff0bccc617d70a12864c993e51"),
(2015, "2e92038ab62ba75e1687f60eecbdd055"),
(2014, "50bdf9da84ebd0457ddd9e0bf9bbcc1f"),
(2013, "7be66c650416dc7c4a945dd7fd93c5b7"),
(2012, "286504ff0512e9fe1a1975c635a1bec2"),
(2011, "517bad1a99beec45d90abb651fb1f0e3"),
(2010, "98d354c5a62c9e3e40ccadce265c721c"),
(2009, "663c8a5fdd92ebfc0d6bee008586d19a"),
(2008, "0610f2f17ab60a9fbb3baeb7543993a4"),
]
md5s = {
2022: "754cf50670cdfee511937554785de3e6",
2021: "27606eab08fe975aa138baad3e5dfcd8",
2020: "483ee48c503aa81b684225179b402d42",
2019: "a5168a2fc93acbeaa93e24eee3d8c696",
2018: "4ad0d7802a9bb751685eb239b0fa8609",
2017: "d173f942a70f94622f9b8290e7548684",
2016: "fddc5dff0bccc617d70a12864c993e51",
2015: "2e92038ab62ba75e1687f60eecbdd055",
2014: "50bdf9da84ebd0457ddd9e0bf9bbcc1f",
2013: "7be66c650416dc7c4a945dd7fd93c5b7",
2012: "286504ff0512e9fe1a1975c635a1bec2",
2011: "517bad1a99beec45d90abb651fb1f0e3",
2010: "98d354c5a62c9e3e40ccadce265c721c",
2009: "663c8a5fdd92ebfc0d6bee008586d19a",
2008: "0610f2f17ab60a9fbb3baeb7543993a4",
}
cmap = {
0: (0, 0, 0, 0),
@ -466,6 +466,7 @@ class CDL(RasterDataset):
root: str = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2022],
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
cache: bool = True,
download: bool = False,
@ -479,6 +480,7 @@ class CDL(RasterDataset):
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
(defaults to the resolution of the first file found)
years: list of years for which to use cdl layer
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
@ -488,7 +490,15 @@ class CDL(RasterDataset):
Raises:
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
.. versionadded:: 0.5
The *years* parameter.
"""
assert set(years).issubset(self.md5s.keys()), (
"CDL data product only exists for the following years: "
f"{list(self.md5s.keys())}."
)
self.years = years
self.root = root
self.download = download
self.checksum = checksum
@ -526,15 +536,30 @@ class CDL(RasterDataset):
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted files already exist
pathname = os.path.join(self.root, "**", self.filename_glob)
for fname in glob.iglob(pathname, recursive=True):
if not fname.endswith(".zip"):
return
exists = []
for year in self.years:
filename_year = self.filename_glob.replace("*", str(year))
pathname = os.path.join(self.root, "**", filename_year)
for fname in glob.iglob(pathname, recursive=True):
if not fname.endswith(".zip"):
exists.append(True)
if len(exists) == len(self.years):
return
# Check if the zip files have already been downloaded
pathname = os.path.join(self.root, self.zipfile_glob)
if glob.glob(pathname):
self._extract()
exists = []
for year in self.years:
pathname = os.path.join(
self.root, self.zipfile_glob.replace("*", str(year))
)
if os.path.exists(pathname):
exists.append(True)
self._extract()
else:
exists.append(False)
if all(exists):
return
# Check if the user requested to download the dataset
@ -551,16 +576,19 @@ class CDL(RasterDataset):
def _download(self) -> None:
"""Download the dataset."""
for year, md5 in self.md5s:
for year in self.years:
download_url(
self.url.format(year), self.root, md5=md5 if self.checksum else None
self.url.format(year),
self.root,
md5=self.md5s[year] if self.checksum else None,
)
def _extract(self) -> None:
"""Extract the dataset."""
pathname = os.path.join(self.root, self.zipfile_glob)
for zipfile in glob.iglob(pathname):
extract_archive(zipfile)
for year in self.years:
zipfile_name = self.zipfile_glob.replace("*", str(year))
pathname = os.path.join(self.root, zipfile_name)
extract_archive(pathname, self.root)
def plot(
self,