зеркало из https://github.com/microsoft/torchgeo.git
89 строки
2.9 KiB
Python
89 строки
2.9 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
|
|
import glob
|
|
import os
|
|
import shutil
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Generator
|
|
|
|
import matplotlib.pyplot as plt
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from _pytest.monkeypatch import MonkeyPatch
|
|
from rasterio.crs import CRS
|
|
|
|
import torchgeo.datasets.utils
|
|
from torchgeo.datasets import CDL, BoundingBox, ZipDataset
|
|
|
|
|
|
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
|
shutil.copy(url, root)
|
|
|
|
|
|
class TestCDL:
|
|
@pytest.fixture
|
|
def dataset(
|
|
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
|
|
) -> CDL:
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
torchgeo.datasets.cdl, "download_url", download_url
|
|
)
|
|
md5s = [
|
|
(2021, "0693f0bb10deb79c69bcafe4aa1635b7"),
|
|
(2020, "7695292902a8672d16ac034d4d560d84"),
|
|
]
|
|
monkeypatch.setattr(CDL, "md5s", md5s) # type: ignore[attr-defined]
|
|
url = os.path.join("tests", "data", "cdl", "{}_30m_cdls.zip")
|
|
monkeypatch.setattr(CDL, "url", url) # type: ignore[attr-defined]
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
plt, "show", lambda *args: None
|
|
)
|
|
root = str(tmp_path)
|
|
transforms = nn.Identity() # type: ignore[attr-defined]
|
|
return CDL(root, transforms=transforms, download=True, checksum=True)
|
|
|
|
def test_getitem(self, dataset: CDL) -> None:
|
|
x = dataset[dataset.bounds]
|
|
assert isinstance(x, dict)
|
|
assert isinstance(x["crs"], CRS)
|
|
assert isinstance(x["mask"], torch.Tensor)
|
|
|
|
def test_add(self, dataset: CDL) -> None:
|
|
ds = dataset + dataset
|
|
assert isinstance(ds, ZipDataset)
|
|
|
|
def test_full_year(self, dataset: CDL) -> None:
|
|
bbox = dataset.bounds
|
|
time = datetime(2021, 6, 1).timestamp()
|
|
query = BoundingBox(bbox.minx, bbox.maxx, bbox.miny, bbox.maxy, time, time)
|
|
next(dataset.index.intersection(query))
|
|
|
|
def test_already_extracted(self, dataset: CDL) -> None:
|
|
CDL(root=dataset.root, download=True)
|
|
|
|
def test_already_downloaded(self, tmp_path: Path) -> None:
|
|
pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip")
|
|
root = str(tmp_path)
|
|
for zipfile in glob.iglob(pathname):
|
|
shutil.copy(zipfile, root)
|
|
CDL(root)
|
|
|
|
def test_plot(self, dataset: CDL) -> None:
|
|
query = dataset.bounds
|
|
x = dataset[query]
|
|
dataset.plot(x["mask"])
|
|
|
|
def test_not_downloaded(self, tmp_path: Path) -> None:
|
|
with pytest.raises(RuntimeError, match="Dataset not found"):
|
|
CDL(str(tmp_path))
|
|
|
|
def test_invalid_query(self, dataset: CDL) -> None:
|
|
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
|
with pytest.raises(
|
|
IndexError, match="query: .* not found in index with bounds:"
|
|
):
|
|
dataset[query]
|