2021-08-31 18:23:10 +03:00
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
# Licensed under the MIT License.
|
|
|
|
|
2021-07-25 22:40:39 +03:00
|
|
|
import builtins
|
|
|
|
import glob
|
2021-09-16 19:07:09 +03:00
|
|
|
import math
|
2021-07-25 22:40:39 +03:00
|
|
|
import os
|
2021-08-18 23:15:52 +03:00
|
|
|
import pickle
|
2021-07-25 22:40:39 +03:00
|
|
|
import shutil
|
|
|
|
import sys
|
2021-09-16 19:07:09 +03:00
|
|
|
from datetime import datetime
|
2021-06-19 01:10:38 +03:00
|
|
|
from pathlib import Path
|
2021-07-25 22:40:39 +03:00
|
|
|
from typing import Any, Generator, Tuple
|
2021-06-17 23:50:17 +03:00
|
|
|
|
2021-07-14 02:23:26 +03:00
|
|
|
import pytest
|
2021-07-13 23:45:21 +03:00
|
|
|
import torch
|
2021-08-11 19:07:27 +03:00
|
|
|
from _pytest.monkeypatch import MonkeyPatch
|
2021-07-16 20:38:53 +03:00
|
|
|
from rasterio.crs import CRS
|
2021-10-12 01:35:38 +03:00
|
|
|
from torch.utils.data import TensorDataset
|
2021-07-16 22:09:32 +03:00
|
|
|
|
2021-07-25 22:40:39 +03:00
|
|
|
import torchgeo.datasets.utils
|
|
|
|
from torchgeo.datasets.utils import (
|
|
|
|
BoundingBox,
|
|
|
|
collate_dict,
|
2021-10-12 01:35:38 +03:00
|
|
|
dataset_split,
|
2021-09-16 19:07:09 +03:00
|
|
|
disambiguate_timestamp,
|
2021-07-25 22:40:39 +03:00
|
|
|
download_and_extract_archive,
|
2021-09-20 10:23:42 +03:00
|
|
|
download_radiant_mlhub_collection,
|
|
|
|
download_radiant_mlhub_dataset,
|
2021-07-25 22:40:39 +03:00
|
|
|
extract_archive,
|
|
|
|
working_dir,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def mock_missing_module(monkeypatch: Generator[MonkeyPatch, None, None]) -> None:
|
|
|
|
import_orig = builtins.__import__
|
|
|
|
|
|
|
|
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
|
|
|
if name in ["rarfile", "radiant_mlhub"]:
|
|
|
|
raise ImportError()
|
|
|
|
return import_orig(name, *args, **kwargs)
|
|
|
|
|
|
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
|
|
builtins, "__import__", mocked_import
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class Dataset:
|
|
|
|
def download(self, output_dir: str, **kwargs: str) -> None:
|
|
|
|
glob_path = os.path.join(
|
|
|
|
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
|
|
|
|
)
|
|
|
|
for tarball in glob.iglob(glob_path):
|
|
|
|
shutil.copy(tarball, output_dir)
|
|
|
|
|
|
|
|
|
2021-09-20 10:23:42 +03:00
|
|
|
class Collection:
|
|
|
|
def download(self, output_dir: str, **kwargs: str) -> None:
|
|
|
|
glob_path = os.path.join(
|
|
|
|
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
|
|
|
|
)
|
|
|
|
for tarball in glob.iglob(glob_path):
|
|
|
|
shutil.copy(tarball, output_dir)
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
|
2021-07-25 22:40:39 +03:00
|
|
|
return Dataset()
|
|
|
|
|
|
|
|
|
2021-09-20 10:23:42 +03:00
|
|
|
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
|
|
|
|
return Collection()
|
|
|
|
|
|
|
|
|
2021-07-25 22:40:39 +03:00
|
|
|
def download_url(url: str, root: str, *args: str) -> None:
|
|
|
|
shutil.copy(url, root)
|
|
|
|
|
|
|
|
|
|
|
|
def test_mock_missing_module(mock_missing_module: None) -> None:
|
|
|
|
import sys # noqa: F401
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: figure out how to install unrar on Windows in GitHub Actions
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="requires unrar executable")
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"src",
|
|
|
|
[
|
|
|
|
os.path.join("cowc_detection", "COWC_Detection_Columbus_CSUAV_AFRL.tbz"),
|
|
|
|
os.path.join("cowc_detection", "COWC_test_list_detection.txt.bz2"),
|
|
|
|
os.path.join("vhr10", "NWPU VHR-10 dataset.rar"),
|
|
|
|
os.path.join("landcoverai", "landcover.ai.v1.zip"),
|
|
|
|
os.path.join("sen12ms", "ROIs1158_spring_lc.tar.gz"),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_extract_archive(src: str, tmp_path: Path) -> None:
|
2021-12-18 03:28:57 +03:00
|
|
|
pytest.importorskip("rarfile", minversion="3")
|
2021-07-25 22:40:39 +03:00
|
|
|
extract_archive(os.path.join("tests", "data", src), str(tmp_path))
|
|
|
|
|
|
|
|
|
|
|
|
def test_missing_rarfile(mock_missing_module: None) -> None:
|
|
|
|
with pytest.raises(
|
|
|
|
ImportError,
|
|
|
|
match="rarfile is not installed and is required to extract this dataset",
|
|
|
|
):
|
|
|
|
extract_archive(
|
|
|
|
os.path.join("tests", "data", "vhr10", "NWPU VHR-10 dataset.rar")
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def test_unsupported_scheme() -> None:
|
|
|
|
with pytest.raises(
|
|
|
|
RuntimeError, match="src file has unknown archival/compression scheme"
|
|
|
|
):
|
|
|
|
extract_archive("foo.bar")
|
|
|
|
|
|
|
|
|
|
|
|
def test_download_and_extract_archive(
|
|
|
|
tmp_path: Path, monkeypatch: Generator[MonkeyPatch, None, None]
|
|
|
|
) -> None:
|
|
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
|
|
torchgeo.datasets.utils, "download_url", download_url
|
|
|
|
)
|
|
|
|
download_and_extract_archive(
|
|
|
|
os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip"),
|
|
|
|
str(tmp_path),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2021-09-20 10:23:42 +03:00
|
|
|
def test_download_radiant_mlhub_dataset(
|
2021-07-25 22:40:39 +03:00
|
|
|
tmp_path: Path, monkeypatch: Generator[MonkeyPatch, None, None]
|
|
|
|
) -> None:
|
|
|
|
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
|
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
2021-09-20 10:23:42 +03:00
|
|
|
radiant_mlhub.Dataset, "fetch", fetch_dataset
|
2021-07-25 22:40:39 +03:00
|
|
|
)
|
2021-09-20 10:23:42 +03:00
|
|
|
download_radiant_mlhub_dataset("", str(tmp_path))
|
|
|
|
|
|
|
|
|
|
|
|
def test_download_radiant_mlhub_collection(
|
|
|
|
tmp_path: Path, monkeypatch: Generator[MonkeyPatch, None, None]
|
|
|
|
) -> None:
|
|
|
|
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
|
|
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
|
|
|
radiant_mlhub.Collection, "fetch", fetch_collection
|
|
|
|
)
|
|
|
|
download_radiant_mlhub_collection("", str(tmp_path))
|
2021-07-25 22:40:39 +03:00
|
|
|
|
|
|
|
|
|
|
|
def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
|
|
|
|
with pytest.raises(
|
|
|
|
ImportError,
|
|
|
|
match="radiant_mlhub is not installed and is required to download this dataset",
|
|
|
|
):
|
2021-09-20 10:23:42 +03:00
|
|
|
download_radiant_mlhub_dataset("", "")
|
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
ImportError,
|
|
|
|
match="radiant_mlhub is not installed and is required to download this"
|
|
|
|
+ " collection",
|
|
|
|
):
|
|
|
|
download_radiant_mlhub_collection("", "")
|
2021-06-17 23:50:17 +03:00
|
|
|
|
|
|
|
|
2021-07-14 02:23:26 +03:00
|
|
|
class TestBoundingBox:
|
2021-07-14 20:48:49 +03:00
|
|
|
def test_new_init(self) -> None:
|
2021-07-14 02:23:26 +03:00
|
|
|
bbox = BoundingBox(0, 1, 2, 3, 4, 5)
|
2021-07-14 20:48:49 +03:00
|
|
|
|
2021-07-14 02:23:26 +03:00
|
|
|
assert bbox.minx == 0
|
|
|
|
assert bbox.maxx == 1
|
|
|
|
assert bbox.miny == 2
|
|
|
|
assert bbox.maxy == 3
|
|
|
|
assert bbox.mint == 4
|
|
|
|
assert bbox.maxt == 5
|
|
|
|
|
2021-07-14 20:48:49 +03:00
|
|
|
assert bbox[0] == 0
|
|
|
|
assert bbox[-1] == 5
|
|
|
|
assert bbox[1:3] == (1, 2)
|
|
|
|
|
|
|
|
def test_repr_str(self) -> None:
|
|
|
|
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)
|
|
|
|
expected = "BoundingBox(minx=0, maxx=1, miny=2.0, maxy=3.0, mint=-5, maxt=-4)"
|
|
|
|
assert repr(bbox) == expected
|
|
|
|
assert str(bbox) == expected
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"test_input,expected",
|
|
|
|
[
|
2021-07-16 23:19:36 +03:00
|
|
|
# Same box
|
|
|
|
((0, 1, 0, 1, 0, 1), True),
|
|
|
|
((0.0, 1.0, 0.0, 1.0, 0.0, 1.0), True),
|
2021-07-14 20:48:49 +03:00
|
|
|
# bbox1 strictly within bbox2
|
|
|
|
((-1, 2, -1, 2, -1, 2), True),
|
|
|
|
# bbox2 strictly within bbox1
|
|
|
|
((0.25, 0.75, 0.25, 0.75, 0.25, 0.75), True),
|
|
|
|
# One corner of bbox1 within bbox2
|
|
|
|
((0.5, 1.5, 0.5, 1.5, 0.5, 1.5), True),
|
|
|
|
((0.5, 1.5, -0.5, 0.5, 0.5, 1.5), True),
|
|
|
|
((0.5, 1.5, 0.5, 1.5, -0.5, 0.5), True),
|
|
|
|
((0.5, 1.5, -0.5, 0.5, -0.5, 0.5), True),
|
|
|
|
((-0.5, 0.5, 0.5, 1.5, 0.5, 1.5), True),
|
|
|
|
((-0.5, 0.5, -0.5, 0.5, 0.5, 1.5), True),
|
|
|
|
((-0.5, 0.5, 0.5, 1.5, -0.5, 0.5), True),
|
|
|
|
((-0.5, 0.5, -0.5, 0.5, -0.5, 0.5), True),
|
|
|
|
# No overlap
|
|
|
|
((0.5, 1.5, 0.5, 1.5, 2, 3), False),
|
|
|
|
((0.5, 1.5, 2, 3, 0.5, 1.5), False),
|
|
|
|
((2, 3, 0.5, 1.5, 0.5, 1.5), False),
|
|
|
|
((2, 3, 2, 3, 2, 3), False),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_intersects(
|
|
|
|
self,
|
|
|
|
test_input: Tuple[float, float, float, float, float, float],
|
|
|
|
expected: bool,
|
|
|
|
) -> None:
|
|
|
|
bbox1 = BoundingBox(0, 1, 0, 1, 0, 1)
|
|
|
|
bbox2 = BoundingBox(*test_input)
|
|
|
|
assert bbox1.intersects(bbox2) == bbox2.intersects(bbox1) == expected
|
|
|
|
|
2021-08-18 23:15:52 +03:00
|
|
|
def test_picklable(self) -> None:
|
|
|
|
bbox = BoundingBox(0, 1, 2, 3, 4, 5)
|
|
|
|
x = pickle.dumps(bbox)
|
|
|
|
y = pickle.loads(x)
|
|
|
|
assert bbox == y
|
|
|
|
|
2021-07-14 17:09:47 +03:00
|
|
|
def test_invalid_x(self) -> None:
|
2021-07-14 20:48:49 +03:00
|
|
|
with pytest.raises(
|
|
|
|
ValueError, match="Bounding box is invalid: 'minx=1' > 'maxx=0'"
|
|
|
|
):
|
2021-07-14 02:23:26 +03:00
|
|
|
BoundingBox(1, 0, 2, 3, 4, 5)
|
2021-07-14 17:09:47 +03:00
|
|
|
|
|
|
|
def test_invalid_y(self) -> None:
|
2021-07-14 20:48:49 +03:00
|
|
|
with pytest.raises(
|
|
|
|
ValueError, match="Bounding box is invalid: 'miny=3' > 'maxy=2'"
|
|
|
|
):
|
2021-07-14 02:23:26 +03:00
|
|
|
BoundingBox(0, 1, 3, 2, 4, 5)
|
2021-07-14 17:09:47 +03:00
|
|
|
|
|
|
|
def test_invalid_t(self) -> None:
|
2021-07-14 20:48:49 +03:00
|
|
|
with pytest.raises(
|
|
|
|
ValueError, match="Bounding box is invalid: 'mint=5' > 'maxt=4'"
|
|
|
|
):
|
2021-07-14 02:23:26 +03:00
|
|
|
BoundingBox(0, 1, 2, 3, 5, 4)
|
2021-07-13 23:45:21 +03:00
|
|
|
|
|
|
|
|
2021-09-16 19:07:09 +03:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"date_string,format,min_datetime,max_datetime",
|
|
|
|
[
|
|
|
|
("", "", 0, sys.maxsize),
|
|
|
|
(
|
|
|
|
"2021",
|
|
|
|
"%Y",
|
|
|
|
datetime(2021, 1, 1, 0, 0, 0, 0).timestamp(),
|
|
|
|
datetime(2021, 12, 31, 23, 59, 59, 999999).timestamp(),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"2021-09",
|
|
|
|
"%Y-%m",
|
|
|
|
datetime(2021, 9, 1, 0, 0, 0, 0).timestamp(),
|
|
|
|
datetime(2021, 9, 30, 23, 59, 59, 999999).timestamp(),
|
|
|
|
),
|
2021-10-13 02:02:49 +03:00
|
|
|
(
|
|
|
|
"Dec 21",
|
|
|
|
"%b %y",
|
|
|
|
datetime(2021, 12, 1, 0, 0, 0, 0).timestamp(),
|
|
|
|
datetime(2021, 12, 31, 23, 59, 59, 999999).timestamp(),
|
|
|
|
),
|
2021-09-16 19:07:09 +03:00
|
|
|
(
|
|
|
|
"2021-09-13",
|
|
|
|
"%Y-%m-%d",
|
|
|
|
datetime(2021, 9, 13, 0, 0, 0, 0).timestamp(),
|
|
|
|
datetime(2021, 9, 13, 23, 59, 59, 999999).timestamp(),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"2021-09-13 17",
|
|
|
|
"%Y-%m-%d %H",
|
|
|
|
datetime(2021, 9, 13, 17, 0, 0, 0).timestamp(),
|
|
|
|
datetime(2021, 9, 13, 17, 59, 59, 999999).timestamp(),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"2021-09-13 17:21",
|
|
|
|
"%Y-%m-%d %H:%M",
|
|
|
|
datetime(2021, 9, 13, 17, 21, 0, 0).timestamp(),
|
|
|
|
datetime(2021, 9, 13, 17, 21, 59, 999999).timestamp(),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"2021-09-13 17:21:53",
|
|
|
|
"%Y-%m-%d %H:%M:%S",
|
|
|
|
datetime(2021, 9, 13, 17, 21, 53, 0).timestamp(),
|
|
|
|
datetime(2021, 9, 13, 17, 21, 53, 999999).timestamp(),
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"2021-09-13 17:21:53:000123",
|
|
|
|
"%Y-%m-%d %H:%M:%S:%f",
|
|
|
|
datetime(2021, 9, 13, 17, 21, 53, 123).timestamp(),
|
|
|
|
datetime(2021, 9, 13, 17, 21, 53, 123).timestamp(),
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_disambiguate_timestamp(
|
|
|
|
date_string: str, format: str, min_datetime: float, max_datetime: float
|
|
|
|
) -> None:
|
|
|
|
mint, maxt = disambiguate_timestamp(date_string, format)
|
|
|
|
assert math.isclose(mint, min_datetime)
|
|
|
|
assert math.isclose(maxt, max_datetime)
|
|
|
|
|
|
|
|
|
2021-07-13 23:45:21 +03:00
|
|
|
def test_collate_dict() -> None:
|
|
|
|
samples = [
|
|
|
|
{
|
|
|
|
"foo": torch.tensor(1), # type: ignore[attr-defined]
|
|
|
|
"bar": torch.tensor(2), # type: ignore[attr-defined]
|
2021-07-16 20:38:53 +03:00
|
|
|
"crs": CRS.from_epsg(3005),
|
2021-07-13 23:45:21 +03:00
|
|
|
},
|
|
|
|
{
|
|
|
|
"foo": torch.tensor(3), # type: ignore[attr-defined]
|
|
|
|
"bar": torch.tensor(4), # type: ignore[attr-defined]
|
2021-07-16 20:38:53 +03:00
|
|
|
"crs": CRS.from_epsg(3005),
|
2021-07-13 23:45:21 +03:00
|
|
|
},
|
|
|
|
]
|
|
|
|
sample = collate_dict(samples)
|
|
|
|
assert torch.allclose( # type: ignore[attr-defined]
|
|
|
|
sample["foo"], torch.tensor([1, 3]) # type: ignore[attr-defined]
|
|
|
|
)
|
|
|
|
assert torch.allclose( # type: ignore[attr-defined]
|
|
|
|
sample["bar"], torch.tensor([2, 4]) # type: ignore[attr-defined]
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2021-06-19 01:10:38 +03:00
|
|
|
def test_existing_directory(tmp_path: Path) -> None:
|
2021-06-17 23:50:17 +03:00
|
|
|
subdir = tmp_path / "foo" / "bar"
|
|
|
|
subdir.mkdir(parents=True)
|
|
|
|
|
|
|
|
assert subdir.exists()
|
|
|
|
|
|
|
|
with working_dir(str(subdir)):
|
|
|
|
assert subdir.cwd() == subdir
|
|
|
|
|
|
|
|
|
2021-06-19 01:10:38 +03:00
|
|
|
def test_nonexisting_directory(tmp_path: Path) -> None:
|
2021-06-17 23:50:17 +03:00
|
|
|
subdir = tmp_path / "foo" / "bar"
|
|
|
|
|
|
|
|
assert not subdir.exists()
|
|
|
|
|
|
|
|
with working_dir(str(subdir), create=True):
|
|
|
|
assert subdir.cwd() == subdir
|
2021-10-12 01:35:38 +03:00
|
|
|
|
|
|
|
|
|
|
|
def test_dataset_split() -> None:
|
|
|
|
num_samples = 24
|
|
|
|
x = torch.ones(num_samples, 5) # type: ignore[attr-defined]
|
|
|
|
y = torch.randint(low=0, high=2, size=(num_samples,)) # type: ignore[attr-defined]
|
|
|
|
ds = TensorDataset(x, y)
|
|
|
|
|
|
|
|
# Test only train/val set split
|
|
|
|
train_ds, val_ds = dataset_split(ds, val_pct=1 / 2)
|
|
|
|
assert len(train_ds) == num_samples // 2
|
|
|
|
assert len(val_ds) == num_samples // 2
|
|
|
|
|
|
|
|
# Test train/val/test set split
|
|
|
|
train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3)
|
|
|
|
assert len(train_ds) == num_samples // 3
|
|
|
|
assert len(val_ds) == num_samples // 3
|
|
|
|
assert len(test_ds) == num_samples // 3
|