diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index e18d18786..f4e90812e 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -233,6 +233,7 @@ SpaceNet
.. autoclass:: SpaceNet
.. autoclass:: SpaceNet1
.. autoclass:: SpaceNet2
+.. autoclass:: SpaceNet3
.. autoclass:: SpaceNet4
.. autoclass:: SpaceNet5
.. autoclass:: SpaceNet7
diff --git a/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz b/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz
new file mode 100644
index 000000000..d8cb305dc
Binary files /dev/null and b/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz differ
diff --git a/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz b/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz
new file mode 100644
index 000000000..0daea2f53
Binary files /dev/null and b/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz differ
diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py
index 35ce3cecf..7e4cf90af 100644
--- a/tests/datasets/test_spacenet.py
+++ b/tests/datasets/test_spacenet.py
@@ -2,7 +2,6 @@
# Licensed under the MIT License.
import glob
-import itertools
import os
import shutil
from pathlib import Path
@@ -14,7 +13,14 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
-from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
+from torchgeo.datasets import (
+ SpaceNet1,
+ SpaceNet2,
+ SpaceNet3,
+ SpaceNet4,
+ SpaceNet5,
+ SpaceNet7,
+)
TEST_DATA_DIR = "tests/data/spacenet"
@@ -142,6 +148,71 @@ class TestSpaceNet2:
plt.close()
+class TestSpaceNet3:
+ @pytest.fixture(params=zip(["PAN", "MS"], [False, True]))
+ def dataset(
+ self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
+ ) -> SpaceNet3:
+ radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
+ monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection)
+ test_md5 = {
+ "sn3_AOI_3_Paris": "197440e0ade970169a801a173a492c27",
+ "sn3_AOI_5_Khartoum": "b21ff7dd33a15ec32bd380c083263cdf",
+ }
+
+ monkeypatch.setattr(SpaceNet3, "collection_md5_dict", test_md5)
+ root = str(tmp_path)
+ transforms = nn.Identity() # type: ignore[no-untyped-call]
+ return SpaceNet3(
+ root,
+ image=request.param[0],
+ speed_mask=request.param[1],
+ collections=["sn3_AOI_3_Paris", "sn3_AOI_5_Khartoum"],
+ transforms=transforms,
+ download=True,
+ api_key="",
+ )
+
+ def test_getitem(self, dataset: SpaceNet3) -> None:
+ # Iterate over all elements to maximize coverage
+ samples = [dataset[i] for i in range(len(dataset))]
+ x = samples[0]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["mask"], torch.Tensor)
+ if dataset.image == "MS":
+ assert x["image"].shape[0] == 8
+ else:
+ assert x["image"].shape[0] == 1
+
+ def test_len(self, dataset: SpaceNet3) -> None:
+ assert len(dataset) == 4
+
+ def test_already_downloaded(self, dataset: SpaceNet3) -> None:
+ SpaceNet3(root=dataset.root, download=True)
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(RuntimeError, match="Dataset not found"):
+ SpaceNet3(str(tmp_path))
+
+ def test_collection_checksum(self, dataset: SpaceNet3) -> None:
+ dataset.collection_md5_dict["sn3_AOI_5_Khartoum"] = "randommd5hash123"
+ with pytest.raises(
+ RuntimeError, match="Collection sn3_AOI_5_Khartoum corrupted"
+ ):
+ SpaceNet3(root=dataset.root, download=True, checksum=True)
+
+ def test_plot(self, dataset: SpaceNet3) -> None:
+ x = dataset[0].copy()
+ x["prediction"] = x["mask"]
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+ dataset.plot(x, show_titles=False)
+ plt.close()
+ dataset.plot({"image": x["image"]})
+ plt.close()
+
+
class TestSpaceNet4:
@pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"])
def dataset(
@@ -206,9 +277,7 @@ class TestSpaceNet4:
class TestSpaceNet5:
- @pytest.fixture(
- params=itertools.product(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True])
- )
+ @pytest.fixture(params=zip(["PAN", "MS"], [False, True]))
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet5:
@@ -239,9 +308,7 @@ class TestSpaceNet5:
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
- if dataset.image == "PS-RGB":
- assert x["image"].shape[0] == 3
- elif dataset.image in ["MS", "PS-MS"]:
+ if dataset.image == "MS":
assert x["image"].shape[0] == 8
else:
assert x["image"].shape[0] == 1
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index 8d28060b4..0bbdb94b0 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -75,7 +75,15 @@ from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
-from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
+from .spacenet import (
+ SpaceNet,
+ SpaceNet1,
+ SpaceNet2,
+ SpaceNet3,
+ SpaceNet4,
+ SpaceNet5,
+ SpaceNet7,
+)
from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
@@ -155,6 +163,7 @@ __all__ = (
"SpaceNet",
"SpaceNet1",
"SpaceNet2",
+ "SpaceNet3",
"SpaceNet4",
"SpaceNet5",
"SpaceNet7",
diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py
index 5cf7d9fe5..7746e3e09 100644
--- a/torchgeo/datasets/spacenet.py
+++ b/torchgeo/datasets/spacenet.py
@@ -133,7 +133,7 @@ class SpaceNet(VisionDataset, abc.ABC):
images = sorted(images)
for imgpath in images:
lbl_path = os.path.join(
- os.path.dirname(imgpath) + "-labels", self.label_glob
+ f"{os.path.dirname(imgpath)}-labels", self.label_glob
)
files.append({"image_path": imgpath, "label_path": lbl_path})
return files
@@ -248,7 +248,7 @@ class SpaceNet(VisionDataset, abc.ABC):
to_be_downloaded = []
for collection in missing_collections:
- archive_path = os.path.join(self.root, collection + ".tar.gz")
+ archive_path = os.path.join(self.root, f"{collection}.tar.gz")
if os.path.exists(archive_path):
print(f"Found {collection} archive")
if (
@@ -281,7 +281,7 @@ class SpaceNet(VisionDataset, abc.ABC):
"""
for collection in collections:
download_radiant_mlhub_collection(collection, self.root, api_key)
- archive_path = os.path.join(self.root, collection + ".tar.gz")
+ archive_path = os.path.join(self.root, f"{collection}.tar.gz")
if (
not self.checksum
or not check_integrity(
@@ -538,7 +538,8 @@ class SpaceNet2(SpaceNet):
image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"]
collections: collection selection which must be a subset of:
[sn2_AOI_2_Vegas, sn2_AOI_3_Paris, sn2_AOI_4_Shanghai,
- sn2_AOI_5_Khartoum]
+ sn2_AOI_5_Khartoum]. If unspecified, all collections will be
+ used.
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory.
@@ -554,6 +555,272 @@ class SpaceNet2(SpaceNet):
)
+class SpaceNet3(SpaceNet):
+ r"""SpaceNet 3: Road Network Detection.
+
+ `SpaceNet 3 `_
+ is a dataset of road networks over the cities of Las Vegas, Paris, Shanghai,
+ and Khartoum.
+
+ Collection features:
+
+ +------------+---------------------+------------+---------------------------+
+ | AOI | Area (km\ :sup:`2`\)| # Images | # Road Network Labels (km)|
+ +============+=====================+============+===========================+
+ | Vegas | 216 | 854 | 3685 |
+ +------------+---------------------+------------+---------------------------+
+ | Paris | 1030 | 257 | 425 |
+ +------------+---------------------+------------+---------------------------+
+ | Shanghai | 1000 | 1028 | 3537 |
+ +------------+---------------------+------------+---------------------------+
+ | Khartoum | 765 | 283 | 1030 |
+ +------------+---------------------+------------+---------------------------+
+
+ Imagery features:
+
+ .. list-table::
+ :widths: 10 10 10 10 10
+ :header-rows: 1
+ :stub-columns: 1
+
+ * -
+ - PAN
+ - MS
+ - PS-MS
+ - PS-RGB
+ * - GSD (m)
+ - 0.31
+ - 1.24
+ - 0.30
+ - 0.30
+ * - Chip size (px)
+ - 1300 x 1300
+ - 325 x 325
+ - 1300 x 1300
+ - 1300 x 1300
+
+ Dataset format:
+
+ * Imagery - Worldview-3 GeoTIFFs
+
+ * PAN.tif (Panchromatic)
+ * MS.tif (Multispectral)
+ * PS-MS (Pansharpened Multispectral)
+ * PS-RGB (Pansharpened RGB)
+
+ * Labels - GeoJSON
+
+ * labels.geojson
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://arxiv.org/abs/1807.01232
+
+ .. note::
+
+ This dataset requires the following additional library to be installed:
+
+ * `radiant-mlhub `_ to download the
+ imagery and labels from the Radiant Earth MLHub
+
+ .. versionadded:: 0.3
+ """
+
+ dataset_id = "spacenet3"
+ collection_md5_dict = {
+ "sn3_AOI_2_Vegas": "8ce7e6abffb8849eb88885035f061ee8",
+ "sn3_AOI_3_Paris": "90b9ebd64cd83dc8d3d4773f45050d8f",
+ "sn3_AOI_4_Shanghai": "3ea291df34548962dfba8b5ed37d700c",
+ "sn3_AOI_5_Khartoum": "b8d549ac9a6d7456c0f7a8e6de23d9f9",
+ }
+
+ imagery = {
+ "MS": "MS.tif",
+ "PAN": "PAN.tif",
+ "PS-MS": "PS-MS.tif",
+ "PS-RGB": "PS-RGB.tif",
+ }
+ chip_size = {
+ "MS": (325, 325),
+ "PAN": (1300, 1300),
+ "PS-MS": (1300, 1300),
+ "PS-RGB": (1300, 1300),
+ }
+ label_glob = "labels.geojson"
+
+ def __init__(
+ self,
+ root: str,
+ image: str = "PS-RGB",
+ speed_mask: Optional[bool] = False,
+ collections: List[str] = [],
+ transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ download: bool = False,
+ api_key: Optional[str] = None,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new SpaceNet 3 Dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"]
+ speed_mask: use multi-class speed mask (created by binning roads at
+ 10 mph increments) as label if true, else use binary mask
+ collections: collection selection which must be a subset of:
+ [sn3_AOI_2_Vegas, sn3_AOI_3_Paris, sn3_AOI_4_Shanghai,
+ sn3_AOI_5_Khartoum]. If unspecified, all collections will be
+ used.
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ download: if True, download dataset and store it in the root directory.
+ api_key: a RadiantEarth MLHub API key to use for downloading the dataset
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ RuntimeError: if ``download=False`` but dataset is missing
+ """
+ assert image in {"MS", "PAN", "PS-MS", "PS-RGB"}
+ self.speed_mask = speed_mask
+ super().__init__(
+ root, image, collections, transforms, download, api_key, checksum
+ )
+
+ def _load_mask(
+ self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int]
+ ) -> Tensor:
+ """Rasterizes the dataset's labels (in geojson format).
+
+ Args:
+ path: path to the label
+ tfm: transform of corresponding image
+ shape: shape of corresponding image
+
+ Returns:
+ Tensor: label tensor
+ """
+ min_speed_bin = 1
+ max_speed_bin = 65
+ speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1)
+ bin_size_mph = 10.0
+ speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array(
+ [int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin]
+ )
+
+ try:
+ with fiona.open(path) as src:
+ vector_crs = CRS(src.crs)
+ labels = []
+
+ for feature in src:
+ if raster_crs != vector_crs:
+ geom = transform_geom(
+ vector_crs.to_string(),
+ raster_crs.to_string(),
+ feature["geometry"],
+ )
+ else:
+ geom = feature["geometry"]
+
+ if self.speed_mask:
+ val = speed_cls_arr[
+ int(feature["properties"]["inferred_speed_mph"]) - 1
+ ]
+ else:
+ val = 1
+
+ labels.append((geom, val))
+
+ except FionaValueError:
+ labels = []
+
+ if not labels:
+ mask_data = np.zeros(shape=shape)
+ else:
+ mask_data = rasterize(
+ labels,
+ out_shape=shape,
+ fill=0, # nodata value
+ transform=tfm,
+ all_touched=False,
+ dtype=np.uint8,
+ )
+
+ mask = torch.from_numpy(mask_data).long()
+ return mask
+
+ def plot(
+ self,
+ sample: Dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ ) -> Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`SpaceNet.__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ suptitle: optional string to use as a suptitle
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+
+ """
+ # image can be 1 channel or >3 channels
+ if sample["image"].shape[0] == 1:
+ image = np.rollaxis(sample["image"].numpy(), 0, 3)
+ else:
+ image = np.rollaxis(sample["image"][:3].numpy(), 0, 3)
+ image = percentile_normalization(image, axis=(0, 1))
+
+ ncols = 1
+ show_mask = "mask" in sample
+ show_predictions = "prediction" in sample
+
+ if show_mask:
+ mask = sample["mask"].numpy()
+ ncols += 1
+
+ if show_predictions:
+ prediction = sample["prediction"].numpy()
+ ncols += 1
+
+ fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8))
+ if not isinstance(axs, np.ndarray):
+ axs = [axs]
+ axs[0].imshow(image)
+ axs[0].axis("off")
+ if show_titles:
+ axs[0].set_title("Image")
+
+ if show_mask:
+ if self.speed_mask:
+ cmap = copy.copy(plt.get_cmap("autumn_r"))
+ cmap.set_under(color="black")
+ axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none")
+ else:
+ axs[1].imshow(mask, cmap="Greys_r", interpolation="none")
+ axs[1].axis("off")
+ if show_titles:
+ axs[1].set_title("Label")
+
+ if show_predictions:
+ if self.speed_mask:
+ cmap = copy.copy(plt.get_cmap("autumn_r"))
+ cmap.set_under(color="black")
+ axs[2].imshow(
+ prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none"
+ )
+ else:
+ axs[2].imshow(prediction, cmap="Greys_r", interpolation="none")
+ axs[2].axis("off")
+ if show_titles:
+ axs[2].set_title("Prediction")
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+ return fig
+
+
class SpaceNet4(SpaceNet):
"""SpaceNet 4: Off-Nadir Buildings Dataset.
@@ -699,7 +966,7 @@ class SpaceNet4(SpaceNet):
lbl_dir = os.path.dirname(imgpath).split("-nadir")[0]
- lbl_path = os.path.join(lbl_dir + "-labels", self.label_glob)
+ lbl_path = os.path.join(f"{lbl_dir}-labels", self.label_glob)
assert os.path.exists(lbl_path)
_file = {"image_path": imgpath, "label_path": lbl_path}
@@ -724,7 +991,7 @@ class SpaceNet4(SpaceNet):
return files
-class SpaceNet5(SpaceNet):
+class SpaceNet5(SpaceNet3):
r"""SpaceNet 5: Automated Road Network Extraction and Route Travel Time Estimation.
`SpaceNet 5 `_
@@ -832,7 +1099,8 @@ class SpaceNet5(SpaceNet):
speed_mask: use multi-class speed mask (created by binning roads at
10 mph increments) as label if true, else use binary mask
collections: collection selection which must be a subset of:
- [sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai]
+ [sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai]. If unspecified, all
+ collections will be used.
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory.
@@ -842,148 +1110,17 @@ class SpaceNet5(SpaceNet):
Raises:
RuntimeError: if ``download=False`` but dataset is missing
"""
- assert image in {"MS", "PAN", "PS-MS", "PS-RGB"}
- self.speed_mask = speed_mask
super().__init__(
- root, image, collections, transforms, download, api_key, checksum
+ root,
+ image,
+ speed_mask,
+ collections,
+ transforms,
+ download,
+ api_key,
+ checksum,
)
- def _load_mask(
- self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int]
- ) -> Tensor:
- """Rasterizes the dataset's labels (in geojson format).
-
- Args:
- path: path to the label
- tfm: transform of corresponding image
- shape: shape of corresponding image
-
- Returns:
- Tensor: label tensor
- """
- min_speed_bin = 1
- max_speed_bin = 65
- speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1)
- bin_size_mph = 10.0
- speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array(
- [int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin]
- )
-
- try:
- with fiona.open(path) as src:
- vector_crs = CRS(src.crs)
- labels = []
-
- for feature in src:
- if raster_crs != vector_crs:
- geom = transform_geom(
- vector_crs.to_string(),
- raster_crs.to_string(),
- feature["geometry"],
- )
- else:
- geom = feature["geometry"]
-
- if self.speed_mask:
- val = speed_cls_arr[
- int(feature["properties"]["inferred_speed_mph"]) - 1
- ]
- else:
- val = 1
-
- labels.append((geom, val))
-
- except FionaValueError:
- labels = []
-
- if not labels:
- mask_data = np.zeros(shape=shape)
- else:
- mask_data = rasterize(
- labels,
- out_shape=shape,
- fill=0, # nodata value
- transform=tfm,
- all_touched=False,
- dtype=np.uint8,
- )
-
- mask = torch.from_numpy(mask_data).long()
- return mask
-
- def plot(
- self,
- sample: Dict[str, Tensor],
- show_titles: bool = True,
- suptitle: Optional[str] = None,
- ) -> Figure:
- """Plot a sample from the dataset.
-
- Args:
- sample: a sample returned by :meth:`SpaceNet.__getitem__`
- show_titles: flag indicating whether to show titles above each panel
- suptitle: optional string to use as a suptitle
-
- Returns:
- a matplotlib Figure with the rendered sample
-
- .. versionadded:: 0.2
- """
- # image can be 1 channel or >3 channels
- if sample["image"].shape[0] == 1:
- image = np.rollaxis(sample["image"].numpy(), 0, 3)
- else:
- image = np.rollaxis(sample["image"][:3].numpy(), 0, 3)
- image = percentile_normalization(image, axis=(0, 1))
-
- ncols = 1
- show_mask = "mask" in sample
- show_predictions = "prediction" in sample
-
- if show_mask:
- mask = sample["mask"].numpy()
- ncols += 1
-
- if show_predictions:
- prediction = sample["prediction"].numpy()
- ncols += 1
-
- fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8))
- if not isinstance(axs, np.ndarray):
- axs = [axs]
- axs[0].imshow(image)
- axs[0].axis("off")
- if show_titles:
- axs[0].set_title("Image")
-
- if show_mask:
- if self.speed_mask:
- cmap = copy.copy(plt.get_cmap("autumn_r"))
- cmap.set_under(color="black")
- axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none")
- else:
- axs[1].imshow(mask, cmap="Greys_r", interpolation="none")
- axs[1].axis("off")
- if show_titles:
- axs[1].set_title("Label")
-
- if show_predictions:
- if self.speed_mask:
- cmap = copy.copy(plt.get_cmap("autumn_r"))
- cmap.set_under(color="black")
- axs[2].imshow(
- prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none"
- )
- else:
- axs[2].imshow(prediction, cmap="Greys_r", interpolation="none")
- axs[2].axis("off")
- if show_titles:
- axs[2].set_title("Prediction")
-
- if suptitle is not None:
- plt.suptitle(suptitle)
- return fig
-
class SpaceNet7(SpaceNet):
"""SpaceNet 7: Multi-Temporal Urban Development Challenge.
@@ -1126,13 +1263,12 @@ class SpaceNet7(SpaceNet):
Returns:
data at that index
"""
- sample = {}
files = self.files[index]
img, tfm, raster_crs = self._load_image(files["image_path"])
h, w = img.shape[1:]
ch, cw = self.chip_size["img"]
- sample["image"] = img[:, :ch, :cw]
+ sample = {"image": img[:, :ch, :cw]}
if self.split == "train":
mask = self._load_mask(files["label_path"], tfm, raster_crs, (h, w))
sample["mask"] = mask[:ch, :cw]