зеркало из https://github.com/microsoft/torchgeo.git
ChesapeakeCVPR: fix non-existing dir support, add unit tests (#195)
This commit is contained in:
Родитель
09d4df9efa
Коммит
24c3f70f5f
Двоичный файл не отображается.
|
@ -9,15 +9,16 @@ from typing import Generator
|
|||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, Chesapeake13, ZipDataset
|
||||
from torchgeo.datasets import BoundingBox, Chesapeake13, ChesapeakeCVPR, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -76,3 +77,84 @@ class TestChesapeake13:
|
|||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
dataset[query]
|
||||
|
||||
|
||||
class TestChesapeakeCVPR:
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("naip-new", "naip-old", "nlcd"),
|
||||
("landsat-leaf-on", "landsat-leaf-off", "lc"),
|
||||
("naip-new", "landsat-leaf-on", "lc", "nlcd", "buildings"),
|
||||
]
|
||||
)
|
||||
def dataset(
|
||||
self,
|
||||
request: SubRequest,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
) -> ChesapeakeCVPR:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.chesapeake, "download_url", download_url
|
||||
)
|
||||
md5 = "564b8d944a941b0b65db9f56c92b93a2"
|
||||
monkeypatch.setattr(ChesapeakeCVPR, "md5", md5) # type: ignore[attr-defined]
|
||||
url = os.path.join(
|
||||
"tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip"
|
||||
)
|
||||
monkeypatch.setattr(ChesapeakeCVPR, "url", url) # type: ignore[attr-defined]
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
ChesapeakeCVPR,
|
||||
"files",
|
||||
["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"],
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
return ChesapeakeCVPR(
|
||||
root,
|
||||
splits=["de-test"],
|
||||
layers=request.param,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
checksum=True,
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: ChesapeakeCVPR) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
|
||||
def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ChesapeakeCVPR(root=dataset.root, download=True)
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
url = os.path.join(
|
||||
"tests", "data", "chesapeake", "cvpr", "cvpr_chesapeake_landcover.zip"
|
||||
)
|
||||
root = str(tmp_path)
|
||||
shutil.copy(url, root)
|
||||
ChesapeakeCVPR(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
ChesapeakeCVPR(str(tmp_path), checksum=True)
|
||||
|
||||
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* not found in index with bounds:"
|
||||
):
|
||||
dataset[query]
|
||||
|
||||
def test_multiple_hits_query(self, dataset: ChesapeakeCVPR) -> None:
|
||||
ds = ChesapeakeCVPR(
|
||||
root=dataset.root, splits=["de-train", "de-test"], layers=dataset.layers
|
||||
)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* spans multiple tiles which is not valid"
|
||||
):
|
||||
ds[dataset.bounds]
|
||||
|
|
|
@ -19,7 +19,13 @@ import torch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import GeoDataset, RasterDataset
|
||||
from .utils import BoundingBox, check_integrity, download_and_extract_archive
|
||||
from .utils import (
|
||||
BoundingBox,
|
||||
check_integrity,
|
||||
download_and_extract_archive,
|
||||
download_url,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
class Chesapeake(RasterDataset, abc.ABC):
|
||||
|
@ -376,21 +382,15 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
for split in splits:
|
||||
assert split in self.splits
|
||||
assert all([layer in self.valid_layers for layer in layers])
|
||||
super().__init__(transforms) # creates self.index and self.transform
|
||||
self.root = root
|
||||
self.layers = layers
|
||||
self.cache = cache
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
|
||||
if download and not self._check_structure():
|
||||
self._download()
|
||||
self._verify()
|
||||
|
||||
if checksum:
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
super().__init__(transforms)
|
||||
|
||||
# Add all tiles into the index in epsg:3857 based on the included geojson
|
||||
mint: float = 0
|
||||
|
@ -496,40 +496,45 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
|
||||
return sample
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Check integrity of the dataset archive.
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Returns:
|
||||
True if dataset archive is found and/or MD5s match, else False
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
integrity: bool = check_integrity(
|
||||
os.path.join(self.root, self.filename),
|
||||
self.md5 if self.checksum else None,
|
||||
)
|
||||
# Check if the extracted files already exist
|
||||
def exists(filename: str) -> bool:
|
||||
return os.path.exists(os.path.join(self.root, filename))
|
||||
|
||||
return integrity
|
||||
|
||||
def _check_structure(self) -> bool:
|
||||
"""Checks to see if the dataset files exist in the root directory.
|
||||
|
||||
Returns:
|
||||
True if the dataset files are found, else False
|
||||
"""
|
||||
dataset_files = os.listdir(self.root)
|
||||
for file in self.files:
|
||||
if file not in dataset_files:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it."""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
if all(map(exists, self.files)):
|
||||
return
|
||||
|
||||
download_and_extract_archive(
|
||||
# Check if the zip files have already been downloaded
|
||||
if os.path.exists(os.path.join(self.root, self.filename)):
|
||||
self._extract()
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
self._extract()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
download_url(
|
||||
self.url,
|
||||
self.root,
|
||||
filename=self.filename,
|
||||
md5=self.md5,
|
||||
)
|
||||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
extract_archive(os.path.join(self.root, self.filename))
|
||||
|
|
Загрузка…
Ссылка в новой задаче