This commit is contained in:
Adam J. Stewart 2021-07-16 19:09:32 +00:00
Родитель 73e09ce833
Коммит fae4a0fcc1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
5 изменённых файлов: 73 добавлений и 2 удалений

Двоичные данные
tests/data/cdl/2020_30m_cdls.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/cdl/2021_30m_cdls.zip Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,66 @@
import os
import shutil
from pathlib import Path
from typing import Generator
import pytest
import torch
import torchvision.datasets.utils
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import CDL, BoundingBox, ZipDataset
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
class TestCDL:
@pytest.fixture
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
request: SubRequest,
) -> CDL:
monkeypatch.setattr( # type: ignore[attr-defined]
torchvision.datasets.utils, "download_url", download_url
)
md5s = [
(2021, "f86f6931c9146f140c306a3529260047"),
(2020, "8bc282d7a0a99e397b3d53097f820189"),
]
monkeypatch.setattr(CDL, "md5s", md5s) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip")
monkeypatch.setattr(CDL, "url", url) # type: ignore[attr-defined]
(tmp_path / "cdl").mkdir()
root = str(tmp_path)
transforms = Identity()
return CDL(root, transforms=transforms, download=True, checksum=True)
def test_getitem(self, dataset: CDL) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["masks"], torch.Tensor)
def test_add(self, dataset: CDL) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
def test_already_downloaded(self, dataset: CDL) -> None:
CDL(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
CDL(str(tmp_path))
def test_invalid_query(self, dataset: CDL) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* is not within bounds of the index:"
):
dataset[query]

Просмотреть файл

@ -3,8 +3,8 @@ from typing import Tuple
import pytest
import torch
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, collate_dict
from torchgeo.datasets.utils import working_dir

Просмотреть файл

@ -153,11 +153,16 @@ class CDL(GeoDataset):
window = Window(col_off, row_off, width, height)
masks = vrt.read(window=window)
masks = masks.astype(np.int32)
return {
sample = {
"masks": torch.tensor(masks), # type: ignore[attr-defined]
"crs": self.crs,
}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def _check_integrity(self) -> bool:
"""Check integrity of dataset.