ChesapeakeCVPR: fix non-existing dir support, add unit tests (#195)

This commit is contained in:
Adam J. Stewart 2021-10-13 00:20:45 -05:00 коммит произвёл GitHub
Родитель 09d4df9efa
Коммит 24c3f70f5f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 126 добавлений и 39 удалений

Двоичные данные
tests/data/chesapeake/cvpr/cvpr_chesapeake_landcover.zip Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -9,15 +9,16 @@ from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, Chesapeake13, ZipDataset
from torchgeo.datasets import BoundingBox, Chesapeake13, ChesapeakeCVPR, ZipDataset
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -76,3 +77,84 @@ class TestChesapeake13:
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
class TestChesapeakeCVPR:
@pytest.fixture(
params=[
("naip-new", "naip-old", "nlcd"),
("landsat-leaf-on", "landsat-leaf-off", "lc"),
("naip-new", "landsat-leaf-on", "lc", "nlcd", "buildings"),
]
)
def dataset(
self,
request: SubRequest,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> ChesapeakeCVPR:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.chesapeake, "download_url", download_url
)
md5 = "564b8d944a941b0b65db9f56c92b93a2"
monkeypatch.setattr(ChesapeakeCVPR, "md5", md5) # type: ignore[attr-defined]
url = os.path.join(
"tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip"
)
monkeypatch.setattr(ChesapeakeCVPR, "url", url) # type: ignore[attr-defined]
monkeypatch.setattr( # type: ignore[attr-defined]
ChesapeakeCVPR,
"files",
["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"],
)
root = str(tmp_path)
transforms = Identity()
return ChesapeakeCVPR(
root,
splits=["de-test"],
layers=request.param,
transforms=transforms,
download=True,
checksum=True,
)
def test_getitem(self, dataset: ChesapeakeCVPR) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_add(self, dataset: ChesapeakeCVPR) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None:
ChesapeakeCVPR(root=dataset.root, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
url = os.path.join(
"tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip"
)
root = str(tmp_path)
shutil.copy(url, root)
ChesapeakeCVPR(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
ChesapeakeCVPR(str(tmp_path), checksum=True)
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]
def test_multiple_hits_query(self, dataset: ChesapeakeCVPR) -> None:
ds = ChesapeakeCVPR(
root=dataset.root, splits=["de-train", "de-test"], layers=dataset.layers
)
with pytest.raises(
IndexError, match="query: .* spans multiple tiles which is not valid"
):
ds[dataset.bounds]

Просмотреть файл

@ -19,7 +19,13 @@ import torch
from rasterio.crs import CRS
from .geo import GeoDataset, RasterDataset
from .utils import BoundingBox, check_integrity, download_and_extract_archive
from .utils import (
BoundingBox,
check_integrity,
download_and_extract_archive,
download_url,
extract_archive,
)
class Chesapeake(RasterDataset, abc.ABC):
@ -376,21 +382,15 @@ class ChesapeakeCVPR(GeoDataset):
for split in splits:
assert split in self.splits
assert all([layer in self.valid_layers for layer in layers])
super().__init__(transforms) # creates self.index and self.transform
self.root = root
self.layers = layers
self.cache = cache
self.download = download
self.checksum = checksum
if download and not self._check_structure():
self._download()
self._verify()
if checksum:
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
super().__init__(transforms)
# Add all tiles into the index in epsg:3857 based on the included geojson
mint: float = 0
@ -496,40 +496,45 @@ class ChesapeakeCVPR(GeoDataset):
return sample
def _check_integrity(self) -> bool:
"""Check integrity of the dataset archive.
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Returns:
True if dataset archive is found and/or MD5s match, else False
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
integrity: bool = check_integrity(
os.path.join(self.root, self.filename),
self.md5 if self.checksum else None,
)
# Check if the extracted files already exist
def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, filename))
return integrity
def _check_structure(self) -> bool:
"""Checks to see if the dataset files exist in the root directory.
Returns:
True if the dataset files are found, else False
"""
dataset_files = os.listdir(self.root)
for file in self.files:
if file not in dataset_files:
return False
return True
def _download(self) -> None:
"""Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
if all(map(exists, self.files)):
return
download_and_extract_archive(
# Check if the zip files have already been downloaded
if os.path.exists(os.path.join(self.root, self.filename)):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)
# Download the dataset
self._download()
self._extract()
def _download(self) -> None:
"""Download the dataset."""
download_url(
self.url,
self.root,
filename=self.filename,
md5=self.md5,
)
def _extract(self) -> None:
"""Extract the dataset."""
extract_archive(os.path.join(self.root, self.filename))