Add band selection to So2Sat Dataset and adopt plot method (#394)

* add band selection and adopt plot method

* np array typing and raise error doc
This commit is contained in:
Nils Lehmann 2022-02-20 20:47:52 +01:00 коммит произвёл GitHub
Родитель b36e053ecf
Коммит 89277dc325
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 99 добавлений и 10 удалений

Просмотреть файл

@ -33,7 +33,7 @@ class TestSo2Sat:
root = os.path.join("tests", "data", "so2sat") root = os.path.join("tests", "data", "so2sat")
split = request.param split = request.param
transforms = nn.Identity() # type: ignore[attr-defined] transforms = nn.Identity() # type: ignore[attr-defined]
return So2Sat(root, split, transforms, checksum=True) return So2Sat(root=root, split=split, transforms=transforms, checksum=True)
@pytest.fixture @pytest.fixture
def mock_missing_module( def mock_missing_module(
@ -69,6 +69,10 @@ class TestSo2Sat:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
So2Sat(split="foo") So2Sat(split="foo")
def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):
So2Sat(bands=("OK", "BK"))
def test_not_downloaded(self, tmp_path: Path) -> None: def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."): with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
So2Sat(str(tmp_path)) So2Sat(str(tmp_path))
@ -83,6 +87,11 @@ class TestSo2Sat:
dataset.plot(x) dataset.plot(x)
plt.close() plt.close()
def test_plot_rgb(self, dataset: So2Sat) -> None:
dataset = So2Sat(root=dataset.root, bands=("B03",))
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], suptitle="Single Band")
def test_mock_missing_module( def test_mock_missing_module(
self, dataset: So2Sat, mock_missing_module: None self, dataset: So2Sat, mock_missing_module: None
) -> None: ) -> None:

Просмотреть файл

@ -4,7 +4,7 @@
"""So2Sat dataset.""" """So2Sat dataset."""
import os import os
from typing import Callable, Dict, Optional, cast from typing import Callable, Dict, Optional, Sequence, cast
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -105,10 +105,34 @@ class So2Sat(VisionDataset):
"Water", "Water",
] ]
all_s1_band_names = ("S1B1", "S1B2", "S1B3", "S1B4", "S1B5", "S1B6", "S1B7", "S1B8")
all_s2_band_names = (
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B08A",
"B11 SWIR",
"B12 SWIR",
)
all_band_names = all_s1_band_names + all_s2_band_names
RGB_BANDS = ["B04", "B03", "B02"]
BAND_SETS = {
"all": all_band_names,
"s1": all_s1_band_names,
"s2": all_s2_band_names,
}
def __init__( def __init__(
self, self,
root: str = "data", root: str = "data",
split: str = "train", split: str = "train",
bands: Sequence[str] = BAND_SETS["all"],
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
checksum: bool = False, checksum: bool = False,
) -> None: ) -> None:
@ -117,6 +141,8 @@ class So2Sat(VisionDataset):
Args: Args:
root: root directory where dataset can be found root: root directory where dataset can be found
split: one of "train", "validation", or "test" split: one of "train", "validation", or "test"
bands: a sequence of band names to use where the indices correspond to the
array index of combined Sentinel 1 and Sentinel 2
transforms: a function/transform that takes input sample and its target as transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow) checksum: if True, check the MD5 of the downloaded files (may be slow)
@ -134,16 +160,39 @@ class So2Sat(VisionDataset):
assert split in ["train", "validation", "test"] assert split in ["train", "validation", "test"]
self._validate_bands(bands)
self.s1_band_indices: "np.typing.NDArray[np.int_]" = np.array(
[
self.all_s1_band_names.index(b)
for b in bands
if b in self.all_s1_band_names
]
).astype(int)
self.s1_band_names = [self.all_s1_band_names[i] for i in self.s1_band_indices]
self.s2_band_indices: "np.typing.NDArray[np.int_]" = np.array(
[
self.all_s2_band_names.index(b)
for b in bands
if b in self.all_s2_band_names
]
).astype(int)
self.s2_band_names = [self.all_s2_band_names[i] for i in self.s2_band_indices]
self.bands = bands
self.root = root self.root = root
self.split = split self.split = split
self.transforms = transforms self.transforms = transforms
self.checksum = checksum self.checksum = checksum
self.fn = os.path.join(self.root, self.filenames[split])
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted.") raise RuntimeError("Dataset not found or corrupted.")
self.fn = os.path.join(self.root, self.filenames[split])
with h5py.File(self.fn, "r") as f: with h5py.File(self.fn, "r") as f:
self.size = int(f["label"].shape[0]) self.size = int(f["label"].shape[0])
@ -160,7 +209,10 @@ class So2Sat(VisionDataset):
with h5py.File(self.fn, "r") as f: with h5py.File(self.fn, "r") as f:
s1 = f["sen1"][index].astype(np.float64) # convert from <f8 to float64 s1 = f["sen1"][index].astype(np.float64) # convert from <f8 to float64
s1 = np.take(s1, indices=self.s1_band_indices, axis=2)
s2 = f["sen2"][index].astype(np.float64) # convert from <f8 to float64 s2 = f["sen2"][index].astype(np.float64) # convert from <f8 to float64
s2 = np.take(s2, indices=self.s2_band_indices, axis=2)
# convert one-hot encoding to int64 then torch int # convert one-hot encoding to int64 then torch int
label = torch.tensor( # type: ignore[attr-defined] label = torch.tensor( # type: ignore[attr-defined]
f["label"][index].argmax() f["label"][index].argmax()
@ -196,13 +248,28 @@ class So2Sat(VisionDataset):
Returns: Returns:
True if dataset files are found and/or MD5s match, else False True if dataset files are found and/or MD5s match, else False
""" """
for split_name, filename in self.filenames.items(): md5 = self.md5s[self.split]
filepath = os.path.join(self.root, filename) if not check_integrity(self.fn, md5 if self.checksum else None):
md5 = self.md5s[split_name] return False
if not check_integrity(filepath, md5 if self.checksum else None):
return False
return True return True
def _validate_bands(self, bands: Sequence[str]) -> None:
"""Validate list of bands.
Args:
bands: user-provided sequence of bands to load
Raises:
AssertionError: if ``bands`` is not a sequence
ValueError: if an invalid band name is provided
.. versionadded:: 0.3
"""
assert isinstance(bands, Sequence), "'bands' must be a sequence"
for band in bands:
if band not in self.all_band_names:
raise ValueError(f"'{band}' is an invalid band name.")
def plot( def plot(
self, self,
sample: Dict[str, Tensor], sample: Dict[str, Tensor],
@ -219,10 +286,23 @@ class So2Sat(VisionDataset):
Returns: Returns:
a matplotlib Figure with the rendered sample a matplotlib Figure with the rendered sample
Raises:
ValueError: if RGB bands are not found in dataset
.. versionadded:: 0.2 .. versionadded:: 0.2
""" """
image = np.rollaxis(sample["image"][[10, 9, 8]].numpy(), 0, 3) rgb_indices = []
for band in self.RGB_BANDS:
if band in self.s2_band_names:
idx = self.s2_band_names.index(band) + len(self.s1_band_names)
rgb_indices.append(idx)
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")
image = np.take(sample["image"].numpy(), indices=rgb_indices, axis=0)
image = np.rollaxis(image, 0, 3)
image = percentile_normalization(image, 0, 100) image = percentile_normalization(image, 0, 100)
label = cast(int, sample["label"].item()) label = cast(int, sample["label"].item())
label_class = self.classes[label] label_class = self.classes[label]