зеркало из https://github.com/microsoft/torchgeo.git
Add tests for CDL
This commit is contained in:
Родитель
73e09ce833
Коммит
fae4a0fcc1
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,66 @@
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torchvision.datasets.utils
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import CDL, BoundingBox, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestCDL:
|
||||
@pytest.fixture
|
||||
def dataset(
|
||||
self,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
request: SubRequest,
|
||||
) -> CDL:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchvision.datasets.utils, "download_url", download_url
|
||||
)
|
||||
md5s = [
|
||||
(2021, "f86f6931c9146f140c306a3529260047"),
|
||||
(2020, "8bc282d7a0a99e397b3d53097f820189"),
|
||||
]
|
||||
monkeypatch.setattr(CDL, "md5s", md5s) # type: ignore[attr-defined]
|
||||
url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip")
|
||||
monkeypatch.setattr(CDL, "url", url) # type: ignore[attr-defined]
|
||||
(tmp_path / "cdl").mkdir()
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
return CDL(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: CDL) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["masks"], torch.Tensor)
|
||||
|
||||
def test_add(self, dataset: CDL) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
|
||||
def test_already_downloaded(self, dataset: CDL) -> None:
|
||||
CDL(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
CDL(str(tmp_path))
|
||||
|
||||
def test_invalid_query(self, dataset: CDL) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
with pytest.raises(
|
||||
IndexError, match="query: .* is not within bounds of the index:"
|
||||
):
|
||||
dataset[query]
|
|
@ -3,8 +3,8 @@ from typing import Tuple
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import BoundingBox, collate_dict
|
||||
from torchgeo.datasets.utils import working_dir
|
||||
|
||||
|
|
|
@ -153,11 +153,16 @@ class CDL(GeoDataset):
|
|||
window = Window(col_off, row_off, width, height)
|
||||
masks = vrt.read(window=window)
|
||||
masks = masks.astype(np.int32)
|
||||
return {
|
||||
sample = {
|
||||
"masks": torch.tensor(masks), # type: ignore[attr-defined]
|
||||
"crs": self.crs,
|
||||
}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Check integrity of dataset.
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче