зеркало из https://github.com/microsoft/torchgeo.git
Add RESISC45 Dataset (#126)
* updated docs * added sample data for tests * added unit tests * added dataset * updated tests to not run on windows due to rar
This commit is contained in:
Родитель
455ea7e24b
Коммит
f60cbee39b
|
@ -112,6 +112,11 @@ PatternNet
|
||||||
|
|
||||||
.. autoclass:: PatternNet
|
.. autoclass:: PatternNet
|
||||||
|
|
||||||
|
RESISC45 (Remote Sensing Image Scene Classification)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. autoclass:: RESISC45
|
||||||
|
|
||||||
SEN12MS
|
SEN12MS
|
||||||
^^^^^^^
|
^^^^^^^
|
||||||
|
|
||||||
|
|
Двоичный файл не отображается.
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
|
||||||
|
import torchgeo.datasets.utils
|
||||||
|
from torchgeo.datasets import RESISC45
|
||||||
|
|
||||||
|
pytest.importorskip("rarfile")
|
||||||
|
|
||||||
|
|
||||||
|
def download_url(url: str, root: str, *args: str) -> None:
|
||||||
|
shutil.copy(url, root)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.platform == "win32", reason="requires unrar executable")
|
||||||
|
class TestRESISC45:
|
||||||
|
@pytest.fixture(params=["train", "test"])
|
||||||
|
def dataset(
|
||||||
|
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
|
||||||
|
) -> RESISC45:
|
||||||
|
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||||
|
torchgeo.datasets.utils, "download_url", download_url
|
||||||
|
)
|
||||||
|
md5 = "9c221122164d17b8118d2b6527ee5e9c"
|
||||||
|
monkeypatch.setattr(RESISC45, "md5", md5) # type: ignore[attr-defined]
|
||||||
|
url = os.path.join("tests", "data", "resisc45", "NWPU-RESISC45.rar")
|
||||||
|
monkeypatch.setattr(RESISC45, "url", url) # type: ignore[attr-defined]
|
||||||
|
root = str(tmp_path)
|
||||||
|
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||||
|
return RESISC45(root, transforms, download=True, checksum=True)
|
||||||
|
|
||||||
|
def test_getitem(self, dataset: RESISC45) -> None:
|
||||||
|
x = dataset[0]
|
||||||
|
assert isinstance(x, dict)
|
||||||
|
assert isinstance(x["image"], torch.Tensor)
|
||||||
|
assert isinstance(x["label"], torch.Tensor)
|
||||||
|
assert x["image"].shape[0] == 3
|
||||||
|
|
||||||
|
def test_len(self, dataset: RESISC45) -> None:
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
|
def test_already_downloaded(self, dataset: RESISC45) -> None:
|
||||||
|
RESISC45(root=dataset.root, download=True)
|
||||||
|
|
||||||
|
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||||
|
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||||
|
RESISC45(str(tmp_path))
|
|
@ -43,6 +43,7 @@ from .levircd import LEVIRCDPlus
|
||||||
from .naip import NAIP
|
from .naip import NAIP
|
||||||
from .nwpu import VHR10
|
from .nwpu import VHR10
|
||||||
from .patternnet import PatternNet
|
from .patternnet import PatternNet
|
||||||
|
from .resisc45 import RESISC45
|
||||||
from .sen12ms import SEN12MS
|
from .sen12ms import SEN12MS
|
||||||
from .sentinel import Sentinel, Sentinel2
|
from .sentinel import Sentinel, Sentinel2
|
||||||
from .so2sat import So2Sat
|
from .so2sat import So2Sat
|
||||||
|
@ -88,6 +89,7 @@ __all__ = (
|
||||||
"LandCoverAI",
|
"LandCoverAI",
|
||||||
"LEVIRCDPlus",
|
"LEVIRCDPlus",
|
||||||
"PatternNet",
|
"PatternNet",
|
||||||
|
"RESISC45",
|
||||||
"SEN12MS",
|
"SEN12MS",
|
||||||
"So2Sat",
|
"So2Sat",
|
||||||
"TropicalCycloneWindEstimation",
|
"TropicalCycloneWindEstimation",
|
||||||
|
|
|
@ -0,0 +1,208 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""RESISC45 dataset."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
|
|
||||||
|
from .geo import VisionDataset
|
||||||
|
from .utils import download_and_extract_archive
|
||||||
|
|
||||||
|
|
||||||
|
class RESISC45(VisionDataset, ImageFolder): # type: ignore[misc]
|
||||||
|
"""RESISC45 dataset.
|
||||||
|
|
||||||
|
The `RESISC45 <http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html>`_
|
||||||
|
dataset is a dataset for remote sensing image scene classification.
|
||||||
|
|
||||||
|
Dataset features:
|
||||||
|
|
||||||
|
* 31,500 images with 0.2-30 m per pixel resolution (256x256 px)
|
||||||
|
* three spectral bands - RGB
|
||||||
|
* 45 scene classes, 700 images per class
|
||||||
|
* images extracted from Google Earth from over 100 countries
|
||||||
|
* images conditions with high variability (resolution, weather, illumination)
|
||||||
|
|
||||||
|
Dataset format:
|
||||||
|
|
||||||
|
* images are three-channel jpgs
|
||||||
|
|
||||||
|
Dataset classes:
|
||||||
|
|
||||||
|
0. airplane
|
||||||
|
1. airport
|
||||||
|
2. baseball_diamond
|
||||||
|
3. basketball_court
|
||||||
|
4. beach
|
||||||
|
5. bridge
|
||||||
|
6. chaparral
|
||||||
|
7. church
|
||||||
|
8. circular_farmland
|
||||||
|
9. cloud
|
||||||
|
10. commercial_area
|
||||||
|
11. dense_residential
|
||||||
|
12. desert
|
||||||
|
13. forest
|
||||||
|
14. freeway
|
||||||
|
15. golf_course
|
||||||
|
16. ground_track_field
|
||||||
|
17. harbor
|
||||||
|
18. industrial_area
|
||||||
|
19. intersection
|
||||||
|
20. island
|
||||||
|
21. lake
|
||||||
|
22. meadow
|
||||||
|
23. medium_residential
|
||||||
|
24. mobile_home_park
|
||||||
|
25. mountain
|
||||||
|
26. overpass
|
||||||
|
27. palace
|
||||||
|
28. parking_lot
|
||||||
|
29. railway
|
||||||
|
30. railway_station
|
||||||
|
31. rectangular_farmland
|
||||||
|
32. river
|
||||||
|
33. roundabout
|
||||||
|
34. runway
|
||||||
|
35. sea_ice
|
||||||
|
36. ship
|
||||||
|
37. snowberg
|
||||||
|
38. sparse_residential
|
||||||
|
39. stadium
|
||||||
|
40. storage_tank
|
||||||
|
41. tennis_court
|
||||||
|
42. terrace
|
||||||
|
43. thermal_power_station
|
||||||
|
44. wetland
|
||||||
|
|
||||||
|
If you use this dataset in your research, please cite the following paper:
|
||||||
|
|
||||||
|
* https://doi.org/10.1109/jproc.2017.2675998
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv"
|
||||||
|
md5 = "d824acb73957502b00efd559fc6cfbbb"
|
||||||
|
filename = "NWPU-RESISC45.rar"
|
||||||
|
directory = "NWPU-RESISC45"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: str = "data",
|
||||||
|
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||||
|
download: bool = False,
|
||||||
|
checksum: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a new PatternNet dataset instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root: root directory where dataset can be found
|
||||||
|
transforms: a function/transform that takes input sample and its target as
|
||||||
|
entry and returns a transformed version
|
||||||
|
download: if True, download dataset and store it in the root directory
|
||||||
|
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||||
|
don't match
|
||||||
|
"""
|
||||||
|
self.root = root
|
||||||
|
self.checksum = checksum
|
||||||
|
|
||||||
|
if download:
|
||||||
|
self._download()
|
||||||
|
|
||||||
|
if not self._check_integrity():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Dataset not found or corrupted. "
|
||||||
|
+ "You can use download=True to download it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# When transform & target_transform are None, ImageFolder.__getitem__[index]
|
||||||
|
# returns a PIL.Image and int for image and label, respectively
|
||||||
|
super().__init__(
|
||||||
|
root=os.path.join(root, self.directory),
|
||||||
|
transform=None,
|
||||||
|
target_transform=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Must be set after calling super().__init__()
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
|
"""Return an index within the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data and label at that index
|
||||||
|
"""
|
||||||
|
image, label = self._load_image(index)
|
||||||
|
sample = {"image": image, "label": label}
|
||||||
|
|
||||||
|
if self.transforms is not None:
|
||||||
|
sample = self.transforms(sample)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of data points in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
length of the dataset
|
||||||
|
"""
|
||||||
|
return len(self.imgs)
|
||||||
|
|
||||||
|
def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""Load a single image and it's class label.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the image
|
||||||
|
the image class label
|
||||||
|
"""
|
||||||
|
img, label = ImageFolder.__getitem__(self, index)
|
||||||
|
array = np.array(img)
|
||||||
|
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||||
|
# Convert from HxWxC to CxHxW
|
||||||
|
tensor = tensor.permute((2, 0, 1))
|
||||||
|
label = torch.tensor(label) # type: ignore[attr-defined]
|
||||||
|
return tensor, label
|
||||||
|
|
||||||
|
def _check_integrity(self) -> bool:
|
||||||
|
"""Checks the integrity of the dataset structure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the dataset directories and split files are found, else False
|
||||||
|
"""
|
||||||
|
filepath = os.path.join(self.root, self.directory)
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _download(self) -> None:
|
||||||
|
"""Download the dataset and extract it.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: if the checksum of split.py does not match
|
||||||
|
"""
|
||||||
|
if self._check_integrity():
|
||||||
|
print("Files already downloaded and verified")
|
||||||
|
return
|
||||||
|
|
||||||
|
download_and_extract_archive(
|
||||||
|
self.url,
|
||||||
|
self.root,
|
||||||
|
filename=self.filename,
|
||||||
|
md5=self.md5 if self.checksum else None,
|
||||||
|
)
|
Загрузка…
Ссылка в новой задаче