зеркало из https://github.com/microsoft/torchgeo.git
* add plot method to sen12 * tuple
This commit is contained in:
Родитель
9e07927c63
Коммит
7d90045b9b
|
@ -5,6 +5,7 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -82,3 +83,21 @@ class TestSEN12MS:
|
|||
ds = SEN12MS(root, bands=bands, checksum=False)
|
||||
x = ds[0]["image"]
|
||||
assert x.shape[0] == len(bands)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
SEN12MS(bands=("OK", "BK"))
|
||||
|
||||
def test_plot(self, dataset: SEN12MS) -> None:
|
||||
dataset.plot(dataset[0], suptitle="Test")
|
||||
plt.close()
|
||||
|
||||
sample = dataset[0]
|
||||
sample["prediction"] = sample["mask"].clone()
|
||||
dataset.plot(sample, suptitle="prediction")
|
||||
plt.close()
|
||||
|
||||
def test_plot_rgb(self, dataset: SEN12MS) -> None:
|
||||
dataset = SEN12MS(root=dataset.root, bands=("B03",))
|
||||
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
|
||||
dataset.plot(dataset[0], suptitle="Single Band")
|
||||
|
|
|
@ -4,15 +4,16 @@
|
|||
"""SEN12MS dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Callable, Dict, Optional, Sequence, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity
|
||||
from .utils import check_integrity, percentile_normalization
|
||||
|
||||
|
||||
class SEN12MS(VisionDataset):
|
||||
|
@ -62,13 +63,63 @@ class SEN12MS(VisionDataset):
|
|||
This download will likely take several hours.
|
||||
""" # noqa: E501
|
||||
|
||||
BAND_SETS: Dict[str, List[int]] = {
|
||||
"all": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
|
||||
"s1": [0, 1],
|
||||
"s2-all": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
|
||||
"s2-reduced": [3, 4, 5, 9, 12, 13],
|
||||
BAND_SETS: Dict[str, Tuple[str, ...]] = {
|
||||
"all": (
|
||||
"VV",
|
||||
"VH",
|
||||
"B01",
|
||||
"B02",
|
||||
"B03",
|
||||
"B04",
|
||||
"B05",
|
||||
"B06",
|
||||
"B07",
|
||||
"B08",
|
||||
"B8A",
|
||||
"B09",
|
||||
"B10",
|
||||
"B11",
|
||||
"B12",
|
||||
),
|
||||
"s1": ("VV", "VH"),
|
||||
"s2-all": (
|
||||
"B01",
|
||||
"B02",
|
||||
"B03",
|
||||
"B04",
|
||||
"B05",
|
||||
"B06",
|
||||
"B07",
|
||||
"B08",
|
||||
"B8A",
|
||||
"B09",
|
||||
"B10",
|
||||
"B11",
|
||||
"B12",
|
||||
),
|
||||
"s2-reduced": ("B02", "B03", "B04", "B08", "B10", "B11"),
|
||||
}
|
||||
|
||||
band_names = (
|
||||
"VV",
|
||||
"VH",
|
||||
"B01",
|
||||
"B02",
|
||||
"B03",
|
||||
"B04",
|
||||
"B05",
|
||||
"B06",
|
||||
"B07",
|
||||
"B08",
|
||||
"B8A",
|
||||
"B09",
|
||||
"B10",
|
||||
"B11",
|
||||
"B12",
|
||||
)
|
||||
|
||||
RGB_BANDS = ["B04", "B03", "B02"]
|
||||
|
||||
filenames = [
|
||||
"ROIs1158_spring_lc.tar.gz",
|
||||
"ROIs1158_spring_s1.tar.gz",
|
||||
|
@ -114,7 +165,7 @@ class SEN12MS(VisionDataset):
|
|||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
bands: List[int] = BAND_SETS["all"],
|
||||
bands: Sequence[str] = BAND_SETS["all"],
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -128,7 +179,7 @@ class SEN12MS(VisionDataset):
|
|||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train" or "test"
|
||||
bands: a list of band indices to use where the indices correspond to the
|
||||
bands: a sequence of band indices to use where the indices correspond to the
|
||||
array index of combined Sentinel 1 and Sentinel 2
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
|
@ -140,9 +191,14 @@ class SEN12MS(VisionDataset):
|
|||
"""
|
||||
assert split in ["train", "test"]
|
||||
|
||||
self._validate_bands(bands)
|
||||
self.band_indices = torch.tensor( # type: ignore[attr-defined]
|
||||
[self.band_names.index(b) for b in bands]
|
||||
).long()
|
||||
self.bands = bands
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.bands = torch.tensor(bands).long() # type: ignore[attr-defined]
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
||||
|
@ -173,7 +229,7 @@ class SEN12MS(VisionDataset):
|
|||
|
||||
image = torch.cat(tensors=[s1, s2], dim=0) # type: ignore[attr-defined]
|
||||
image = torch.index_select( # type: ignore[attr-defined]
|
||||
image, dim=0, index=self.bands
|
||||
image, dim=0, index=self.band_indices
|
||||
)
|
||||
|
||||
sample: Dict[str, Tensor] = {"image": image, "mask": lc}
|
||||
|
@ -216,6 +272,21 @@ class SEN12MS(VisionDataset):
|
|||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _validate_bands(self, bands: Sequence[str]) -> None:
|
||||
"""Validate list of bands.
|
||||
|
||||
Args:
|
||||
bands: user-provided sequence of bands to load
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``bands`` is not a sequence
|
||||
ValueError: if an invalid band name is provided
|
||||
"""
|
||||
assert isinstance(bands, tuple), "'bands' must be a sequence"
|
||||
for band in bands:
|
||||
if band not in self.band_names:
|
||||
raise ValueError(f"'{band}' is an invalid band name.")
|
||||
|
||||
def _check_integrity_light(self) -> bool:
|
||||
"""Checks the integrity of the dataset structure.
|
||||
|
||||
|
@ -239,3 +310,59 @@ class SEN12MS(VisionDataset):
|
|||
if not check_integrity(filepath, md5 if self.checksum else None):
|
||||
return False
|
||||
return True
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional suptitle to use for figure
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
rgb_indices = []
|
||||
for band in self.RGB_BANDS:
|
||||
if band in self.bands:
|
||||
rgb_indices.append(self.bands.index(band))
|
||||
else:
|
||||
raise ValueError("Dataset doesn't contain some of the RGB bands")
|
||||
|
||||
image, mask = sample["image"][rgb_indices].numpy(), sample["mask"][0]
|
||||
image = percentile_normalization(image)
|
||||
ncols = 2
|
||||
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
prediction = sample["prediction"][0]
|
||||
ncols += 1
|
||||
|
||||
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))
|
||||
|
||||
axs[0].imshow(np.transpose(image, (1, 2, 0)))
|
||||
axs[0].axis("off")
|
||||
axs[1].imshow(mask)
|
||||
axs[1].axis("off")
|
||||
|
||||
if showing_predictions:
|
||||
axs[2].imshow(prediction)
|
||||
axs[2].axis("off")
|
||||
|
||||
if show_titles:
|
||||
axs[0].set_title("Image")
|
||||
axs[1].set_title("Mask")
|
||||
if showing_predictions:
|
||||
axs[2].set_title("Prediction")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
|
Загрузка…
Ссылка в новой задаче