зеркало из https://github.com/microsoft/torchgeo.git
77 строки
2.5 KiB
Python
77 строки
2.5 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Generator
|
|
|
|
import matplotlib.pyplot as plt
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from _pytest.monkeypatch import MonkeyPatch
|
|
from rasterio.crs import CRS
|
|
|
|
from torchgeo.datasets import NAIP, BoundingBox, NAIPChesapeakeDataModule, ZipDataset
|
|
|
|
|
|
class TestNAIP:
|
|
@pytest.fixture
|
|
def dataset(self, monkeypatch: Generator[MonkeyPatch, None, None]) -> NAIP:
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
plt, "show", lambda *args: None
|
|
)
|
|
root = os.path.join("tests", "data", "naip")
|
|
transforms = nn.Identity() # type: ignore[attr-defined]
|
|
return NAIP(root, transforms=transforms)
|
|
|
|
def test_getitem(self, dataset: NAIP) -> None:
|
|
x = dataset[dataset.bounds]
|
|
assert isinstance(x, dict)
|
|
assert isinstance(x["crs"], CRS)
|
|
assert isinstance(x["image"], torch.Tensor)
|
|
|
|
def test_add(self, dataset: NAIP) -> None:
|
|
ds = dataset + dataset
|
|
assert isinstance(ds, ZipDataset)
|
|
|
|
def test_plot(self, dataset: NAIP) -> None:
|
|
query = dataset.bounds
|
|
x = dataset[query]
|
|
dataset.plot(x["image"])
|
|
|
|
def test_no_data(self, tmp_path: Path) -> None:
|
|
with pytest.raises(FileNotFoundError, match="No NAIP data was found in "):
|
|
NAIP(str(tmp_path))
|
|
|
|
def test_invalid_query(self, dataset: NAIP) -> None:
|
|
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
|
with pytest.raises(
|
|
IndexError, match="query: .* not found in index with bounds:"
|
|
):
|
|
dataset[query]
|
|
|
|
|
|
class TestNAIPChesapeakeDataModule:
|
|
@pytest.fixture(scope="class")
|
|
def datamodule(self) -> NAIPChesapeakeDataModule:
|
|
dm = NAIPChesapeakeDataModule(
|
|
os.path.join("tests", "data", "naip"),
|
|
os.path.join("tests", "data", "chesapeake", "BAYWIDE"),
|
|
batch_size=2,
|
|
num_workers=0,
|
|
)
|
|
dm.patch_size = 32
|
|
dm.prepare_data()
|
|
dm.setup()
|
|
return dm
|
|
|
|
def test_train_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
|
|
next(iter(datamodule.train_dataloader()))
|
|
|
|
def test_val_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
|
|
next(iter(datamodule.val_dataloader()))
|
|
|
|
def test_test_dataloader(self, datamodule: NAIPChesapeakeDataModule) -> None:
|
|
next(iter(datamodule.test_dataloader()))
|