Adding plotting to ChesapeakeCVPR dataset (#820)

* Adding plotting to ChesapeakeCVPR dataset

* De for-looping
This commit is contained in:
Caleb Robinson 2022-10-06 11:49:56 -07:00 коммит произвёл GitHub
Родитель 240df5bd23
Коммит ba966d76cc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 146 добавлений и 0 удалений

Просмотреть файл

@ -211,3 +211,13 @@ class TestChesapeakeCVPR:
IndexError, match="query: .* spans multiple tiles which is not valid"
):
ds[dataset.bounds]
def test_plot(self, dataset: ChesapeakeCVPR) -> None:
x = dataset[dataset.bounds].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2)
dataset.plot(x)
plt.close()

Просмотреть файл

@ -19,6 +19,7 @@ import shapely.ops
import torch
from matplotlib.colors import ListedColormap
from rasterio.crs import CRS
from torch import Tensor
from .geo import GeoDataset, RasterDataset
from .utils import BoundingBox, download_url, extract_archive
@ -438,6 +439,46 @@ class ChesapeakeCVPR(GeoDataset):
crs = CRS.from_epsg(3857)
res = 1
lc_cmap = {
0: (0, 0, 0, 0),
1: (0, 197, 255, 255),
2: (38, 115, 0, 255),
3: (163, 255, 115, 255),
4: (255, 170, 0, 255),
5: (156, 156, 156, 255),
6: (0, 0, 0, 255),
15: (0, 0, 0, 0),
}
nlcd_cmap = {
0: (0, 0, 0, 0),
11: (70, 107, 159, 255),
12: (209, 222, 248, 255),
21: (222, 197, 197, 255),
22: (217, 146, 130, 255),
23: (235, 0, 0, 255),
24: (171, 0, 0, 255),
31: (179, 172, 159, 255),
41: (104, 171, 95, 255),
42: (28, 95, 44, 255),
43: (181, 197, 143, 255),
52: (204, 184, 121, 255),
71: (223, 223, 194, 255),
81: (220, 217, 57, 255),
82: (171, 108, 40, 255),
90: (184, 217, 235, 255),
95: (108, 159, 184, 255),
}
prior_color_matrix = np.array(
[
[0.0, 0.77254902, 1.0, 1.0],
[0.14901961, 0.45098039, 0.0, 1.0],
[0.63921569, 1.0, 0.45098039, 1.0],
[0.61176471, 0.61176471, 0.61176471, 1.0],
]
)
valid_layers = [
"naip-new",
"naip-old",
@ -540,6 +581,16 @@ class ChesapeakeCVPR(GeoDataset):
super().__init__(transforms)
lc_colors = np.zeros((max(self.lc_cmap.keys()) + 1, 4))
lc_colors[list(self.lc_cmap.keys())] = list(self.lc_cmap.values())
lc_colors = lc_colors[:, :3] / 255
self._lc_cmap = ListedColormap(lc_colors)
nlcd_colors = np.zeros((max(self.nlcd_cmap.keys()) + 1, 4))
nlcd_colors[list(self.nlcd_cmap.keys())] = list(self.nlcd_cmap.values())
nlcd_colors = nlcd_colors[:, :3] / 255
self._nlcd_cmap = ListedColormap(nlcd_colors)
# Add all tiles into the index in epsg:3857 based on the included geojson
mint: float = 0
maxt: float = sys.maxsize
@ -694,3 +745,88 @@ class ChesapeakeCVPR(GeoDataset):
"""Extract the dataset."""
for subdataset in self.subdatasets:
extract_archive(os.path.join(self.root, self.filenames[subdataset]))
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.4
"""
image = np.rollaxis(sample["image"].numpy(), 0, 3)
mask = np.rollaxis(sample["mask"].numpy(), 0, 3)
num_panels = len(self.layers)
showing_predictions = "prediction" in sample
if showing_predictions:
predictions = sample["prediction"].numpy()
num_panels += 1
fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))
i = 0
for layer in self.layers:
if layer == "naip-new" or layer == "naip-old":
img = image[:, :, :3] / 255
image = image[:, :, 4:]
axs[i].axis("off")
axs[i].imshow(img)
elif layer == "landsat-leaf-on" or layer == "landsat-leaf-off":
img = image[:, :, [3, 2, 1]] / 3000
image = image[:, :, 9:]
axs[i].axis("off")
axs[i].imshow(img)
elif layer == "nlcd":
img = mask[:, :, 0]
mask = mask[:, :, 1:]
axs[i].imshow(
img, vmin=0, vmax=95, cmap=self._nlcd_cmap, interpolation="none"
)
axs[i].axis("off")
elif layer == "lc":
img = mask[:, :, 0]
mask = mask[:, :, 1:]
axs[i].imshow(
img, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
)
axs[i].axis("off")
elif layer == "buildings":
img = mask[:, :, 0]
mask = mask[:, :, 1:]
axs[i].imshow(img, vmin=0, vmax=1, cmap="gray", interpolation="none")
axs[i].axis("off")
elif layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
img = (mask[:, :, :4] @ self.prior_color_matrix) / 255
mask = mask[:, :, 4:]
axs[i].imshow(img)
axs[i].axis("off")
if show_titles:
if layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
axs[i].set_title("prior")
else:
axs[i].set_title(layer)
i += 1
if showing_predictions:
axs[i].imshow(
predictions, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
)
axs[i].axis("off")
if show_titles:
axs[i].set_title("Predictions")
if suptitle is not None:
plt.suptitle(suptitle)
return fig