зеркало из https://github.com/microsoft/torchgeo.git
180 строки
6.4 KiB
Python
180 строки
6.4 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Generator
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from _pytest.fixtures import SubRequest
|
|
from _pytest.monkeypatch import MonkeyPatch
|
|
|
|
import torchgeo.datasets.utils
|
|
from torchgeo.datasets import BigEarthNet, BigEarthNetDataModule
|
|
|
|
|
|
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
|
shutil.copy(url, root)
|
|
|
|
|
|
class TestBigEarthNet:
|
|
@pytest.fixture(
|
|
params=zip(["all", "s1", "s2"], [43, 19, 19], ["train", "val", "test"])
|
|
)
|
|
def dataset(
|
|
self,
|
|
monkeypatch: Generator[MonkeyPatch, None, None],
|
|
tmp_path: Path,
|
|
request: SubRequest,
|
|
) -> BigEarthNet:
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
torchgeo.datasets.bigearthnet, "download_url", download_url
|
|
)
|
|
data_dir = os.path.join("tests", "data", "bigearthnet")
|
|
metadata = {
|
|
"s1": {
|
|
"url": os.path.join(data_dir, "BigEarthNet-S1-v1.0.tar.gz"),
|
|
"md5": "5a64e9ce38deb036a435a7b59494924c",
|
|
"filename": "BigEarthNet-S1-v1.0.tar.gz",
|
|
"directory": "BigEarthNet-S1-v1.0",
|
|
},
|
|
"s2": {
|
|
"url": os.path.join(data_dir, "BigEarthNet-S2-v1.0.tar.gz"),
|
|
"md5": "ef5f41129b8308ca178b04d7538dbacf",
|
|
"filename": "BigEarthNet-S2-v1.0.tar.gz",
|
|
"directory": "BigEarthNet-v1.0",
|
|
},
|
|
}
|
|
splits_metadata = {
|
|
"train": {
|
|
"url": os.path.join(data_dir, "bigearthnet-train.csv"),
|
|
"filename": "bigearthnet-train.csv",
|
|
"md5": "167ac4d5de8dde7b5aeaa812f42031e7",
|
|
},
|
|
"val": {
|
|
"url": os.path.join(data_dir, "bigearthnet-val.csv"),
|
|
"filename": "bigearthnet-val.csv",
|
|
"md5": "aff594ba256a52e839a3b5fefeb9ef42",
|
|
},
|
|
"test": {
|
|
"url": os.path.join(data_dir, "bigearthnet-test.csv"),
|
|
"filename": "bigearthnet-test.csv",
|
|
"md5": "851a6bdda484d47f60e121352dcb1bf5",
|
|
},
|
|
}
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
BigEarthNet, "metadata", metadata
|
|
)
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
BigEarthNet, "splits_metadata", splits_metadata
|
|
)
|
|
bands, num_classes, split = request.param
|
|
root = str(tmp_path)
|
|
transforms = nn.Identity() # type: ignore[attr-defined]
|
|
return BigEarthNet(
|
|
root, split, bands, num_classes, transforms, download=True, checksum=True
|
|
)
|
|
|
|
def test_getitem(self, dataset: BigEarthNet) -> None:
|
|
x = dataset[0]
|
|
assert isinstance(x, dict)
|
|
assert isinstance(x["image"], torch.Tensor)
|
|
assert isinstance(x["label"], torch.Tensor)
|
|
assert x["label"].shape == (dataset.num_classes,)
|
|
assert x["image"].dtype == torch.int32 # type: ignore[attr-defined]
|
|
assert x["label"].dtype == torch.int64 # type: ignore[attr-defined]
|
|
|
|
if dataset.bands == "all":
|
|
assert x["image"].shape == (14, 120, 120)
|
|
elif dataset.bands == "s1":
|
|
assert x["image"].shape == (2, 120, 120)
|
|
else:
|
|
assert x["image"].shape == (12, 120, 120)
|
|
|
|
def test_len(self, dataset: BigEarthNet) -> None:
|
|
if dataset.split == "train":
|
|
assert len(dataset) == 2
|
|
elif dataset.split == "val":
|
|
assert len(dataset) == 1
|
|
else:
|
|
assert len(dataset) == 1
|
|
|
|
def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
|
|
BigEarthNet(
|
|
root=str(tmp_path),
|
|
bands=dataset.bands,
|
|
split=dataset.split,
|
|
num_classes=dataset.num_classes,
|
|
download=True,
|
|
)
|
|
|
|
def test_already_downloaded_not_extracted(
|
|
self, dataset: BigEarthNet, tmp_path: Path
|
|
) -> None:
|
|
if dataset.bands == "all":
|
|
shutil.rmtree(
|
|
os.path.join(dataset.root, dataset.metadata["s1"]["directory"])
|
|
)
|
|
shutil.rmtree(
|
|
os.path.join(dataset.root, dataset.metadata["s2"]["directory"])
|
|
)
|
|
download_url(dataset.metadata["s1"]["url"], root=str(tmp_path))
|
|
download_url(dataset.metadata["s2"]["url"], root=str(tmp_path))
|
|
elif dataset.bands == "s1":
|
|
shutil.rmtree(
|
|
os.path.join(dataset.root, dataset.metadata["s1"]["directory"])
|
|
)
|
|
download_url(dataset.metadata["s1"]["url"], root=str(tmp_path))
|
|
else:
|
|
shutil.rmtree(
|
|
os.path.join(dataset.root, dataset.metadata["s2"]["directory"])
|
|
)
|
|
download_url(dataset.metadata["s2"]["url"], root=str(tmp_path))
|
|
|
|
BigEarthNet(
|
|
root=str(tmp_path),
|
|
bands=dataset.bands,
|
|
split=dataset.split,
|
|
num_classes=dataset.num_classes,
|
|
download=False,
|
|
)
|
|
|
|
def test_not_downloaded(self, tmp_path: Path) -> None:
|
|
err = "Dataset not found in `root` directory and `download=False`, "
|
|
"either specify a different `root` directory or use `download=True` "
|
|
"to automaticaly download the dataset."
|
|
with pytest.raises(RuntimeError, match=err):
|
|
BigEarthNet(str(tmp_path))
|
|
|
|
|
|
class TestBigEarthNetDataModule:
|
|
@pytest.fixture(scope="class", params=["s1", "s2", "all"])
|
|
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
|
|
bands = request.param
|
|
root = os.path.join("tests", "data", "bigearthnet")
|
|
num_classes = 19
|
|
batch_size = 1
|
|
num_workers = 0
|
|
dm = BigEarthNetDataModule(
|
|
root,
|
|
bands,
|
|
num_classes,
|
|
batch_size,
|
|
num_workers,
|
|
)
|
|
dm.prepare_data()
|
|
dm.setup()
|
|
return dm
|
|
|
|
def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
|
next(iter(datamodule.train_dataloader()))
|
|
|
|
def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
|
next(iter(datamodule.val_dataloader()))
|
|
|
|
def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
|
|
next(iter(datamodule.test_dataloader()))
|