diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index e18d18786..f4e90812e 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -233,6 +233,7 @@ SpaceNet .. autoclass:: SpaceNet .. autoclass:: SpaceNet1 .. autoclass:: SpaceNet2 +.. autoclass:: SpaceNet3 .. autoclass:: SpaceNet4 .. autoclass:: SpaceNet5 .. autoclass:: SpaceNet7 diff --git a/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz b/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz new file mode 100644 index 000000000..d8cb305dc Binary files /dev/null and b/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz differ diff --git a/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz b/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz new file mode 100644 index 000000000..0daea2f53 Binary files /dev/null and b/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz differ diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index 35ce3cecf..7e4cf90af 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import glob -import itertools import os import shutil from pathlib import Path @@ -14,7 +13,14 @@ import torch.nn as nn from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch -from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7 +from torchgeo.datasets import ( + SpaceNet1, + SpaceNet2, + SpaceNet3, + SpaceNet4, + SpaceNet5, + SpaceNet7, +) TEST_DATA_DIR = "tests/data/spacenet" @@ -142,6 +148,71 @@ class TestSpaceNet2: plt.close() +class TestSpaceNet3: + @pytest.fixture(params=zip(["PAN", "MS"], [False, True])) + def dataset( + self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> SpaceNet3: + radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1") + monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection) + test_md5 = { + "sn3_AOI_3_Paris": "197440e0ade970169a801a173a492c27", + "sn3_AOI_5_Khartoum": "b21ff7dd33a15ec32bd380c083263cdf", + } + + monkeypatch.setattr(SpaceNet3, "collection_md5_dict", test_md5) + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[no-untyped-call] + return SpaceNet3( + root, + image=request.param[0], + speed_mask=request.param[1], + collections=["sn3_AOI_3_Paris", "sn3_AOI_5_Khartoum"], + transforms=transforms, + download=True, + api_key="", + ) + + def test_getitem(self, dataset: SpaceNet3) -> None: + # Iterate over all elements to maximize coverage + samples = [dataset[i] for i in range(len(dataset))] + x = samples[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + if dataset.image == "MS": + assert x["image"].shape[0] == 8 + else: + assert x["image"].shape[0] == 1 + + def test_len(self, dataset: SpaceNet3) -> None: + assert len(dataset) == 4 + + def test_already_downloaded(self, dataset: SpaceNet3) -> None: + SpaceNet3(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + SpaceNet3(str(tmp_path)) + + def test_collection_checksum(self, dataset: SpaceNet3) -> None: + dataset.collection_md5_dict["sn3_AOI_5_Khartoum"] = "randommd5hash123" + with pytest.raises( + RuntimeError, match="Collection sn3_AOI_5_Khartoum corrupted" + ): + SpaceNet3(root=dataset.root, download=True, checksum=True) + + def test_plot(self, dataset: SpaceNet3) -> None: + x = dataset[0].copy() + x["prediction"] = x["mask"] + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + dataset.plot({"image": x["image"]}) + plt.close() + + class TestSpaceNet4: @pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"]) def dataset( @@ -206,9 +277,7 @@ class TestSpaceNet4: class TestSpaceNet5: - @pytest.fixture( - params=itertools.product(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True]) - ) + @pytest.fixture(params=zip(["PAN", "MS"], [False, True])) def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> SpaceNet5: @@ -239,9 +308,7 @@ class TestSpaceNet5: assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) assert isinstance(x["mask"], torch.Tensor) - if dataset.image == "PS-RGB": - assert x["image"].shape[0] == 3 - elif dataset.image in ["MS", "PS-MS"]: + if dataset.image == "MS": assert x["image"].shape[0] == 8 else: assert x["image"].shape[0] == 1 diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 8d28060b4..0bbdb94b0 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -75,7 +75,15 @@ from .seco import SeasonalContrastS2 from .sen12ms import SEN12MS from .sentinel import Sentinel, Sentinel2 from .so2sat import So2Sat -from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7 +from .spacenet import ( + SpaceNet, + SpaceNet1, + SpaceNet2, + SpaceNet3, + SpaceNet4, + SpaceNet5, + SpaceNet7, +) from .ucmerced import UCMerced from .usavars import USAVars from .utils import ( @@ -155,6 +163,7 @@ __all__ = ( "SpaceNet", "SpaceNet1", "SpaceNet2", + "SpaceNet3", "SpaceNet4", "SpaceNet5", "SpaceNet7", diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 5cf7d9fe5..7746e3e09 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -133,7 +133,7 @@ class SpaceNet(VisionDataset, abc.ABC): images = sorted(images) for imgpath in images: lbl_path = os.path.join( - os.path.dirname(imgpath) + "-labels", self.label_glob + f"{os.path.dirname(imgpath)}-labels", self.label_glob ) files.append({"image_path": imgpath, "label_path": lbl_path}) return files @@ -248,7 +248,7 @@ class SpaceNet(VisionDataset, abc.ABC): to_be_downloaded = [] for collection in missing_collections: - archive_path = os.path.join(self.root, collection + ".tar.gz") + archive_path = os.path.join(self.root, f"{collection}.tar.gz") if os.path.exists(archive_path): print(f"Found {collection} archive") if ( @@ -281,7 +281,7 @@ class SpaceNet(VisionDataset, abc.ABC): """ for collection in collections: download_radiant_mlhub_collection(collection, self.root, api_key) - archive_path = os.path.join(self.root, collection + ".tar.gz") + archive_path = os.path.join(self.root, f"{collection}.tar.gz") if ( not self.checksum or not check_integrity( @@ -538,7 +538,8 @@ class SpaceNet2(SpaceNet): image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"] collections: collection selection which must be a subset of: [sn2_AOI_2_Vegas, sn2_AOI_3_Paris, sn2_AOI_4_Shanghai, - sn2_AOI_5_Khartoum] + sn2_AOI_5_Khartoum]. If unspecified, all collections will be + used. transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory. @@ -554,6 +555,272 @@ class SpaceNet2(SpaceNet): ) +class SpaceNet3(SpaceNet): + r"""SpaceNet 3: Road Network Detection. + + `SpaceNet 3 `_ + is a dataset of road networks over the cities of Las Vegas, Paris, Shanghai, + and Khartoum. + + Collection features: + + +------------+---------------------+------------+---------------------------+ + | AOI | Area (km\ :sup:`2`\)| # Images | # Road Network Labels (km)| + +============+=====================+============+===========================+ + | Vegas | 216 | 854 | 3685 | + +------------+---------------------+------------+---------------------------+ + | Paris | 1030 | 257 | 425 | + +------------+---------------------+------------+---------------------------+ + | Shanghai | 1000 | 1028 | 3537 | + +------------+---------------------+------------+---------------------------+ + | Khartoum | 765 | 283 | 1030 | + +------------+---------------------+------------+---------------------------+ + + Imagery features: + + .. list-table:: + :widths: 10 10 10 10 10 + :header-rows: 1 + :stub-columns: 1 + + * - + - PAN + - MS + - PS-MS + - PS-RGB + * - GSD (m) + - 0.31 + - 1.24 + - 0.30 + - 0.30 + * - Chip size (px) + - 1300 x 1300 + - 325 x 325 + - 1300 x 1300 + - 1300 x 1300 + + Dataset format: + + * Imagery - Worldview-3 GeoTIFFs + + * PAN.tif (Panchromatic) + * MS.tif (Multispectral) + * PS-MS (Pansharpened Multispectral) + * PS-RGB (Pansharpened RGB) + + * Labels - GeoJSON + + * labels.geojson + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/1807.01232 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `radiant-mlhub `_ to download the + imagery and labels from the Radiant Earth MLHub + + .. versionadded:: 0.3 + """ + + dataset_id = "spacenet3" + collection_md5_dict = { + "sn3_AOI_2_Vegas": "8ce7e6abffb8849eb88885035f061ee8", + "sn3_AOI_3_Paris": "90b9ebd64cd83dc8d3d4773f45050d8f", + "sn3_AOI_4_Shanghai": "3ea291df34548962dfba8b5ed37d700c", + "sn3_AOI_5_Khartoum": "b8d549ac9a6d7456c0f7a8e6de23d9f9", + } + + imagery = { + "MS": "MS.tif", + "PAN": "PAN.tif", + "PS-MS": "PS-MS.tif", + "PS-RGB": "PS-RGB.tif", + } + chip_size = { + "MS": (325, 325), + "PAN": (1300, 1300), + "PS-MS": (1300, 1300), + "PS-RGB": (1300, 1300), + } + label_glob = "labels.geojson" + + def __init__( + self, + root: str, + image: str = "PS-RGB", + speed_mask: Optional[bool] = False, + collections: List[str] = [], + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + download: bool = False, + api_key: Optional[str] = None, + checksum: bool = False, + ) -> None: + """Initialize a new SpaceNet 3 Dataset instance. + + Args: + root: root directory where dataset can be found + image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"] + speed_mask: use multi-class speed mask (created by binning roads at + 10 mph increments) as label if true, else use binary mask + collections: collection selection which must be a subset of: + [sn3_AOI_2_Vegas, sn3_AOI_3_Paris, sn3_AOI_4_Shanghai, + sn3_AOI_5_Khartoum]. If unspecified, all collections will be + used. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory. + api_key: a RadiantEarth MLHub API key to use for downloading the dataset + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` but dataset is missing + """ + assert image in {"MS", "PAN", "PS-MS", "PS-RGB"} + self.speed_mask = speed_mask + super().__init__( + root, image, collections, transforms, download, api_key, checksum + ) + + def _load_mask( + self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int] + ) -> Tensor: + """Rasterizes the dataset's labels (in geojson format). + + Args: + path: path to the label + tfm: transform of corresponding image + shape: shape of corresponding image + + Returns: + Tensor: label tensor + """ + min_speed_bin = 1 + max_speed_bin = 65 + speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1) + bin_size_mph = 10.0 + speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array( + [int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin] + ) + + try: + with fiona.open(path) as src: + vector_crs = CRS(src.crs) + labels = [] + + for feature in src: + if raster_crs != vector_crs: + geom = transform_geom( + vector_crs.to_string(), + raster_crs.to_string(), + feature["geometry"], + ) + else: + geom = feature["geometry"] + + if self.speed_mask: + val = speed_cls_arr[ + int(feature["properties"]["inferred_speed_mph"]) - 1 + ] + else: + val = 1 + + labels.append((geom, val)) + + except FionaValueError: + labels = [] + + if not labels: + mask_data = np.zeros(shape=shape) + else: + mask_data = rasterize( + labels, + out_shape=shape, + fill=0, # nodata value + transform=tfm, + all_touched=False, + dtype=np.uint8, + ) + + mask = torch.from_numpy(mask_data).long() + return mask + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`SpaceNet.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + """ + # image can be 1 channel or >3 channels + if sample["image"].shape[0] == 1: + image = np.rollaxis(sample["image"].numpy(), 0, 3) + else: + image = np.rollaxis(sample["image"][:3].numpy(), 0, 3) + image = percentile_normalization(image, axis=(0, 1)) + + ncols = 1 + show_mask = "mask" in sample + show_predictions = "prediction" in sample + + if show_mask: + mask = sample["mask"].numpy() + ncols += 1 + + if show_predictions: + prediction = sample["prediction"].numpy() + ncols += 1 + + fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8)) + if not isinstance(axs, np.ndarray): + axs = [axs] + axs[0].imshow(image) + axs[0].axis("off") + if show_titles: + axs[0].set_title("Image") + + if show_mask: + if self.speed_mask: + cmap = copy.copy(plt.get_cmap("autumn_r")) + cmap.set_under(color="black") + axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none") + else: + axs[1].imshow(mask, cmap="Greys_r", interpolation="none") + axs[1].axis("off") + if show_titles: + axs[1].set_title("Label") + + if show_predictions: + if self.speed_mask: + cmap = copy.copy(plt.get_cmap("autumn_r")) + cmap.set_under(color="black") + axs[2].imshow( + prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none" + ) + else: + axs[2].imshow(prediction, cmap="Greys_r", interpolation="none") + axs[2].axis("off") + if show_titles: + axs[2].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + + class SpaceNet4(SpaceNet): """SpaceNet 4: Off-Nadir Buildings Dataset. @@ -699,7 +966,7 @@ class SpaceNet4(SpaceNet): lbl_dir = os.path.dirname(imgpath).split("-nadir")[0] - lbl_path = os.path.join(lbl_dir + "-labels", self.label_glob) + lbl_path = os.path.join(f"{lbl_dir}-labels", self.label_glob) assert os.path.exists(lbl_path) _file = {"image_path": imgpath, "label_path": lbl_path} @@ -724,7 +991,7 @@ class SpaceNet4(SpaceNet): return files -class SpaceNet5(SpaceNet): +class SpaceNet5(SpaceNet3): r"""SpaceNet 5: Automated Road Network Extraction and Route Travel Time Estimation. `SpaceNet 5 `_ @@ -832,7 +1099,8 @@ class SpaceNet5(SpaceNet): speed_mask: use multi-class speed mask (created by binning roads at 10 mph increments) as label if true, else use binary mask collections: collection selection which must be a subset of: - [sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai] + [sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai]. If unspecified, all + collections will be used. transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory. @@ -842,148 +1110,17 @@ class SpaceNet5(SpaceNet): Raises: RuntimeError: if ``download=False`` but dataset is missing """ - assert image in {"MS", "PAN", "PS-MS", "PS-RGB"} - self.speed_mask = speed_mask super().__init__( - root, image, collections, transforms, download, api_key, checksum + root, + image, + speed_mask, + collections, + transforms, + download, + api_key, + checksum, ) - def _load_mask( - self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int] - ) -> Tensor: - """Rasterizes the dataset's labels (in geojson format). - - Args: - path: path to the label - tfm: transform of corresponding image - shape: shape of corresponding image - - Returns: - Tensor: label tensor - """ - min_speed_bin = 1 - max_speed_bin = 65 - speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1) - bin_size_mph = 10.0 - speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array( - [int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin] - ) - - try: - with fiona.open(path) as src: - vector_crs = CRS(src.crs) - labels = [] - - for feature in src: - if raster_crs != vector_crs: - geom = transform_geom( - vector_crs.to_string(), - raster_crs.to_string(), - feature["geometry"], - ) - else: - geom = feature["geometry"] - - if self.speed_mask: - val = speed_cls_arr[ - int(feature["properties"]["inferred_speed_mph"]) - 1 - ] - else: - val = 1 - - labels.append((geom, val)) - - except FionaValueError: - labels = [] - - if not labels: - mask_data = np.zeros(shape=shape) - else: - mask_data = rasterize( - labels, - out_shape=shape, - fill=0, # nodata value - transform=tfm, - all_touched=False, - dtype=np.uint8, - ) - - mask = torch.from_numpy(mask_data).long() - return mask - - def plot( - self, - sample: Dict[str, Tensor], - show_titles: bool = True, - suptitle: Optional[str] = None, - ) -> Figure: - """Plot a sample from the dataset. - - Args: - sample: a sample returned by :meth:`SpaceNet.__getitem__` - show_titles: flag indicating whether to show titles above each panel - suptitle: optional string to use as a suptitle - - Returns: - a matplotlib Figure with the rendered sample - - .. versionadded:: 0.2 - """ - # image can be 1 channel or >3 channels - if sample["image"].shape[0] == 1: - image = np.rollaxis(sample["image"].numpy(), 0, 3) - else: - image = np.rollaxis(sample["image"][:3].numpy(), 0, 3) - image = percentile_normalization(image, axis=(0, 1)) - - ncols = 1 - show_mask = "mask" in sample - show_predictions = "prediction" in sample - - if show_mask: - mask = sample["mask"].numpy() - ncols += 1 - - if show_predictions: - prediction = sample["prediction"].numpy() - ncols += 1 - - fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8)) - if not isinstance(axs, np.ndarray): - axs = [axs] - axs[0].imshow(image) - axs[0].axis("off") - if show_titles: - axs[0].set_title("Image") - - if show_mask: - if self.speed_mask: - cmap = copy.copy(plt.get_cmap("autumn_r")) - cmap.set_under(color="black") - axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none") - else: - axs[1].imshow(mask, cmap="Greys_r", interpolation="none") - axs[1].axis("off") - if show_titles: - axs[1].set_title("Label") - - if show_predictions: - if self.speed_mask: - cmap = copy.copy(plt.get_cmap("autumn_r")) - cmap.set_under(color="black") - axs[2].imshow( - prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none" - ) - else: - axs[2].imshow(prediction, cmap="Greys_r", interpolation="none") - axs[2].axis("off") - if show_titles: - axs[2].set_title("Prediction") - - if suptitle is not None: - plt.suptitle(suptitle) - return fig - class SpaceNet7(SpaceNet): """SpaceNet 7: Multi-Temporal Urban Development Challenge. @@ -1126,13 +1263,12 @@ class SpaceNet7(SpaceNet): Returns: data at that index """ - sample = {} files = self.files[index] img, tfm, raster_crs = self._load_image(files["image_path"]) h, w = img.shape[1:] ch, cw = self.chip_size["img"] - sample["image"] = img[:, :ch, :cw] + sample = {"image": img[:, :ch, :cw]} if self.split == "train": mask = self._load_mask(files["label_path"], tfm, raster_crs, (h, w)) sample["mask"] = mask[:ch, :cw]