зеркало из https://github.com/microsoft/torchgeo.git
update tests
This commit is contained in:
Родитель
2ff7d776e7
Коммит
1a49cfffbe
|
@ -8,15 +8,15 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import EuroSAT
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -28,32 +28,14 @@ class TestEuroSAT:
|
|||
tmp_path: Path,
|
||||
) -> EuroSAT:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.utils, "download_url", download_url
|
||||
torchgeo.datasets.eurosat, "download_url", download_url
|
||||
)
|
||||
md5 = "aa051207b0547daba0ac6af57808d68e"
|
||||
monkeypatch.setattr(EuroSAT, "md5", md5) # type: ignore[attr-defined]
|
||||
url = os.path.join("tests", "data", "eurosat", "EuroSATallBands.zip")
|
||||
monkeypatch.setattr(EuroSAT, "url", url) # type: ignore[attr-defined]
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
EuroSAT,
|
||||
"class_counts",
|
||||
{
|
||||
"AnnualCrop": 1,
|
||||
"Forest": 1,
|
||||
"HerbaceousVegetation": 0,
|
||||
"Highway": 0,
|
||||
"Industrial": 0,
|
||||
"Pasture": 0,
|
||||
"PermanentCrop": 0,
|
||||
"Residential": 0,
|
||||
"River": 0,
|
||||
"SeaLake": 0,
|
||||
},
|
||||
)
|
||||
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return EuroSAT(root, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: EuroSAT) -> None:
|
||||
|
@ -70,9 +52,19 @@ class TestEuroSAT:
|
|||
assert isinstance(ds, ConcatDataset)
|
||||
assert len(ds) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: EuroSAT) -> None:
|
||||
EuroSAT(root=dataset.root, download=True)
|
||||
def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None:
|
||||
EuroSAT(root=str(tmp_path), download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: EuroSAT, tmp_path: Path
|
||||
) -> None:
|
||||
shutil.rmtree(dataset.root)
|
||||
download_url(dataset.url, root=str(tmp_path))
|
||||
EuroSAT(root=str(tmp_path), download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
EuroSAT(str(tmp_path))
|
||||
|
|
|
@ -8,15 +8,15 @@ from typing import Generator
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import UCMerced
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -28,33 +28,14 @@ class TestUCMerced:
|
|||
tmp_path: Path,
|
||||
) -> UCMerced:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.utils, "download_url", download_url
|
||||
torchgeo.datasets.ucmerced, "download_url", download_url
|
||||
)
|
||||
md5 = "95e710774f3ef6d9ecb0cd42e4d0fc23"
|
||||
monkeypatch.setattr(UCMerced, "md5", md5) # type: ignore[attr-defined]
|
||||
url = os.path.join("tests", "data", "ucmerced", "UCMerced_LandUse.zip")
|
||||
monkeypatch.setattr(UCMerced, "url", url) # type: ignore[attr-defined]
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
UCMerced,
|
||||
"classes",
|
||||
[
|
||||
"agricultural",
|
||||
"airplane",
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
UCMerced,
|
||||
"class_counts",
|
||||
{
|
||||
"agricultural": 1,
|
||||
"airplane": 1,
|
||||
},
|
||||
)
|
||||
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return UCMerced(root, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: UCMerced) -> None:
|
||||
|
@ -71,9 +52,19 @@ class TestUCMerced:
|
|||
assert isinstance(ds, ConcatDataset)
|
||||
assert len(ds) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: UCMerced) -> None:
|
||||
UCMerced(root=dataset.root, download=True)
|
||||
def test_already_downloaded(self, dataset: UCMerced, tmp_path: Path) -> None:
|
||||
UCMerced(root=str(tmp_path), download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: UCMerced, tmp_path: Path
|
||||
) -> None:
|
||||
shutil.rmtree(dataset.root)
|
||||
download_url(dataset.url, root=str(tmp_path))
|
||||
UCMerced(root=str(tmp_path), download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
UCMerced(str(tmp_path))
|
||||
|
|
Загрузка…
Ссылка в новой задаче