зеркало из https://github.com/microsoft/torchgeo.git
Add year filter functionality CDL (#1337)
* add year functionality CDL * requested changes and test coverage * removesuffix clearer than strip * recursive file verify * Fix versionadded syntax --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
ded347e17b
Коммит
2d576d915f
|
@ -27,17 +27,23 @@ class TestCDL:
|
|||
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL:
|
||||
monkeypatch.setattr(torchgeo.datasets.cdl, "download_url", download_url)
|
||||
|
||||
md5s = [
|
||||
(2021, "e929beb9c8e59fa1d7b7f82e64edaae1"),
|
||||
(2020, "e95c2d40ce0c261ed6ee0bd00b49e4b6"),
|
||||
]
|
||||
md5s = {
|
||||
2021: "e929beb9c8e59fa1d7b7f82e64edaae1",
|
||||
2020: "e95c2d40ce0c261ed6ee0bd00b49e4b6",
|
||||
}
|
||||
monkeypatch.setattr(CDL, "md5s", md5s)
|
||||
url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip")
|
||||
monkeypatch.setattr(CDL, "url", url)
|
||||
monkeypatch.setattr(plt, "show", lambda *args: None)
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity()
|
||||
return CDL(root, transforms=transforms, download=True, checksum=True)
|
||||
return CDL(
|
||||
root,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
checksum=True,
|
||||
years=[2020, 2021],
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: CDL) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
|
@ -60,14 +66,21 @@ class TestCDL:
|
|||
next(dataset.index.intersection(tuple(query)))
|
||||
|
||||
def test_already_extracted(self, dataset: CDL) -> None:
|
||||
CDL(root=dataset.root, download=True)
|
||||
CDL(root=dataset.root, years=[2020, 2021])
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip")
|
||||
root = str(tmp_path)
|
||||
for zipfile in glob.iglob(pathname):
|
||||
shutil.copy(zipfile, root)
|
||||
CDL(root)
|
||||
CDL(root, years=[2020, 2021])
|
||||
|
||||
def test_invalid_year(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="CDL data product only exists for the following years:",
|
||||
):
|
||||
CDL(str(tmp_path), years=[1996])
|
||||
|
||||
def test_plot(self, dataset: CDL) -> None:
|
||||
query = dataset.bounds
|
||||
|
|
|
@ -47,23 +47,23 @@ class CDL(RasterDataset):
|
|||
is_image = False
|
||||
|
||||
url = "https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip" # noqa: E501
|
||||
md5s = [
|
||||
(2022, "754cf50670cdfee511937554785de3e6"),
|
||||
(2021, "27606eab08fe975aa138baad3e5dfcd8"),
|
||||
(2020, "483ee48c503aa81b684225179b402d42"),
|
||||
(2019, "a5168a2fc93acbeaa93e24eee3d8c696"),
|
||||
(2018, "4ad0d7802a9bb751685eb239b0fa8609"),
|
||||
(2017, "d173f942a70f94622f9b8290e7548684"),
|
||||
(2016, "fddc5dff0bccc617d70a12864c993e51"),
|
||||
(2015, "2e92038ab62ba75e1687f60eecbdd055"),
|
||||
(2014, "50bdf9da84ebd0457ddd9e0bf9bbcc1f"),
|
||||
(2013, "7be66c650416dc7c4a945dd7fd93c5b7"),
|
||||
(2012, "286504ff0512e9fe1a1975c635a1bec2"),
|
||||
(2011, "517bad1a99beec45d90abb651fb1f0e3"),
|
||||
(2010, "98d354c5a62c9e3e40ccadce265c721c"),
|
||||
(2009, "663c8a5fdd92ebfc0d6bee008586d19a"),
|
||||
(2008, "0610f2f17ab60a9fbb3baeb7543993a4"),
|
||||
]
|
||||
md5s = {
|
||||
2022: "754cf50670cdfee511937554785de3e6",
|
||||
2021: "27606eab08fe975aa138baad3e5dfcd8",
|
||||
2020: "483ee48c503aa81b684225179b402d42",
|
||||
2019: "a5168a2fc93acbeaa93e24eee3d8c696",
|
||||
2018: "4ad0d7802a9bb751685eb239b0fa8609",
|
||||
2017: "d173f942a70f94622f9b8290e7548684",
|
||||
2016: "fddc5dff0bccc617d70a12864c993e51",
|
||||
2015: "2e92038ab62ba75e1687f60eecbdd055",
|
||||
2014: "50bdf9da84ebd0457ddd9e0bf9bbcc1f",
|
||||
2013: "7be66c650416dc7c4a945dd7fd93c5b7",
|
||||
2012: "286504ff0512e9fe1a1975c635a1bec2",
|
||||
2011: "517bad1a99beec45d90abb651fb1f0e3",
|
||||
2010: "98d354c5a62c9e3e40ccadce265c721c",
|
||||
2009: "663c8a5fdd92ebfc0d6bee008586d19a",
|
||||
2008: "0610f2f17ab60a9fbb3baeb7543993a4",
|
||||
}
|
||||
|
||||
cmap = {
|
||||
0: (0, 0, 0, 0),
|
||||
|
@ -466,6 +466,7 @@ class CDL(RasterDataset):
|
|||
root: str = "data",
|
||||
crs: Optional[CRS] = None,
|
||||
res: Optional[float] = None,
|
||||
years: list[int] = [2022],
|
||||
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
|
||||
cache: bool = True,
|
||||
download: bool = False,
|
||||
|
@ -479,6 +480,7 @@ class CDL(RasterDataset):
|
|||
(defaults to the CRS of the first file found)
|
||||
res: resolution of the dataset in units of CRS
|
||||
(defaults to the resolution of the first file found)
|
||||
years: list of years for which to use cdl layer
|
||||
transforms: a function/transform that takes an input sample
|
||||
and returns a transformed version
|
||||
cache: if True, cache file handle to speed up repeated sampling
|
||||
|
@ -488,7 +490,15 @@ class CDL(RasterDataset):
|
|||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
|
||||
.. versionadded:: 0.5
|
||||
The *years* parameter.
|
||||
"""
|
||||
assert set(years).issubset(self.md5s.keys()), (
|
||||
"CDL data product only exists for the following years: "
|
||||
f"{list(self.md5s.keys())}."
|
||||
)
|
||||
self.years = years
|
||||
self.root = root
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
|
@ -526,15 +536,30 @@ class CDL(RasterDataset):
|
|||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
# Check if the extracted files already exist
|
||||
pathname = os.path.join(self.root, "**", self.filename_glob)
|
||||
for fname in glob.iglob(pathname, recursive=True):
|
||||
if not fname.endswith(".zip"):
|
||||
return
|
||||
exists = []
|
||||
for year in self.years:
|
||||
filename_year = self.filename_glob.replace("*", str(year))
|
||||
pathname = os.path.join(self.root, "**", filename_year)
|
||||
for fname in glob.iglob(pathname, recursive=True):
|
||||
if not fname.endswith(".zip"):
|
||||
exists.append(True)
|
||||
|
||||
if len(exists) == len(self.years):
|
||||
return
|
||||
|
||||
# Check if the zip files have already been downloaded
|
||||
pathname = os.path.join(self.root, self.zipfile_glob)
|
||||
if glob.glob(pathname):
|
||||
self._extract()
|
||||
exists = []
|
||||
for year in self.years:
|
||||
pathname = os.path.join(
|
||||
self.root, self.zipfile_glob.replace("*", str(year))
|
||||
)
|
||||
if os.path.exists(pathname):
|
||||
exists.append(True)
|
||||
self._extract()
|
||||
else:
|
||||
exists.append(False)
|
||||
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
|
@ -551,16 +576,19 @@ class CDL(RasterDataset):
|
|||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
for year, md5 in self.md5s:
|
||||
for year in self.years:
|
||||
download_url(
|
||||
self.url.format(year), self.root, md5=md5 if self.checksum else None
|
||||
self.url.format(year),
|
||||
self.root,
|
||||
md5=self.md5s[year] if self.checksum else None,
|
||||
)
|
||||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
pathname = os.path.join(self.root, self.zipfile_glob)
|
||||
for zipfile in glob.iglob(pathname):
|
||||
extract_archive(zipfile)
|
||||
for year in self.years:
|
||||
zipfile_name = self.zipfile_glob.replace("*", str(year))
|
||||
pathname = os.path.join(self.root, zipfile_name)
|
||||
extract_archive(pathname, self.root)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
|
Загрузка…
Ссылка в новой задаче