зеркало из https://github.com/microsoft/torchgeo.git
Adding plotting to ChesapeakeCVPR dataset (#820)
* Adding plotting to ChesapeakeCVPR dataset * De for-looping
This commit is contained in:
Родитель
240df5bd23
Коммит
ba966d76cc
|
@ -211,3 +211,13 @@ class TestChesapeakeCVPR:
|
|||
IndexError, match="query: .* spans multiple tiles which is not valid"
|
||||
):
|
||||
ds[dataset.bounds]
|
||||
|
||||
def test_plot(self, dataset: ChesapeakeCVPR) -> None:
|
||||
x = dataset[dataset.bounds].copy()
|
||||
dataset.plot(x, suptitle="Test")
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2)
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
|
|
@ -19,6 +19,7 @@ import shapely.ops
|
|||
import torch
|
||||
from matplotlib.colors import ListedColormap
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import GeoDataset, RasterDataset
|
||||
from .utils import BoundingBox, download_url, extract_archive
|
||||
|
@ -438,6 +439,46 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
crs = CRS.from_epsg(3857)
|
||||
res = 1
|
||||
|
||||
lc_cmap = {
|
||||
0: (0, 0, 0, 0),
|
||||
1: (0, 197, 255, 255),
|
||||
2: (38, 115, 0, 255),
|
||||
3: (163, 255, 115, 255),
|
||||
4: (255, 170, 0, 255),
|
||||
5: (156, 156, 156, 255),
|
||||
6: (0, 0, 0, 255),
|
||||
15: (0, 0, 0, 0),
|
||||
}
|
||||
|
||||
nlcd_cmap = {
|
||||
0: (0, 0, 0, 0),
|
||||
11: (70, 107, 159, 255),
|
||||
12: (209, 222, 248, 255),
|
||||
21: (222, 197, 197, 255),
|
||||
22: (217, 146, 130, 255),
|
||||
23: (235, 0, 0, 255),
|
||||
24: (171, 0, 0, 255),
|
||||
31: (179, 172, 159, 255),
|
||||
41: (104, 171, 95, 255),
|
||||
42: (28, 95, 44, 255),
|
||||
43: (181, 197, 143, 255),
|
||||
52: (204, 184, 121, 255),
|
||||
71: (223, 223, 194, 255),
|
||||
81: (220, 217, 57, 255),
|
||||
82: (171, 108, 40, 255),
|
||||
90: (184, 217, 235, 255),
|
||||
95: (108, 159, 184, 255),
|
||||
}
|
||||
|
||||
prior_color_matrix = np.array(
|
||||
[
|
||||
[0.0, 0.77254902, 1.0, 1.0],
|
||||
[0.14901961, 0.45098039, 0.0, 1.0],
|
||||
[0.63921569, 1.0, 0.45098039, 1.0],
|
||||
[0.61176471, 0.61176471, 0.61176471, 1.0],
|
||||
]
|
||||
)
|
||||
|
||||
valid_layers = [
|
||||
"naip-new",
|
||||
"naip-old",
|
||||
|
@ -540,6 +581,16 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
|
||||
super().__init__(transforms)
|
||||
|
||||
lc_colors = np.zeros((max(self.lc_cmap.keys()) + 1, 4))
|
||||
lc_colors[list(self.lc_cmap.keys())] = list(self.lc_cmap.values())
|
||||
lc_colors = lc_colors[:, :3] / 255
|
||||
self._lc_cmap = ListedColormap(lc_colors)
|
||||
|
||||
nlcd_colors = np.zeros((max(self.nlcd_cmap.keys()) + 1, 4))
|
||||
nlcd_colors[list(self.nlcd_cmap.keys())] = list(self.nlcd_cmap.values())
|
||||
nlcd_colors = nlcd_colors[:, :3] / 255
|
||||
self._nlcd_cmap = ListedColormap(nlcd_colors)
|
||||
|
||||
# Add all tiles into the index in epsg:3857 based on the included geojson
|
||||
mint: float = 0
|
||||
maxt: float = sys.maxsize
|
||||
|
@ -694,3 +745,88 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
"""Extract the dataset."""
|
||||
for subdataset in self.subdatasets:
|
||||
extract_archive(os.path.join(self.root, self.filenames[subdataset]))
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
image = np.rollaxis(sample["image"].numpy(), 0, 3)
|
||||
mask = np.rollaxis(sample["mask"].numpy(), 0, 3)
|
||||
|
||||
num_panels = len(self.layers)
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
predictions = sample["prediction"].numpy()
|
||||
num_panels += 1
|
||||
|
||||
fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))
|
||||
|
||||
i = 0
|
||||
for layer in self.layers:
|
||||
if layer == "naip-new" or layer == "naip-old":
|
||||
img = image[:, :, :3] / 255
|
||||
image = image[:, :, 4:]
|
||||
axs[i].axis("off")
|
||||
axs[i].imshow(img)
|
||||
elif layer == "landsat-leaf-on" or layer == "landsat-leaf-off":
|
||||
img = image[:, :, [3, 2, 1]] / 3000
|
||||
image = image[:, :, 9:]
|
||||
axs[i].axis("off")
|
||||
axs[i].imshow(img)
|
||||
elif layer == "nlcd":
|
||||
img = mask[:, :, 0]
|
||||
mask = mask[:, :, 1:]
|
||||
axs[i].imshow(
|
||||
img, vmin=0, vmax=95, cmap=self._nlcd_cmap, interpolation="none"
|
||||
)
|
||||
axs[i].axis("off")
|
||||
elif layer == "lc":
|
||||
img = mask[:, :, 0]
|
||||
mask = mask[:, :, 1:]
|
||||
axs[i].imshow(
|
||||
img, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
|
||||
)
|
||||
axs[i].axis("off")
|
||||
elif layer == "buildings":
|
||||
img = mask[:, :, 0]
|
||||
mask = mask[:, :, 1:]
|
||||
axs[i].imshow(img, vmin=0, vmax=1, cmap="gray", interpolation="none")
|
||||
axs[i].axis("off")
|
||||
elif layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
|
||||
img = (mask[:, :, :4] @ self.prior_color_matrix) / 255
|
||||
mask = mask[:, :, 4:]
|
||||
axs[i].imshow(img)
|
||||
axs[i].axis("off")
|
||||
|
||||
if show_titles:
|
||||
if layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
|
||||
axs[i].set_title("prior")
|
||||
else:
|
||||
axs[i].set_title(layer)
|
||||
i += 1
|
||||
|
||||
if showing_predictions:
|
||||
axs[i].imshow(
|
||||
predictions, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
|
||||
)
|
||||
axs[i].axis("off")
|
||||
if show_titles:
|
||||
axs[i].set_title("Predictions")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
||||
|
|
Загрузка…
Ссылка в новой задаче