* updated docs

* added sample data for tests

* added unit tests

* added dataset

* updated tests to not run on windows due to rar
This commit is contained in:
isaac 2021-09-12 10:50:15 -05:00 коммит произвёл GitHub
Родитель 455ea7e24b
Коммит f60cbee39b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 272 добавлений и 0 удалений

Просмотреть файл

@ -112,6 +112,11 @@ PatternNet
.. autoclass:: PatternNet
RESISC45 (Remote Sensing Image Scene Classification)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: RESISC45
SEN12MS
^^^^^^^

Двоичные данные
tests/data/resisc45/NWPU-RESISC45.rar Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,57 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import shutil
import sys
from pathlib import Path
from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import RESISC45
pytest.importorskip("rarfile")
def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
@pytest.mark.skipif(sys.platform == "win32", reason="requires unrar executable")
class TestRESISC45:
@pytest.fixture(params=["train", "test"])
def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
) -> RESISC45:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
)
md5 = "9c221122164d17b8118d2b6527ee5e9c"
monkeypatch.setattr(RESISC45, "md5", md5) # type: ignore[attr-defined]
url = os.path.join("tests", "data", "resisc45", "NWPU-RESISC45.rar")
monkeypatch.setattr(RESISC45, "url", url) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return RESISC45(root, transforms, download=True, checksum=True)
def test_getitem(self, dataset: RESISC45) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["image"].shape[0] == 3
def test_len(self, dataset: RESISC45) -> None:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: RESISC45) -> None:
RESISC45(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
RESISC45(str(tmp_path))

Просмотреть файл

@ -43,6 +43,7 @@ from .levircd import LEVIRCDPlus
from .naip import NAIP
from .nwpu import VHR10
from .patternnet import PatternNet
from .resisc45 import RESISC45
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
@ -88,6 +89,7 @@ __all__ = (
"LandCoverAI",
"LEVIRCDPlus",
"PatternNet",
"RESISC45",
"SEN12MS",
"So2Sat",
"TropicalCycloneWindEstimation",

Просмотреть файл

@ -0,0 +1,208 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""RESISC45 dataset."""
import os
from typing import Callable, Dict, Optional, Tuple
import numpy as np
import torch
from torch import Tensor
from torchvision.datasets import ImageFolder
from .geo import VisionDataset
from .utils import download_and_extract_archive
class RESISC45(VisionDataset, ImageFolder): # type: ignore[misc]
"""RESISC45 dataset.
The `RESISC45 <http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html>`_
dataset is a dataset for remote sensing image scene classification.
Dataset features:
* 31,500 images with 0.2-30 m per pixel resolution (256x256 px)
* three spectral bands - RGB
* 45 scene classes, 700 images per class
* images extracted from Google Earth from over 100 countries
* images conditions with high variability (resolution, weather, illumination)
Dataset format:
* images are three-channel jpgs
Dataset classes:
0. airplane
1. airport
2. baseball_diamond
3. basketball_court
4. beach
5. bridge
6. chaparral
7. church
8. circular_farmland
9. cloud
10. commercial_area
11. dense_residential
12. desert
13. forest
14. freeway
15. golf_course
16. ground_track_field
17. harbor
18. industrial_area
19. intersection
20. island
21. lake
22. meadow
23. medium_residential
24. mobile_home_park
25. mountain
26. overpass
27. palace
28. parking_lot
29. railway
30. railway_station
31. rectangular_farmland
32. river
33. roundabout
34. runway
35. sea_ice
36. ship
37. snowberg
38. sparse_residential
39. stadium
40. storage_tank
41. tennis_court
42. terrace
43. thermal_power_station
44. wetland
If you use this dataset in your research, please cite the following paper:
* https://doi.org/10.1109/jproc.2017.2675998
"""
url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv"
md5 = "d824acb73957502b00efd559fc6cfbbb"
filename = "NWPU-RESISC45.rar"
directory = "NWPU-RESISC45"
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new PatternNet dataset instance.
Args:
root: root directory where dataset can be found
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
self.root = root
self.checksum = checksum
if download:
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
# When transform & target_transform are None, ImageFolder.__getitem__[index]
# returns a PIL.Image and int for image and label, respectively
super().__init__(
root=os.path.join(root, self.directory),
transform=None,
target_transform=None,
)
# Must be set after calling super().__init__()
self.transforms = transforms
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
image, label = self._load_image(index)
sample = {"image": image, "label": label}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.imgs)
def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
"""Load a single image and it's class label.
Args:
index: index to return
Returns:
the image
the image class label
"""
img, label = ImageFolder.__getitem__(self, index)
array = np.array(img)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
label = torch.tensor(label) # type: ignore[attr-defined]
return tensor, label
def _check_integrity(self) -> bool:
"""Checks the integrity of the dataset structure.
Returns:
True if the dataset directories and split files are found, else False
"""
filepath = os.path.join(self.root, self.directory)
if not os.path.exists(filepath):
return False
return True
def _download(self) -> None:
"""Download the dataset and extract it.
Raises:
AssertionError: if the checksum of split.py does not match
"""
if self._check_integrity():
print("Files already downloaded and verified")
return
download_and_extract_archive(
self.url,
self.root,
filename=self.filename,
md5=self.md5 if self.checksum else None,
)