diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index c877d1b79..1c49b313d 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -211,3 +211,13 @@ class TestChesapeakeCVPR: IndexError, match="query: .* spans multiple tiles which is not valid" ): ds[dataset.bounds] + + def test_plot(self, dataset: ChesapeakeCVPR) -> None: + x = dataset[dataset.bounds].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2) + dataset.plot(x) + plt.close() diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 8ba1c0c8f..db5296468 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -19,6 +19,7 @@ import shapely.ops import torch from matplotlib.colors import ListedColormap from rasterio.crs import CRS +from torch import Tensor from .geo import GeoDataset, RasterDataset from .utils import BoundingBox, download_url, extract_archive @@ -438,6 +439,46 @@ class ChesapeakeCVPR(GeoDataset): crs = CRS.from_epsg(3857) res = 1 + lc_cmap = { + 0: (0, 0, 0, 0), + 1: (0, 197, 255, 255), + 2: (38, 115, 0, 255), + 3: (163, 255, 115, 255), + 4: (255, 170, 0, 255), + 5: (156, 156, 156, 255), + 6: (0, 0, 0, 255), + 15: (0, 0, 0, 0), + } + + nlcd_cmap = { + 0: (0, 0, 0, 0), + 11: (70, 107, 159, 255), + 12: (209, 222, 248, 255), + 21: (222, 197, 197, 255), + 22: (217, 146, 130, 255), + 23: (235, 0, 0, 255), + 24: (171, 0, 0, 255), + 31: (179, 172, 159, 255), + 41: (104, 171, 95, 255), + 42: (28, 95, 44, 255), + 43: (181, 197, 143, 255), + 52: (204, 184, 121, 255), + 71: (223, 223, 194, 255), + 81: (220, 217, 57, 255), + 82: (171, 108, 40, 255), + 90: (184, 217, 235, 255), + 95: (108, 159, 184, 255), + } + + prior_color_matrix = np.array( + [ + [0.0, 0.77254902, 1.0, 1.0], + [0.14901961, 0.45098039, 0.0, 1.0], + [0.63921569, 1.0, 0.45098039, 1.0], + [0.61176471, 0.61176471, 0.61176471, 1.0], + ] + ) + valid_layers = [ "naip-new", "naip-old", @@ -540,6 +581,16 @@ class ChesapeakeCVPR(GeoDataset): super().__init__(transforms) + lc_colors = np.zeros((max(self.lc_cmap.keys()) + 1, 4)) + lc_colors[list(self.lc_cmap.keys())] = list(self.lc_cmap.values()) + lc_colors = lc_colors[:, :3] / 255 + self._lc_cmap = ListedColormap(lc_colors) + + nlcd_colors = np.zeros((max(self.nlcd_cmap.keys()) + 1, 4)) + nlcd_colors[list(self.nlcd_cmap.keys())] = list(self.nlcd_cmap.values()) + nlcd_colors = nlcd_colors[:, :3] / 255 + self._nlcd_cmap = ListedColormap(nlcd_colors) + # Add all tiles into the index in epsg:3857 based on the included geojson mint: float = 0 maxt: float = sys.maxsize @@ -694,3 +745,88 @@ class ChesapeakeCVPR(GeoDataset): """Extract the dataset.""" for subdataset in self.subdatasets: extract_archive(os.path.join(self.root, self.filenames[subdataset])) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + .. versionadded:: 0.4 + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + mask = np.rollaxis(sample["mask"].numpy(), 0, 3) + + num_panels = len(self.layers) + showing_predictions = "prediction" in sample + if showing_predictions: + predictions = sample["prediction"].numpy() + num_panels += 1 + + fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) + + i = 0 + for layer in self.layers: + if layer == "naip-new" or layer == "naip-old": + img = image[:, :, :3] / 255 + image = image[:, :, 4:] + axs[i].axis("off") + axs[i].imshow(img) + elif layer == "landsat-leaf-on" or layer == "landsat-leaf-off": + img = image[:, :, [3, 2, 1]] / 3000 + image = image[:, :, 9:] + axs[i].axis("off") + axs[i].imshow(img) + elif layer == "nlcd": + img = mask[:, :, 0] + mask = mask[:, :, 1:] + axs[i].imshow( + img, vmin=0, vmax=95, cmap=self._nlcd_cmap, interpolation="none" + ) + axs[i].axis("off") + elif layer == "lc": + img = mask[:, :, 0] + mask = mask[:, :, 1:] + axs[i].imshow( + img, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none" + ) + axs[i].axis("off") + elif layer == "buildings": + img = mask[:, :, 0] + mask = mask[:, :, 1:] + axs[i].imshow(img, vmin=0, vmax=1, cmap="gray", interpolation="none") + axs[i].axis("off") + elif layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings": + img = (mask[:, :, :4] @ self.prior_color_matrix) / 255 + mask = mask[:, :, 4:] + axs[i].imshow(img) + axs[i].axis("off") + + if show_titles: + if layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings": + axs[i].set_title("prior") + else: + axs[i].set_title(layer) + i += 1 + + if showing_predictions: + axs[i].imshow( + predictions, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none" + ) + axs[i].axis("off") + if show_titles: + axs[i].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig