This commit is contained in:
Adam J. Stewart 2021-06-24 15:36:16 +00:00
Родитель 8810f48d44
Коммит 7a616d0a8b
4 изменённых файлов: 93 добавлений и 3 удалений

Двоичные данные
tests/data/ts_cashew_benin/ts_cashew_benin_labels.tar.gz Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/ts_cashew_benin/ts_cashew_benin_source.tar.gz Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,93 @@
import glob
import os
from pathlib import Path
import shutil
from typing import Generator
import pytest
from pytest import MonkeyPatch
import torch
from torchgeo.datasets import BeninSmallHolderCashews, ZipDataset
from torchgeo.transforms import Identity
class Dataset:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join("tests", "data", "ts_cashew_benin", "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
def fetch(collection_id: str, **kwargs: str) -> Dataset:
return Dataset()
class TestBeninSmallHolderCashews:
@pytest.fixture
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> BeninSmallHolderCashews:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr( # type: ignore[attr-defined]
radiant_mlhub.Dataset, "fetch", fetch
)
source_md5 = "255efff0f03bc6322470949a09bc76db"
labels_md5 = "ed2195d93ca6822d48eb02bc3e81c127"
monkeypatch.setitem( # type: ignore[attr-defined]
BeninSmallHolderCashews.image_meta, "md5", source_md5
)
monkeypatch.setitem( # type: ignore[attr-defined]
BeninSmallHolderCashews.target_meta, "md5", labels_md5
)
monkeypatch.setattr( # type: ignore[attr-defined]
BeninSmallHolderCashews, "dates", ("2019_11_05",)
)
(tmp_path / "ts_cashew_benin").mkdir()
root = str(tmp_path)
transforms = Identity()
return BeninSmallHolderCashews(
root,
transforms=transforms,
download=True,
api_key="",
checksum=True,
verbose=True,
)
def test_getitem(self, dataset: BeninSmallHolderCashews) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x["x"], torch.Tensor)
assert isinstance(x["y"], torch.Tensor)
def test_len(self, dataset: BeninSmallHolderCashews) -> None:
assert len(dataset) == 72
def test_add(self, dataset: BeninSmallHolderCashews) -> None:
ds = dataset + dataset
assert isinstance(ds, ZipDataset)
assert len(ds) == 72
def test_already_downloaded(self, dataset: BeninSmallHolderCashews) -> None:
BeninSmallHolderCashews(root=dataset.root, download=True, api_key="")
def test_missing_api_key(self) -> None:
match = "You must pass an MLHub API key if download=True."
with pytest.raises(RuntimeError, match=match):
BeninSmallHolderCashews(download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
BeninSmallHolderCashews(str(tmp_path))
def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
BeninSmallHolderCashews(bands=["B01", "B02"]) # type: ignore[arg-type]
with pytest.raises(ValueError, match="is an invalid band name."):
BeninSmallHolderCashews(bands=("foo", "bar"))

Просмотреть файл

@ -432,9 +432,6 @@ class BeninSmallHolderCashews(GeoDataset):
output_dir=os.path.join(self.root, self.dataset_id), api_key=api_key
)
if not self._check_integrity():
raise RuntimeError("Dataset files not found or corrupted.")
image_archive_path = os.path.join(
self.root, self.dataset_id, self.image_meta["filename"]
)