Add band selection to So2Sat Dataset and adopt plot method (#394)

* add band selection and adopt plot method

* np array typing and raise error doc
This commit is contained in:
Nils Lehmann 2022-02-20 20:47:52 +01:00 коммит произвёл GitHub
Родитель b36e053ecf
Коммит 89277dc325
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 99 добавлений и 10 удалений

Просмотреть файл

@ -33,7 +33,7 @@ class TestSo2Sat:
root = os.path.join("tests", "data", "so2sat")
split = request.param
transforms = nn.Identity() # type: ignore[attr-defined]
return So2Sat(root, split, transforms, checksum=True)
return So2Sat(root=root, split=split, transforms=transforms, checksum=True)
@pytest.fixture
def mock_missing_module(
@ -69,6 +69,10 @@ class TestSo2Sat:
with pytest.raises(AssertionError):
So2Sat(split="foo")
def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):
So2Sat(bands=("OK", "BK"))
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
So2Sat(str(tmp_path))
@ -83,6 +87,11 @@ class TestSo2Sat:
dataset.plot(x)
plt.close()
def test_plot_rgb(self, dataset: So2Sat) -> None:
dataset = So2Sat(root=dataset.root, bands=("B03",))
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
dataset.plot(dataset[0], suptitle="Single Band")
def test_mock_missing_module(
self, dataset: So2Sat, mock_missing_module: None
) -> None:

Просмотреть файл

@ -4,7 +4,7 @@
"""So2Sat dataset."""
import os
from typing import Callable, Dict, Optional, cast
from typing import Callable, Dict, Optional, Sequence, cast
import matplotlib.pyplot as plt
import numpy as np
@ -105,10 +105,34 @@ class So2Sat(VisionDataset):
"Water",
]
all_s1_band_names = ("S1B1", "S1B2", "S1B3", "S1B4", "S1B5", "S1B6", "S1B7", "S1B8")
all_s2_band_names = (
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B08A",
"B11 SWIR",
"B12 SWIR",
)
all_band_names = all_s1_band_names + all_s2_band_names
RGB_BANDS = ["B04", "B03", "B02"]
BAND_SETS = {
"all": all_band_names,
"s1": all_s1_band_names,
"s2": all_s2_band_names,
}
def __init__(
self,
root: str = "data",
split: str = "train",
bands: Sequence[str] = BAND_SETS["all"],
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
checksum: bool = False,
) -> None:
@ -117,6 +141,8 @@ class So2Sat(VisionDataset):
Args:
root: root directory where dataset can be found
split: one of "train", "validation", or "test"
bands: a sequence of band names to use where the indices correspond to the
array index of combined Sentinel 1 and Sentinel 2
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)
@ -134,16 +160,39 @@ class So2Sat(VisionDataset):
assert split in ["train", "validation", "test"]
self._validate_bands(bands)
self.s1_band_indices: "np.typing.NDArray[np.int_]" = np.array(
[
self.all_s1_band_names.index(b)
for b in bands
if b in self.all_s1_band_names
]
).astype(int)
self.s1_band_names = [self.all_s1_band_names[i] for i in self.s1_band_indices]
self.s2_band_indices: "np.typing.NDArray[np.int_]" = np.array(
[
self.all_s2_band_names.index(b)
for b in bands
if b in self.all_s2_band_names
]
).astype(int)
self.s2_band_names = [self.all_s2_band_names[i] for i in self.s2_band_indices]
self.bands = bands
self.root = root
self.split = split
self.transforms = transforms
self.checksum = checksum
self.fn = os.path.join(self.root, self.filenames[split])
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted.")
self.fn = os.path.join(self.root, self.filenames[split])
with h5py.File(self.fn, "r") as f:
self.size = int(f["label"].shape[0])
@ -160,7 +209,10 @@ class So2Sat(VisionDataset):
with h5py.File(self.fn, "r") as f:
s1 = f["sen1"][index].astype(np.float64) # convert from <f8 to float64
s1 = np.take(s1, indices=self.s1_band_indices, axis=2)
s2 = f["sen2"][index].astype(np.float64) # convert from <f8 to float64
s2 = np.take(s2, indices=self.s2_band_indices, axis=2)
# convert one-hot encoding to int64 then torch int
label = torch.tensor( # type: ignore[attr-defined]
f["label"][index].argmax()
@ -196,13 +248,28 @@ class So2Sat(VisionDataset):
Returns:
True if dataset files are found and/or MD5s match, else False
"""
for split_name, filename in self.filenames.items():
filepath = os.path.join(self.root, filename)
md5 = self.md5s[split_name]
if not check_integrity(filepath, md5 if self.checksum else None):
md5 = self.md5s[self.split]
if not check_integrity(self.fn, md5 if self.checksum else None):
return False
return True
def _validate_bands(self, bands: Sequence[str]) -> None:
"""Validate list of bands.
Args:
bands: user-provided sequence of bands to load
Raises:
AssertionError: if ``bands`` is not a sequence
ValueError: if an invalid band name is provided
.. versionadded:: 0.3
"""
assert isinstance(bands, Sequence), "'bands' must be a sequence"
for band in bands:
if band not in self.all_band_names:
raise ValueError(f"'{band}' is an invalid band name.")
def plot(
self,
sample: Dict[str, Tensor],
@ -219,10 +286,23 @@ class So2Sat(VisionDataset):
Returns:
a matplotlib Figure with the rendered sample
Raises:
ValueError: if RGB bands are not found in dataset
.. versionadded:: 0.2
"""
image = np.rollaxis(sample["image"][[10, 9, 8]].numpy(), 0, 3)
rgb_indices = []
for band in self.RGB_BANDS:
if band in self.s2_band_names:
idx = self.s2_band_names.index(band) + len(self.s1_band_names)
rgb_indices.append(idx)
else:
raise ValueError("Dataset doesn't contain some of the RGB bands")
image = np.take(sample["image"].numpy(), indices=rgb_indices, axis=0)
image = np.rollaxis(image, 0, 3)
image = percentile_normalization(image, 0, 100)
label = cast(int, sample["label"].item())
label_class = self.classes[label]