2021-08-31 18:23:10 +03:00
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
2021-12-18 03:28:57 +03:00
|
|
|
import builtins
|
2021-07-07 21:53:12 +03:00
|
|
|
import os
|
|
|
|
from pathlib import Path
|
2021-12-18 03:28:57 +03:00
|
|
|
from typing import Any, Generator
|
2021-07-07 21:53:12 +03:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
2021-10-16 07:59:17 +03:00
|
|
|
import torch.nn as nn
|
2021-07-07 21:53:12 +03:00
|
|
|
from _pytest.fixtures import SubRequest
|
2021-08-11 19:07:27 +03:00
|
|
|
from _pytest.monkeypatch import MonkeyPatch
|
2021-07-07 21:53:12 +03:00
|
|
|
|
2021-11-05 01:15:29 +03:00
|
|
|
from torchgeo.datasets import So2Sat, So2SatDataModule
|
2021-07-07 21:53:12 +03:00
|
|
|
|
2021-12-18 03:28:57 +03:00
|
|
|
pytest.importorskip("h5py")
|
|
|
|
|
2021-07-07 21:53:12 +03:00
|
|
|
|
|
|
|
class TestSo2Sat:
|
|
|
|
@pytest.fixture(params=["train", "validation", "test"])
|
|
|
|
def dataset(
|
|
|
|
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
|
|
|
|
) -> So2Sat:
|
|
|
|
md5s = {
|
|
|
|
"train": "086c5fa964a401d4194d09ab161c39f1",
|
|
|
|
"validation": "dd864f1af0cd495af99d7de80103f49e",
|
|
|
|
"test": "320102c5c15f3cee7691f203824028ce",
|
|
|
|
}
|
|
|
|
|
|
|
|
monkeypatch.setattr(So2Sat, "md5s", md5s) # type: ignore[attr-defined]
|
2021-08-04 23:10:26 +03:00
|
|
|
root = os.path.join("tests", "data", "so2sat")
|
2021-07-07 21:53:12 +03:00
|
|
|
split = request.param
|
2021-10-16 07:59:17 +03:00
|
|
|
transforms = nn.Identity() # type: ignore[attr-defined]
|
2021-07-07 21:53:12 +03:00
|
|
|
return So2Sat(root, split, transforms, checksum=True)
|
|
|
|
|
2021-12-18 03:28:57 +03:00
|
|
|
@pytest.fixture
|
|
|
|
def mock_missing_module(
|
|
|
|
self, monkeypatch: Generator[MonkeyPatch, None, None]
|
|
|
|
) -> None:
|
|
|
|
import_orig = builtins.__import__
|
|
|
|
|
|
|
|
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
|
|
|
if name == "h5py":
|
|
|
|
raise ImportError()
|
|
|
|
return import_orig(name, *args, **kwargs)
|
|
|
|
|
|
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
|
|
builtins, "__import__", mocked_import
|
|
|
|
)
|
|
|
|
|
2021-07-07 21:53:12 +03:00
|
|
|
def test_getitem(self, dataset: So2Sat) -> None:
|
|
|
|
x = dataset[0]
|
|
|
|
assert isinstance(x, dict)
|
|
|
|
assert isinstance(x["image"], torch.Tensor)
|
|
|
|
assert isinstance(x["label"], int)
|
|
|
|
|
|
|
|
def test_len(self, dataset: So2Sat) -> None:
|
|
|
|
assert len(dataset) == 10
|
|
|
|
|
|
|
|
def test_out_of_bounds(self, dataset: So2Sat) -> None:
|
2021-07-20 23:34:01 +03:00
|
|
|
# h5py at version 2.10.0 raises a ValueError instead of an IndexError so we
|
|
|
|
# check for both here
|
|
|
|
with pytest.raises((IndexError, ValueError)):
|
2021-07-07 21:53:12 +03:00
|
|
|
dataset[10]
|
|
|
|
|
|
|
|
def test_invalid_split(self) -> None:
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
So2Sat(split="foo")
|
|
|
|
|
|
|
|
def test_not_downloaded(self, tmp_path: Path) -> None:
|
|
|
|
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
|
|
|
So2Sat(str(tmp_path))
|
2021-11-05 01:15:29 +03:00
|
|
|
|
2021-12-19 22:29:38 +03:00
|
|
|
def test_mock_missing_module(
|
|
|
|
self, dataset: So2Sat, mock_missing_module: None
|
|
|
|
) -> None:
|
|
|
|
with pytest.raises(
|
|
|
|
ImportError,
|
|
|
|
match="h5py is not installed and is required to use this dataset",
|
|
|
|
):
|
|
|
|
So2Sat(dataset.root)
|
|
|
|
|
2021-11-05 01:15:29 +03:00
|
|
|
|
|
|
|
class TestSo2SatDataModule:
|
|
|
|
@pytest.fixture(scope="class", params=zip([True, False], ["rgb", "s2"]))
|
|
|
|
def datamodule(self, request: SubRequest) -> So2SatDataModule:
|
|
|
|
unsupervised_mode, bands = request.param
|
|
|
|
root = os.path.join("tests", "data", "so2sat")
|
|
|
|
batch_size = 2
|
|
|
|
num_workers = 0
|
|
|
|
dm = So2SatDataModule(root, batch_size, num_workers, bands, unsupervised_mode)
|
|
|
|
dm.prepare_data()
|
|
|
|
dm.setup()
|
|
|
|
return dm
|
|
|
|
|
|
|
|
def test_train_dataloader(self, datamodule: So2SatDataModule) -> None:
|
|
|
|
next(iter(datamodule.train_dataloader()))
|
|
|
|
|
|
|
|
def test_val_dataloader(self, datamodule: So2SatDataModule) -> None:
|
|
|
|
next(iter(datamodule.val_dataloader()))
|
|
|
|
|
|
|
|
def test_test_dataloader(self, datamodule: So2SatDataModule) -> None:
|
|
|
|
next(iter(datamodule.test_dataloader()))
|