This commit is contained in:
Isaac Corley 2021-09-27 21:57:35 -05:00 коммит произвёл Caleb Robinson
Родитель 2ff7d776e7
Коммит 1a49cfffbe
2 изменённых файлов: 34 добавлений и 51 удалений

Просмотреть файл

@ -8,15 +8,15 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import EuroSAT
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -28,32 +28,14 @@ class TestEuroSAT:
tmp_path: Path,
) -> EuroSAT:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
torchgeo.datasets.eurosat, "download_url", download_url
)
md5 = "aa051207b0547daba0ac6af57808d68e"
monkeypatch.setattr(EuroSAT, "md5", md5) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "eurosat", "EuroSATallBands.zip")
monkeypatch.setattr(EuroSAT, "url", url) # type: ignore[attr-defined]
monkeypatch.setattr( # type: ignore[attr-defined]
EuroSAT,
"class_counts",
{
"AnnualCrop": 1,
"Forest": 1,
"HerbaceousVegetation": 0,
"Highway": 0,
"Industrial": 0,
"Pasture": 0,
"PermanentCrop": 0,
"Residential": 0,
"River": 0,
"SeaLake": 0,
},
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return EuroSAT(root, transforms, download=True, checksum=True)
def test_getitem(self, dataset: EuroSAT) -> None:
@ -70,9 +52,19 @@ class TestEuroSAT:
assert isinstance(ds, ConcatDataset)
assert len(ds) == 4
def test_already_downloaded(self, dataset: EuroSAT) -> None:
EuroSAT(root=dataset.root, download=True)
def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None:
EuroSAT(root=str(tmp_path), download=True)
def test_already_downloaded_not_extracted(
self, dataset: EuroSAT, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
download_url(dataset.url, root=str(tmp_path))
EuroSAT(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
EuroSAT(str(tmp_path))

Просмотреть файл

@ -8,15 +8,15 @@ from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import UCMerced
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -28,33 +28,14 @@ class TestUCMerced:
tmp_path: Path,
) -> UCMerced:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
torchgeo.datasets.ucmerced, "download_url", download_url
)
md5 = "95e710774f3ef6d9ecb0cd42e4d0fc23"
monkeypatch.setattr(UCMerced, "md5", md5) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "ucmerced", "UCMerced_LandUse.zip")
monkeypatch.setattr(UCMerced, "url", url) # type: ignore[attr-defined]
monkeypatch.setattr( # type: ignore[attr-defined]
UCMerced,
"classes",
[
"agricultural",
"airplane",
],
)
monkeypatch.setattr( # type: ignore[attr-defined]
UCMerced,
"class_counts",
{
"agricultural": 1,
"airplane": 1,
},
)
root = str(tmp_path)
transforms = Identity()
transforms = nn.Identity() # type: ignore[attr-defined]
return UCMerced(root, transforms, download=True, checksum=True)
def test_getitem(self, dataset: UCMerced) -> None:
@ -71,9 +52,19 @@ class TestUCMerced:
assert isinstance(ds, ConcatDataset)
assert len(ds) == 4
def test_already_downloaded(self, dataset: UCMerced) -> None:
UCMerced(root=dataset.root, download=True)
def test_already_downloaded(self, dataset: UCMerced, tmp_path: Path) -> None:
UCMerced(root=str(tmp_path), download=True)
def test_already_downloaded_not_extracted(
self, dataset: UCMerced, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
download_url(dataset.url, root=str(tmp_path))
UCMerced(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
UCMerced(str(tmp_path))