зеркало из https://github.com/microsoft/torchgeo.git
Adding plotting to ChesapeakeCVPR dataset (#820)
* Adding plotting to ChesapeakeCVPR dataset * De for-looping
This commit is contained in:
Родитель
240df5bd23
Коммит
ba966d76cc
|
@ -211,3 +211,13 @@ class TestChesapeakeCVPR:
|
||||||
IndexError, match="query: .* spans multiple tiles which is not valid"
|
IndexError, match="query: .* spans multiple tiles which is not valid"
|
||||||
):
|
):
|
||||||
ds[dataset.bounds]
|
ds[dataset.bounds]
|
||||||
|
|
||||||
|
def test_plot(self, dataset: ChesapeakeCVPR) -> None:
|
||||||
|
x = dataset[dataset.bounds].copy()
|
||||||
|
dataset.plot(x, suptitle="Test")
|
||||||
|
plt.close()
|
||||||
|
dataset.plot(x, show_titles=False)
|
||||||
|
plt.close()
|
||||||
|
x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2)
|
||||||
|
dataset.plot(x)
|
||||||
|
plt.close()
|
||||||
|
|
|
@ -19,6 +19,7 @@ import shapely.ops
|
||||||
import torch
|
import torch
|
||||||
from matplotlib.colors import ListedColormap
|
from matplotlib.colors import ListedColormap
|
||||||
from rasterio.crs import CRS
|
from rasterio.crs import CRS
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from .geo import GeoDataset, RasterDataset
|
from .geo import GeoDataset, RasterDataset
|
||||||
from .utils import BoundingBox, download_url, extract_archive
|
from .utils import BoundingBox, download_url, extract_archive
|
||||||
|
@ -438,6 +439,46 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
crs = CRS.from_epsg(3857)
|
crs = CRS.from_epsg(3857)
|
||||||
res = 1
|
res = 1
|
||||||
|
|
||||||
|
lc_cmap = {
|
||||||
|
0: (0, 0, 0, 0),
|
||||||
|
1: (0, 197, 255, 255),
|
||||||
|
2: (38, 115, 0, 255),
|
||||||
|
3: (163, 255, 115, 255),
|
||||||
|
4: (255, 170, 0, 255),
|
||||||
|
5: (156, 156, 156, 255),
|
||||||
|
6: (0, 0, 0, 255),
|
||||||
|
15: (0, 0, 0, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
nlcd_cmap = {
|
||||||
|
0: (0, 0, 0, 0),
|
||||||
|
11: (70, 107, 159, 255),
|
||||||
|
12: (209, 222, 248, 255),
|
||||||
|
21: (222, 197, 197, 255),
|
||||||
|
22: (217, 146, 130, 255),
|
||||||
|
23: (235, 0, 0, 255),
|
||||||
|
24: (171, 0, 0, 255),
|
||||||
|
31: (179, 172, 159, 255),
|
||||||
|
41: (104, 171, 95, 255),
|
||||||
|
42: (28, 95, 44, 255),
|
||||||
|
43: (181, 197, 143, 255),
|
||||||
|
52: (204, 184, 121, 255),
|
||||||
|
71: (223, 223, 194, 255),
|
||||||
|
81: (220, 217, 57, 255),
|
||||||
|
82: (171, 108, 40, 255),
|
||||||
|
90: (184, 217, 235, 255),
|
||||||
|
95: (108, 159, 184, 255),
|
||||||
|
}
|
||||||
|
|
||||||
|
prior_color_matrix = np.array(
|
||||||
|
[
|
||||||
|
[0.0, 0.77254902, 1.0, 1.0],
|
||||||
|
[0.14901961, 0.45098039, 0.0, 1.0],
|
||||||
|
[0.63921569, 1.0, 0.45098039, 1.0],
|
||||||
|
[0.61176471, 0.61176471, 0.61176471, 1.0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
valid_layers = [
|
valid_layers = [
|
||||||
"naip-new",
|
"naip-new",
|
||||||
"naip-old",
|
"naip-old",
|
||||||
|
@ -540,6 +581,16 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
|
|
||||||
super().__init__(transforms)
|
super().__init__(transforms)
|
||||||
|
|
||||||
|
lc_colors = np.zeros((max(self.lc_cmap.keys()) + 1, 4))
|
||||||
|
lc_colors[list(self.lc_cmap.keys())] = list(self.lc_cmap.values())
|
||||||
|
lc_colors = lc_colors[:, :3] / 255
|
||||||
|
self._lc_cmap = ListedColormap(lc_colors)
|
||||||
|
|
||||||
|
nlcd_colors = np.zeros((max(self.nlcd_cmap.keys()) + 1, 4))
|
||||||
|
nlcd_colors[list(self.nlcd_cmap.keys())] = list(self.nlcd_cmap.values())
|
||||||
|
nlcd_colors = nlcd_colors[:, :3] / 255
|
||||||
|
self._nlcd_cmap = ListedColormap(nlcd_colors)
|
||||||
|
|
||||||
# Add all tiles into the index in epsg:3857 based on the included geojson
|
# Add all tiles into the index in epsg:3857 based on the included geojson
|
||||||
mint: float = 0
|
mint: float = 0
|
||||||
maxt: float = sys.maxsize
|
maxt: float = sys.maxsize
|
||||||
|
@ -694,3 +745,88 @@ class ChesapeakeCVPR(GeoDataset):
|
||||||
"""Extract the dataset."""
|
"""Extract the dataset."""
|
||||||
for subdataset in self.subdatasets:
|
for subdataset in self.subdatasets:
|
||||||
extract_archive(os.path.join(self.root, self.filenames[subdataset]))
|
extract_archive(os.path.join(self.root, self.filenames[subdataset]))
|
||||||
|
|
||||||
|
def plot(
|
||||||
|
self,
|
||||||
|
sample: Dict[str, Tensor],
|
||||||
|
show_titles: bool = True,
|
||||||
|
suptitle: Optional[str] = None,
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""Plot a sample from the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: a sample returned by :meth:`__getitem__`
|
||||||
|
show_titles: flag indicating whether to show titles above each panel
|
||||||
|
suptitle: optional string to use as a suptitle
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a matplotlib Figure with the rendered sample
|
||||||
|
|
||||||
|
.. versionadded:: 0.4
|
||||||
|
"""
|
||||||
|
image = np.rollaxis(sample["image"].numpy(), 0, 3)
|
||||||
|
mask = np.rollaxis(sample["mask"].numpy(), 0, 3)
|
||||||
|
|
||||||
|
num_panels = len(self.layers)
|
||||||
|
showing_predictions = "prediction" in sample
|
||||||
|
if showing_predictions:
|
||||||
|
predictions = sample["prediction"].numpy()
|
||||||
|
num_panels += 1
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
for layer in self.layers:
|
||||||
|
if layer == "naip-new" or layer == "naip-old":
|
||||||
|
img = image[:, :, :3] / 255
|
||||||
|
image = image[:, :, 4:]
|
||||||
|
axs[i].axis("off")
|
||||||
|
axs[i].imshow(img)
|
||||||
|
elif layer == "landsat-leaf-on" or layer == "landsat-leaf-off":
|
||||||
|
img = image[:, :, [3, 2, 1]] / 3000
|
||||||
|
image = image[:, :, 9:]
|
||||||
|
axs[i].axis("off")
|
||||||
|
axs[i].imshow(img)
|
||||||
|
elif layer == "nlcd":
|
||||||
|
img = mask[:, :, 0]
|
||||||
|
mask = mask[:, :, 1:]
|
||||||
|
axs[i].imshow(
|
||||||
|
img, vmin=0, vmax=95, cmap=self._nlcd_cmap, interpolation="none"
|
||||||
|
)
|
||||||
|
axs[i].axis("off")
|
||||||
|
elif layer == "lc":
|
||||||
|
img = mask[:, :, 0]
|
||||||
|
mask = mask[:, :, 1:]
|
||||||
|
axs[i].imshow(
|
||||||
|
img, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
|
||||||
|
)
|
||||||
|
axs[i].axis("off")
|
||||||
|
elif layer == "buildings":
|
||||||
|
img = mask[:, :, 0]
|
||||||
|
mask = mask[:, :, 1:]
|
||||||
|
axs[i].imshow(img, vmin=0, vmax=1, cmap="gray", interpolation="none")
|
||||||
|
axs[i].axis("off")
|
||||||
|
elif layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
|
||||||
|
img = (mask[:, :, :4] @ self.prior_color_matrix) / 255
|
||||||
|
mask = mask[:, :, 4:]
|
||||||
|
axs[i].imshow(img)
|
||||||
|
axs[i].axis("off")
|
||||||
|
|
||||||
|
if show_titles:
|
||||||
|
if layer == "prior_from_cooccurrences_101_31_no_osm_no_buildings":
|
||||||
|
axs[i].set_title("prior")
|
||||||
|
else:
|
||||||
|
axs[i].set_title(layer)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if showing_predictions:
|
||||||
|
axs[i].imshow(
|
||||||
|
predictions, vmin=0, vmax=15, cmap=self._lc_cmap, interpolation="none"
|
||||||
|
)
|
||||||
|
axs[i].axis("off")
|
||||||
|
if show_titles:
|
||||||
|
axs[i].set_title("Predictions")
|
||||||
|
|
||||||
|
if suptitle is not None:
|
||||||
|
plt.suptitle(suptitle)
|
||||||
|
return fig
|
||||||
|
|
Загрузка…
Ссылка в новой задаче