Adding plotting to ChesapeakeCVPR dataset (#820)

* Adding plotting to ChesapeakeCVPR dataset

* De for-looping
This commit is contained in:
Caleb Robinson 2022-10-06 11:49:56 -07:00 коммит произвёл GitHub
Родитель 240df5bd23
Коммит ba966d76cc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 146 добавлений и 0 удалений

Просмотреть файл

@ -211,3 +211,13 @@ class TestChesapeakeCVPR:
IndexError, match="query: .* spans multiple tiles which is not valid" IndexError, match="query: .* spans multiple tiles which is not valid"
): ):
ds[dataset.bounds] ds[dataset.bounds]
def test_plot(self, dataset: ChesapeakeCVPR) -> None:
x = dataset[dataset.bounds].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2)
dataset.plot(x)
plt.close()

Просмотреть файл

@ -19,6 +19,7 @@ import shapely.ops
import torch import torch
from matplotlib.colors import ListedColormap from matplotlib.colors import ListedColormap
from rasterio.crs import CRS from rasterio.crs import CRS
from torch import Tensor
from .geo import GeoDataset, RasterDataset from .geo import GeoDataset, RasterDataset
from .utils import BoundingBox, download_url, extract_archive from .utils import BoundingBox, download_url, extract_archive
@ -438,6 +439,46 @@ class ChesapeakeCVPR(GeoDataset):
crs = CRS.from_epsg(3857) crs = CRS.from_epsg(3857)
res = 1 res = 1
lc_cmap = {
0: (0, 0, 0, 0),
1: (0, 197, 255, 255),
2: (38, 115, 0, 255),
3: (163, 255, 115, 255),
4: (255, 170, 0, 255),
5: (156, 156, 156, 255),
6: (0, 0, 0, 255),
15: (0, 0, 0, 0),
}
nlcd_cmap = {
0: (0, 0, 0, 0),
11: (70, 107, 159, 255),
12: (209, 222, 248, 255),
21: (222, 197, 197, 255),
22: (217, 146, 130, 255),
23: (235, 0, 0, 255),
24: (171, 0, 0, 255),
31: (179, 172, 159, 255),
41: (104, 171, 95, 255),
42: (28, 95, 44, 255),
43: (181, 197, 143, 255),
52: (204, 184, 121, 255),
71: (223, 223, 194, 255),
81: (220, 217, 57, 255),
82: (171, 108, 40, 255),
90: (184, 217, 235, 255),
95: (108, 159, 184, 255),
}
prior_color_matrix = np.array(
[
[0.0, 0.77254902, 1.0, 1.0],
[0.14901961, 0.45098039, 0.0, 1.0],
[0.63921569, 1.0, 0.45098039, 1.0],
[0.61176471, 0.61176471, 0.61176471, 1.0],
]
)
valid_layers = [ valid_layers = [
"naip-new", "naip-new",
"naip-old", "naip-old",
@ -540,6 +581,16 @@ class ChesapeakeCVPR(GeoDataset):
super().__init__(transforms) super().__init__(transforms)
lc_colors = np.zeros((max(self.lc_cmap.keys()) + 1, 4))
lc_colors[list(self.lc_cmap.keys())] = list(self.lc_cmap.values())
lc_colors = lc_colors[:, :3] / 255
self._lc_cmap = ListedColormap(lc_colors)
nlcd_colors = np.zeros((max(self.nlcd_cmap.keys()) + 1, 4))
nlcd_colors[list(self.nlcd_cmap.keys())] = list(self.nlcd_cmap.values())
nlcd_colors = nlcd_colors[:, :3] / 255
self._nlcd_cmap = ListedColormap(nlcd_colors)
# Add all tiles into the index in epsg:3857 based on the included geojson # Add all tiles into the index in epsg:3857 based on the included geojson
mint: float = 0 mint: float = 0
maxt: float = sys.maxsize maxt: float = sys.maxsize
@ -694,3 +745,88 @@ class ChesapeakeCVPR(GeoDataset):
"""Extract the dataset.""" """Extract the dataset."""
for subdataset in self.subdatasets: for subdataset in self.subdatasets:
extract_archive(os.path.join(self.root, self.filenames[subdataset])) extract_archive(os.path.join(self.root, self.filenames[subdataset]))
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.4
"""
image = np.rollaxis(sample["image"].numpy(), 0, 3)
mask = np.rollaxis(sample["mask"].numpy(), 0, 3)
num_panels = len(self.layers)
showing_predictions = "prediction" in sample
if showing_predictions:
predictions = sample["prediction"].numpy()
num_panels += 1
fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))
i = 0
for layer in self.layers:
if layer == "naip-new" or layer == "naip-old":
img = image[:, :, :3] / 255
image = image[:, :, 4:]
axs[i].axis("off")
axs[i].imshow(img)
elif layer == "landsat-leaf-on" or layer == "landsat-leaf-off":
img = image[:, :, [3, 2, 1]] / 3000
image = image[:, :, 9:]
axs[i].axis("off")
axs[i].imshow(img)
elif layer == "nlcd":
img = mask[:, :, 0]
mask = mask[:, :, 1:]
axs[i].imshow(
img, vmin=0, vmax=95, cmap=self._nlcd_cmap, interpolation="none"
)
axs[i].axis("off")
elif layer == "lc":
img = mask[:, :, 0]
mask = mask[:, :, 1:]
axs[i].imshow(
img, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
)
axs[i].axis("off")
elif layer == "buildings":
img = mask[:, :, 0]
mask = mask[:, :, 1:]
axs[i].imshow(img, vmin=0, vmax=1, cmap="gray", interpolation="none")
axs[i].axis("off")
elif layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
img = (mask[:, :, :4] @ self.prior_color_matrix) / 255
mask = mask[:, :, 4:]
axs[i].imshow(img)
axs[i].axis("off")
if show_titles:
if layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
axs[i].set_title("prior")
else:
axs[i].set_title(layer)
i += 1
if showing_predictions:
axs[i].imshow(
predictions, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
)
axs[i].axis("off")
if show_titles:
axs[i].set_title("Predictions")
if suptitle is not None:
plt.suptitle(suptitle)
return fig