diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 2956a049b..504d836e1 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -127,6 +127,11 @@ So2Sat .. autoclass:: So2Sat +SpaceNet +^^^^^^^^ + +.. autoclass:: SpaceNet1 + Tropical Cyclone Wind Estimation Competition ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/data/spacenet/spacenet1/sn1_AOI_1_RIO.tar.gz b/tests/data/spacenet/spacenet1/sn1_AOI_1_RIO.tar.gz new file mode 100644 index 000000000..231c81e1c Binary files /dev/null and b/tests/data/spacenet/spacenet1/sn1_AOI_1_RIO.tar.gz differ diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py new file mode 100644 index 000000000..39a7acf9b --- /dev/null +++ b/tests/datasets/test_spacenet.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import os +import shutil +from pathlib import Path +from typing import Generator + +import pytest +import torch +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch + +from torchgeo.datasets import SpaceNet1 +from torchgeo.transforms import Identity + +TEST_DATA_DIR = "tests/data/spacenet" + + +class Dataset: + def __init__(self, collection_id: str) -> None: + self.collection_id = collection_id + + def download(self, output_dir: str, **kwargs: str) -> None: + glob_path = os.path.join(TEST_DATA_DIR, self.collection_id, "*.tar.gz") + for tarball in glob.iglob(glob_path): + shutil.copy(tarball, output_dir) + + +def fetch(collection_id: str, **kwargs: str) -> Dataset: + return Dataset(collection_id) + + +class TestSpaceNet1: + @pytest.fixture(params=["rgb", "8band"]) + def dataset( + self, + request: SubRequest, + monkeypatch: Generator[MonkeyPatch, None, None], + tmp_path: Path, + ) -> SpaceNet1: + radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") + monkeypatch.setattr( # type: ignore[attr-defined] + radiant_mlhub.Dataset, "fetch", fetch + ) + test_md5 = "829652022c2df4511ee4ae05bc290250" + monkeypatch.setattr(SpaceNet1, "md5", test_md5) # type: ignore[attr-defined] + root = str(tmp_path) + transforms = Identity() + return SpaceNet1( + root, + image=request.param, + transforms=transforms, + download=True, + api_key="", + ) + + def test_getitem(self, dataset: SpaceNet1) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + if dataset.image == "rgb": + assert x["image"].shape[0] == 3 + else: + assert x["image"].shape[0] == 8 + + def test_len(self, dataset: SpaceNet1) -> None: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: SpaceNet1) -> None: + SpaceNet1(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + SpaceNet1(str(tmp_path)) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index cba8f93a0..7eee81cc9 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -47,6 +47,7 @@ from .resisc45 import RESISC45 from .sen12ms import SEN12MS from .sentinel import Sentinel, Sentinel2 from .so2sat import So2Sat +from .spacenet import SpaceNet1 from .utils import BoundingBox, collate_dict __all__ = ( @@ -92,6 +93,7 @@ __all__ = ( "RESISC45", "SEN12MS", "So2Sat", + "SpaceNet1", "TropicalCycloneWindEstimation", "VHR10", # Base classes diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py new file mode 100644 index 000000000..0bdfbab80 --- /dev/null +++ b/torchgeo/datasets/spacenet.py @@ -0,0 +1,239 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SpaceNet datasets.""" + +import glob +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import fiona +import numpy as np +import rasterio as rio +import torch +from affine import Affine +from rasterio.features import rasterize +from torch import Tensor + +from torchgeo.datasets.geo import VisionDataset +from torchgeo.datasets.utils import ( + check_integrity, + download_radiant_mlhub, + extract_archive, +) + + +class SpaceNet1(VisionDataset): + """SpaceNet 1: Building Detection v1 Dataset. + + `SpaceNet 1 `_ + is a dataset of building footprints over the city of Rio de Janeiro. + + Dataset features: + + * No. of images - 6940 (8 Band) + 6940 (RGB) + * No. of polygons - 382,534 building labels + * Area Coverage - 2544 sq km + + Dataset format: + + * Imagery - Raw 8 band Worldview-3 (GeoTIFF) & Pansharpened RGB image (GeoTIFF) + * Labels - GeoJSON + + If you are using data from SpaceNet in a paper, please cite the following paper: + + * https://arxiv.org/abs/1807.01232 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `radiant-mlhub `_ to download the + imagery and labels from the Radiant Earth MLHub + + """ + + dataset_id = "spacenet1" + md5 = "e6ea35331636fa0c036c04b3d1cbf226" + imagery = {"rgb": "RGB.tif", "8band": "8Band.tif"} + label_glob = "labels.geojson" + foldername = "sn1_AOI_1_RIO" + + def __init__( + self, + root: str, + image: str = "rgb", + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + download: bool = False, + api_key: Optional[str] = None, + checksum: bool = False, + ) -> None: + """Initialise a new SpaceNet 1 Dataset instance. + + Args: + root: root directory where dataset can be found + image: image selection which must be "rgb" or "8band" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version. + download: if True, download dataset and store it in the root directory. + api_key: a RadiantEarth MLHub API key to use for downloading the dataset + + Raises: + RuntimeError: if ``download=False`` but dataset is missing + """ + self.root = root + self.image = image # For testing + self.filename = self.imagery[image] + self.transforms = transforms + self.checksum = checksum + + if not self._check_integrity(): + if download: + self._download(api_key) + else: + raise RuntimeError( + "Dataset not found. You can use download=True to download it." + ) + + self.files = self._load_files(os.path.join(root, self.foldername)) + + def _load_files(self, root: str) -> List[Dict[str, str]]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + list of dicts containing paths for each triple of rgb, + 8band and label + """ + files = [] + images = glob.glob(os.path.join(root, "*", self.filename)) + images = sorted(images) + for imgpath in images: + lbl_path = os.path.join( + os.path.dirname(imgpath) + "-labels", "labels.geojson" + ) + files.append({"image_path": imgpath, "label_path": lbl_path}) + return files + + def _load_image(self, path: str) -> Tuple[Tensor, Affine]: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + filename = os.path.join(path) + with rio.open(filename) as img: + array = img.read().astype(np.float32) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + return tensor, img.transform + + def _load_mask(self, path: str, tfm: Affine, shape: Tuple[int, int]) -> Tensor: + """Rasterizes the dataset's labels (in geojson format). + + Args: + path (str): path to the label + tfm (Affine): transform of corresponding image + shape (List[int, int]): shape of corresponding image + + Returns: + Tensor: label tensor + """ + with fiona.open(path) as src: + labels = [feature["geometry"] for feature in src] + + if not labels: + mask_data = np.zeros(shape=shape) + else: + mask_data = rasterize( + labels, + out_shape=shape, + fill=0, # nodata value + transform=tfm, + all_touched=False, + dtype=np.uint8, + ) + + mask: Tensor = torch.from_numpy(mask_data).long() # type: ignore[attr-defined] + + return mask + + def __len__(self) -> int: + """Return the number of samples in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + files = self.files[index] + img, tfm = self._load_image(files["image_path"]) + h, w = img.shape[1:] + mask = self._load_mask(files["label_path"], tfm, (h, w)) + + sample = {"image": img, "mask": mask} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _check_integrity(self) -> bool: + """Checks the integrity of the dataset structure. + + Returns: + True if the dataset directories are found, else False + """ + stacpath = os.path.join(self.root, self.foldername, "collection.json") + + if os.path.exists(stacpath): + return True + + # If dataset folder does not exist, check for uncorrupted archive + archive_path = os.path.join(self.root, self.foldername + ".tar.gz") + if not os.path.exists(archive_path): + return False + print("Archive found") + if self.checksum and not check_integrity(archive_path, self.md5): + print("Dataset corrupted") + return False + print("Extracting...") + extract_archive(archive_path) + return True + + def _download(self, api_key: Optional[str] = None) -> None: + """Download the dataset and extract it. + + Args: + api_key: a RadiantEarth MLHub API key to use for downloading the dataset + + Raises: + RuntimeError: if download doesn't work correctly or checksums don't match + """ + if self._check_integrity(): + print("Files already downloaded") + return + + download_radiant_mlhub(self.dataset_id, self.root, api_key) + archive_path = os.path.join(self.root, self.foldername + ".tar.gz") + if ( + self.checksum + and check_integrity(archive_path, self.md5) + or not self.checksum + ): + extract_archive(archive_path) + else: + raise RuntimeError("Dataset corrupted")