diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 2956a049b..504d836e1 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -127,6 +127,11 @@ So2Sat
.. autoclass:: So2Sat
+SpaceNet
+^^^^^^^^
+
+.. autoclass:: SpaceNet1
+
Tropical Cyclone Wind Estimation Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/tests/data/spacenet/spacenet1/sn1_AOI_1_RIO.tar.gz b/tests/data/spacenet/spacenet1/sn1_AOI_1_RIO.tar.gz
new file mode 100644
index 000000000..231c81e1c
Binary files /dev/null and b/tests/data/spacenet/spacenet1/sn1_AOI_1_RIO.tar.gz differ
diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py
new file mode 100644
index 000000000..39a7acf9b
--- /dev/null
+++ b/tests/datasets/test_spacenet.py
@@ -0,0 +1,77 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import glob
+import os
+import shutil
+from pathlib import Path
+from typing import Generator
+
+import pytest
+import torch
+from _pytest.fixtures import SubRequest
+from _pytest.monkeypatch import MonkeyPatch
+
+from torchgeo.datasets import SpaceNet1
+from torchgeo.transforms import Identity
+
+TEST_DATA_DIR = "tests/data/spacenet"
+
+
+class Dataset:
+ def __init__(self, collection_id: str) -> None:
+ self.collection_id = collection_id
+
+ def download(self, output_dir: str, **kwargs: str) -> None:
+ glob_path = os.path.join(TEST_DATA_DIR, self.collection_id, "*.tar.gz")
+ for tarball in glob.iglob(glob_path):
+ shutil.copy(tarball, output_dir)
+
+
+def fetch(collection_id: str, **kwargs: str) -> Dataset:
+ return Dataset(collection_id)
+
+
+class TestSpaceNet1:
+ @pytest.fixture(params=["rgb", "8band"])
+ def dataset(
+ self,
+ request: SubRequest,
+ monkeypatch: Generator[MonkeyPatch, None, None],
+ tmp_path: Path,
+ ) -> SpaceNet1:
+ radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
+ monkeypatch.setattr( # type: ignore[attr-defined]
+ radiant_mlhub.Dataset, "fetch", fetch
+ )
+ test_md5 = "829652022c2df4511ee4ae05bc290250"
+ monkeypatch.setattr(SpaceNet1, "md5", test_md5) # type: ignore[attr-defined]
+ root = str(tmp_path)
+ transforms = Identity()
+ return SpaceNet1(
+ root,
+ image=request.param,
+ transforms=transforms,
+ download=True,
+ api_key="",
+ )
+
+ def test_getitem(self, dataset: SpaceNet1) -> None:
+ x = dataset[0]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["mask"], torch.Tensor)
+ if dataset.image == "rgb":
+ assert x["image"].shape[0] == 3
+ else:
+ assert x["image"].shape[0] == 8
+
+ def test_len(self, dataset: SpaceNet1) -> None:
+ assert len(dataset) == 2
+
+ def test_already_downloaded(self, dataset: SpaceNet1) -> None:
+ SpaceNet1(root=dataset.root, download=True)
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(RuntimeError, match="Dataset not found"):
+ SpaceNet1(str(tmp_path))
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index cba8f93a0..7eee81cc9 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -47,6 +47,7 @@ from .resisc45 import RESISC45
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
+from .spacenet import SpaceNet1
from .utils import BoundingBox, collate_dict
__all__ = (
@@ -92,6 +93,7 @@ __all__ = (
"RESISC45",
"SEN12MS",
"So2Sat",
+ "SpaceNet1",
"TropicalCycloneWindEstimation",
"VHR10",
# Base classes
diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py
new file mode 100644
index 000000000..0bdfbab80
--- /dev/null
+++ b/torchgeo/datasets/spacenet.py
@@ -0,0 +1,239 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""SpaceNet datasets."""
+
+import glob
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import fiona
+import numpy as np
+import rasterio as rio
+import torch
+from affine import Affine
+from rasterio.features import rasterize
+from torch import Tensor
+
+from torchgeo.datasets.geo import VisionDataset
+from torchgeo.datasets.utils import (
+ check_integrity,
+ download_radiant_mlhub,
+ extract_archive,
+)
+
+
+class SpaceNet1(VisionDataset):
+ """SpaceNet 1: Building Detection v1 Dataset.
+
+ `SpaceNet 1 `_
+ is a dataset of building footprints over the city of Rio de Janeiro.
+
+ Dataset features:
+
+ * No. of images - 6940 (8 Band) + 6940 (RGB)
+ * No. of polygons - 382,534 building labels
+ * Area Coverage - 2544 sq km
+
+ Dataset format:
+
+ * Imagery - Raw 8 band Worldview-3 (GeoTIFF) & Pansharpened RGB image (GeoTIFF)
+ * Labels - GeoJSON
+
+ If you are using data from SpaceNet in a paper, please cite the following paper:
+
+ * https://arxiv.org/abs/1807.01232
+
+ .. note::
+
+ This dataset requires the following additional library to be installed:
+
+ * `radiant-mlhub `_ to download the
+ imagery and labels from the Radiant Earth MLHub
+
+ """
+
+ dataset_id = "spacenet1"
+ md5 = "e6ea35331636fa0c036c04b3d1cbf226"
+ imagery = {"rgb": "RGB.tif", "8band": "8Band.tif"}
+ label_glob = "labels.geojson"
+ foldername = "sn1_AOI_1_RIO"
+
+ def __init__(
+ self,
+ root: str,
+ image: str = "rgb",
+ transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ download: bool = False,
+ api_key: Optional[str] = None,
+ checksum: bool = False,
+ ) -> None:
+ """Initialise a new SpaceNet 1 Dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ image: image selection which must be "rgb" or "8band"
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version.
+ download: if True, download dataset and store it in the root directory.
+ api_key: a RadiantEarth MLHub API key to use for downloading the dataset
+
+ Raises:
+ RuntimeError: if ``download=False`` but dataset is missing
+ """
+ self.root = root
+ self.image = image # For testing
+ self.filename = self.imagery[image]
+ self.transforms = transforms
+ self.checksum = checksum
+
+ if not self._check_integrity():
+ if download:
+ self._download(api_key)
+ else:
+ raise RuntimeError(
+ "Dataset not found. You can use download=True to download it."
+ )
+
+ self.files = self._load_files(os.path.join(root, self.foldername))
+
+ def _load_files(self, root: str) -> List[Dict[str, str]]:
+ """Return the paths of the files in the dataset.
+
+ Args:
+ root: root dir of dataset
+
+ Returns:
+ list of dicts containing paths for each triple of rgb,
+ 8band and label
+ """
+ files = []
+ images = glob.glob(os.path.join(root, "*", self.filename))
+ images = sorted(images)
+ for imgpath in images:
+ lbl_path = os.path.join(
+ os.path.dirname(imgpath) + "-labels", "labels.geojson"
+ )
+ files.append({"image_path": imgpath, "label_path": lbl_path})
+ return files
+
+ def _load_image(self, path: str) -> Tuple[Tensor, Affine]:
+ """Load a single image.
+
+ Args:
+ path: path to the image
+
+ Returns:
+ the image
+ """
+ filename = os.path.join(path)
+ with rio.open(filename) as img:
+ array = img.read().astype(np.float32)
+ tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
+ return tensor, img.transform
+
+ def _load_mask(self, path: str, tfm: Affine, shape: Tuple[int, int]) -> Tensor:
+ """Rasterizes the dataset's labels (in geojson format).
+
+ Args:
+ path (str): path to the label
+ tfm (Affine): transform of corresponding image
+ shape (List[int, int]): shape of corresponding image
+
+ Returns:
+ Tensor: label tensor
+ """
+ with fiona.open(path) as src:
+ labels = [feature["geometry"] for feature in src]
+
+ if not labels:
+ mask_data = np.zeros(shape=shape)
+ else:
+ mask_data = rasterize(
+ labels,
+ out_shape=shape,
+ fill=0, # nodata value
+ transform=tfm,
+ all_touched=False,
+ dtype=np.uint8,
+ )
+
+ mask: Tensor = torch.from_numpy(mask_data).long() # type: ignore[attr-defined]
+
+ return mask
+
+ def __len__(self) -> int:
+ """Return the number of samples in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return len(self.files)
+
+ def __getitem__(self, index: int) -> Dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ data and label at that index
+ """
+ files = self.files[index]
+ img, tfm = self._load_image(files["image_path"])
+ h, w = img.shape[1:]
+ mask = self._load_mask(files["label_path"], tfm, (h, w))
+
+ sample = {"image": img, "mask": mask}
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def _check_integrity(self) -> bool:
+ """Checks the integrity of the dataset structure.
+
+ Returns:
+ True if the dataset directories are found, else False
+ """
+ stacpath = os.path.join(self.root, self.foldername, "collection.json")
+
+ if os.path.exists(stacpath):
+ return True
+
+ # If dataset folder does not exist, check for uncorrupted archive
+ archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
+ if not os.path.exists(archive_path):
+ return False
+ print("Archive found")
+ if self.checksum and not check_integrity(archive_path, self.md5):
+ print("Dataset corrupted")
+ return False
+ print("Extracting...")
+ extract_archive(archive_path)
+ return True
+
+ def _download(self, api_key: Optional[str] = None) -> None:
+ """Download the dataset and extract it.
+
+ Args:
+ api_key: a RadiantEarth MLHub API key to use for downloading the dataset
+
+ Raises:
+ RuntimeError: if download doesn't work correctly or checksums don't match
+ """
+ if self._check_integrity():
+ print("Files already downloaded")
+ return
+
+ download_radiant_mlhub(self.dataset_id, self.root, api_key)
+ archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
+ if (
+ self.checksum
+ and check_integrity(archive_path, self.md5)
+ or not self.checksum
+ ):
+ extract_archive(archive_path)
+ else:
+ raise RuntimeError("Dataset corrupted")