From f60cbee39b168618973dc410c5f7137fe7b55e9d Mon Sep 17 00:00:00 2001 From: isaac <22203655+isaaccorley@users.noreply.github.com> Date: Sun, 12 Sep 2021 10:50:15 -0500 Subject: [PATCH] Add RESISC45 Dataset (#126) * updated docs * added sample data for tests * added unit tests * added dataset * updated tests to not run on windows due to rar --- docs/api/datasets.rst | 5 + tests/data/resisc45/NWPU-RESISC45.rar | Bin 0 -> 1287 bytes tests/datasets/test_resisc45.py | 57 +++++++ torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/resisc45.py | 208 ++++++++++++++++++++++++++ 5 files changed, 272 insertions(+) create mode 100644 tests/data/resisc45/NWPU-RESISC45.rar create mode 100644 tests/datasets/test_resisc45.py create mode 100644 torchgeo/datasets/resisc45.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index d6c07b0f5..2956a049b 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -112,6 +112,11 @@ PatternNet .. autoclass:: PatternNet +RESISC45 (Remote Sensing Image Scene Classification) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: RESISC45 + SEN12MS ^^^^^^^ diff --git a/tests/data/resisc45/NWPU-RESISC45.rar b/tests/data/resisc45/NWPU-RESISC45.rar new file mode 100644 index 0000000000000000000000000000000000000000..b9771fee9387798a90c4fe67322117a4194e1d96 GIT binary patch literal 1287 zcmWGaEK-zWXJq*Nu<127BP%-t8zW;wLjwbMLbbIk6EpW+W|r?POPZM)A2CjDU}jYJ z3l9j@4RQ_k40bj#)lbYUD#%I9ONBGy4GawRvI^3fn@VJfn6l<@T9bo=U+jYMbXwxPIdPnY%!I z|H%Cl_YWMAIDJM)@WP)~?eBXtW8@MR!aouCt zj>Hy)CBNJy0+S3Db!$xKX31P;E3-H8)YUnEwie&|m;75Vqp(|o?LzMZDaWncG0u*T zUle0)B+WZs9bR~)rRXKA-NarGUUh?rjZHi^Haluw+$i`=Q&J|aDWsx8q<6(kbrp|0 z8qxoZ7M@U1<1u|1l@q#;IpIgC3!kToz`XcHl}8+(tWD<@3VQL9>6XWO=huJd8mmrT`#^YT=>YUu(s~-zVlf|=G*VQ`@Z`hKY#bh zr=KhT%FPd7fBW6{KPvm|_2=Jy|KrcU{q-mPzyAx1ic3n%%B!ktYU}EoG;P+RWvkY0 z+O})op>vn6-MaVa*{gToe*I@==jIm{S60{7H#WDncXsy=4v&scPS4ISbaarjxa0Zw z{QUz`j4EQEonER+*vK$Uns(&4xT4a|WgojAH54jbTbKN0e))Qa%*|^$WKG#CtS9lF zTpH!wqb=gjH(|wxN#YUF)tma% None: + shutil.copy(url, root) + + +@pytest.mark.skipif(sys.platform == "win32", reason="requires unrar executable") +class TestRESISC45: + @pytest.fixture(params=["train", "test"]) + def dataset( + self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path + ) -> RESISC45: + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.utils, "download_url", download_url + ) + md5 = "9c221122164d17b8118d2b6527ee5e9c" + monkeypatch.setattr(RESISC45, "md5", md5) # type: ignore[attr-defined] + url = os.path.join("tests", "data", "resisc45", "NWPU-RESISC45.rar") + monkeypatch.setattr(RESISC45, "url", url) # type: ignore[attr-defined] + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[attr-defined] + return RESISC45(root, transforms, download=True, checksum=True) + + def test_getitem(self, dataset: RESISC45) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert x["image"].shape[0] == 3 + + def test_len(self, dataset: RESISC45) -> None: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: RESISC45) -> None: + RESISC45(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): + RESISC45(str(tmp_path)) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index c057e0327..cba8f93a0 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -43,6 +43,7 @@ from .levircd import LEVIRCDPlus from .naip import NAIP from .nwpu import VHR10 from .patternnet import PatternNet +from .resisc45 import RESISC45 from .sen12ms import SEN12MS from .sentinel import Sentinel, Sentinel2 from .so2sat import So2Sat @@ -88,6 +89,7 @@ __all__ = ( "LandCoverAI", "LEVIRCDPlus", "PatternNet", + "RESISC45", "SEN12MS", "So2Sat", "TropicalCycloneWindEstimation", diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py new file mode 100644 index 000000000..4a10c21fc --- /dev/null +++ b/torchgeo/datasets/resisc45.py @@ -0,0 +1,208 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""RESISC45 dataset.""" + +import os +from typing import Callable, Dict, Optional, Tuple + +import numpy as np +import torch +from torch import Tensor +from torchvision.datasets import ImageFolder + +from .geo import VisionDataset +from .utils import download_and_extract_archive + + +class RESISC45(VisionDataset, ImageFolder): # type: ignore[misc] + """RESISC45 dataset. + + The `RESISC45 `_ + dataset is a dataset for remote sensing image scene classification. + + Dataset features: + + * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) + * three spectral bands - RGB + * 45 scene classes, 700 images per class + * images extracted from Google Earth from over 100 countries + * images conditions with high variability (resolution, weather, illumination) + + Dataset format: + + * images are three-channel jpgs + + Dataset classes: + + 0. airplane + 1. airport + 2. baseball_diamond + 3. basketball_court + 4. beach + 5. bridge + 6. chaparral + 7. church + 8. circular_farmland + 9. cloud + 10. commercial_area + 11. dense_residential + 12. desert + 13. forest + 14. freeway + 15. golf_course + 16. ground_track_field + 17. harbor + 18. industrial_area + 19. intersection + 20. island + 21. lake + 22. meadow + 23. medium_residential + 24. mobile_home_park + 25. mountain + 26. overpass + 27. palace + 28. parking_lot + 29. railway + 30. railway_station + 31. rectangular_farmland + 32. river + 33. roundabout + 34. runway + 35. sea_ice + 36. ship + 37. snowberg + 38. sparse_residential + 39. stadium + 40. storage_tank + 41. tennis_court + 42. terrace + 43. thermal_power_station + 44. wetland + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.1109/jproc.2017.2675998 + + """ + + url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" + md5 = "d824acb73957502b00efd559fc6cfbbb" + filename = "NWPU-RESISC45.rar" + directory = "NWPU-RESISC45" + + def __init__( + self, + root: str = "data", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new PatternNet dataset instance. + + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + self.root = root + self.checksum = checksum + + if download: + self._download() + + if not self._check_integrity(): + raise RuntimeError( + "Dataset not found or corrupted. " + + "You can use download=True to download it" + ) + + # When transform & target_transform are None, ImageFolder.__getitem__[index] + # returns a PIL.Image and int for image and label, respectively + super().__init__( + root=os.path.join(root, self.directory), + transform=None, + target_transform=None, + ) + + # Must be set after calling super().__init__() + self.transforms = transforms + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + image, label = self._load_image(index) + sample = {"image": image, "label": label} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.imgs) + + def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: + """Load a single image and it's class label. + + Args: + index: index to return + + Returns: + the image + the image class label + """ + img, label = ImageFolder.__getitem__(self, index) + array = np.array(img) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + label = torch.tensor(label) # type: ignore[attr-defined] + return tensor, label + + def _check_integrity(self) -> bool: + """Checks the integrity of the dataset structure. + + Returns: + True if the dataset directories and split files are found, else False + """ + filepath = os.path.join(self.root, self.directory) + if not os.path.exists(filepath): + return False + + return True + + def _download(self) -> None: + """Download the dataset and extract it. + + Raises: + AssertionError: if the checksum of split.py does not match + """ + if self._check_integrity(): + print("Files already downloaded and verified") + return + + download_and_extract_archive( + self.url, + self.root, + filename=self.filename, + md5=self.md5 if self.checksum else None, + )