зеркало из https://github.com/microsoft/torchgeo.git
Add unit tests
This commit is contained in:
Родитель
8810f48d44
Коммит
7a616d0a8b
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,93 @@
|
|||
import glob
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
import torch
|
||||
|
||||
from torchgeo.datasets import BeninSmallHolderCashews, ZipDataset
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
class Dataset:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join("tests", "data", "ts_cashew_benin", "*.tar.gz")
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(collection_id: str, **kwargs: str) -> Dataset:
|
||||
return Dataset()
|
||||
|
||||
|
||||
class TestBeninSmallHolderCashews:
|
||||
@pytest.fixture
|
||||
def dataset(
|
||||
self,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
) -> BeninSmallHolderCashews:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
radiant_mlhub.Dataset, "fetch", fetch
|
||||
)
|
||||
source_md5 = "255efff0f03bc6322470949a09bc76db"
|
||||
labels_md5 = "ed2195d93ca6822d48eb02bc3e81c127"
|
||||
monkeypatch.setitem( # type: ignore[attr-defined]
|
||||
BeninSmallHolderCashews.image_meta, "md5", source_md5
|
||||
)
|
||||
monkeypatch.setitem( # type: ignore[attr-defined]
|
||||
BeninSmallHolderCashews.target_meta, "md5", labels_md5
|
||||
)
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
BeninSmallHolderCashews, "dates", ("2019_11_05",)
|
||||
)
|
||||
(tmp_path / "ts_cashew_benin").mkdir()
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
return BeninSmallHolderCashews(
|
||||
root,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
api_key="",
|
||||
checksum=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: BeninSmallHolderCashews) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
assert isinstance(x["x"], torch.Tensor)
|
||||
assert isinstance(x["y"], torch.Tensor)
|
||||
|
||||
def test_len(self, dataset: BeninSmallHolderCashews) -> None:
|
||||
assert len(dataset) == 72
|
||||
|
||||
def test_add(self, dataset: BeninSmallHolderCashews) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ZipDataset)
|
||||
assert len(ds) == 72
|
||||
|
||||
def test_already_downloaded(self, dataset: BeninSmallHolderCashews) -> None:
|
||||
BeninSmallHolderCashews(root=dataset.root, download=True, api_key="")
|
||||
|
||||
def test_missing_api_key(self) -> None:
|
||||
match = "You must pass an MLHub API key if download=True."
|
||||
with pytest.raises(RuntimeError, match=match):
|
||||
BeninSmallHolderCashews(download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
BeninSmallHolderCashews(str(tmp_path))
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
BeninSmallHolderCashews(bands=["B01", "B02"]) # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(ValueError, match="is an invalid band name."):
|
||||
BeninSmallHolderCashews(bands=("foo", "bar"))
|
|
@ -432,9 +432,6 @@ class BeninSmallHolderCashews(GeoDataset):
|
|||
output_dir=os.path.join(self.root, self.dataset_id), api_key=api_key
|
||||
)
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError("Dataset files not found or corrupted.")
|
||||
|
||||
image_archive_path = os.path.join(
|
||||
self.root, self.dataset_id, self.image_meta["filename"]
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче