Add plotting method and band selection to Sen12ms, replacing #320 (#338)

* add plot method to sen12

* tuple
This commit is contained in:
Nils Lehmann 2021-12-31 20:54:01 +01:00 коммит произвёл GitHub
Родитель 9e07927c63
Коммит 7d90045b9b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 157 добавлений и 11 удалений

Просмотреть файл

@ -5,6 +5,7 @@ import os
from pathlib import Path
from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
@ -82,3 +83,21 @@ class TestSEN12MS:
ds = SEN12MS(root, bands=bands, checksum=False)
x = ds[0]["image"]
assert x.shape[0] == len(bands)
def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):
SEN12MS(bands=("OK", "BK"))
def test_plot(self, dataset: SEN12MS) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()
sample = dataset[0]
sample["prediction"] = sample["mask"].clone()
dataset.plot(sample, suptitle="prediction")
plt.close()
def test_plot_rgb(self, dataset: SEN12MS) -> None:
dataset = SEN12MS(root=dataset.root, bands=("B03",))
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], suptitle="Single Band")

Просмотреть файл

@ -4,15 +4,16 @@
"""SEN12MS dataset."""
import os
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, Optional, Sequence, Tuple
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from torch import Tensor
from .geo import VisionDataset
from .utils import check_integrity
from .utils import check_integrity, percentile_normalization
class SEN12MS(VisionDataset):
@ -62,13 +63,63 @@ class SEN12MS(VisionDataset):
This download will likely take several hours.
""" # noqa: E501
BAND_SETS: Dict[str, List[int]] = {
"all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"s1": [0, 1],
"s2-all": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"s2-reduced": [3, 4, 5, 9, 12, 13],
BAND_SETS: Dict[str, Tuple[str, ...]] = {
"all": (
"VV",
"VH",
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
),
"s1": ("VV", "VH"),
"s2-all": (
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
),
"s2-reduced": ("B02", "B03", "B04", "B08", "B10", "B11"),
}
band_names = (
"VV",
"VH",
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B10",
"B11",
"B12",
)
RGB_BANDS = ["B04", "B03", "B02"]
filenames = [
"ROIs1158_spring_lc.tar.gz",
"ROIs1158_spring_s1.tar.gz",
@ -114,7 +165,7 @@ class SEN12MS(VisionDataset):
self,
root: str = "data",
split: str = "train",
bands: List[int] = BAND_SETS["all"],
bands: Sequence[str] = BAND_SETS["all"],
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
checksum: bool = False,
) -> None:
@ -128,7 +179,7 @@ class SEN12MS(VisionDataset):
Args:
root: root directory where dataset can be found
split: one of "train" or "test"
bands: a list of band indices to use where the indices correspond to the
bands: a sequence of band indices to use where the indices correspond to the
array index of combined Sentinel 1 and Sentinel 2
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
@ -140,9 +191,14 @@ class SEN12MS(VisionDataset):
"""
assert split in ["train", "test"]
self._validate_bands(bands)
self.band_indices = torch.tensor( # type: ignore[attr-defined]
[self.band_names.index(b) for b in bands]
).long()
self.bands = bands
self.root = root
self.split = split
self.bands = torch.tensor(bands).long() # type: ignore[attr-defined]
self.transforms = transforms
self.checksum = checksum
@ -173,7 +229,7 @@ class SEN12MS(VisionDataset):
image = torch.cat(tensors=[s1, s2], dim=0) # type: ignore[attr-defined]
image = torch.index_select( # type: ignore[attr-defined]
image, dim=0, index=self.bands
image, dim=0, index=self.band_indices
)
sample: Dict[str, Tensor] = {"image": image, "mask": lc}
@ -216,6 +272,21 @@ class SEN12MS(VisionDataset):
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor
def _validate_bands(self, bands: Sequence[str]) -> None:
"""Validate list of bands.
Args:
bands: user-provided sequence of bands to load
Raises:
AssertionError: if ``bands`` is not a sequence
ValueError: if an invalid band name is provided
"""
assert isinstance(bands, tuple), "'bands' must be a sequence"
for band in bands:
if band not in self.band_names:
raise ValueError(f"'{band}' is an invalid band name.")
def _check_integrity_light(self) -> bool:
"""Checks the integrity of the dataset structure.
@ -239,3 +310,59 @@ class SEN12MS(VisionDataset):
if not check_integrity(filepath, md5 if self.checksum else None):
return False
return True
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional suptitle to use for figure
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.2
"""
rgb_indices = []
for band in self.RGB_BANDS:
if band in self.bands:
rgb_indices.append(self.bands.index(band))
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")
image, mask = sample["image"][rgb_indices].numpy(), sample["mask"][0]
image = percentile_normalization(image)
ncols = 2
showing_predictions = "prediction" in sample
if showing_predictions:
prediction = sample["prediction"][0]
ncols += 1
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))
axs[0].imshow(np.transpose(image, (1, 2, 0)))
axs[0].axis("off")
axs[1].imshow(mask)
axs[1].axis("off")
if showing_predictions:
axs[2].imshow(prediction)
axs[2].axis("off")
if show_titles:
axs[0].set_title("Image")
axs[1].set_title("Mask")
if showing_predictions:
axs[2].set_title("Prediction")
if suptitle is not None:
plt.suptitle(suptitle)
return fig