зеркало из https://github.com/microsoft/torchgeo.git
Add band selection to So2Sat Dataset and adopt plot method (#394)
* add band selection and adopt plot method * np array typing and raise error doc
This commit is contained in:
Родитель
b36e053ecf
Коммит
89277dc325
|
@ -33,7 +33,7 @@ class TestSo2Sat:
|
|||
root = os.path.join("tests", "data", "so2sat")
|
||||
split = request.param
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return So2Sat(root, split, transforms, checksum=True)
|
||||
return So2Sat(root=root, split=split, transforms=transforms, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(
|
||||
|
@ -69,6 +69,10 @@ class TestSo2Sat:
|
|||
with pytest.raises(AssertionError):
|
||||
So2Sat(split="foo")
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
So2Sat(bands=("OK", "BK"))
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
So2Sat(str(tmp_path))
|
||||
|
@ -83,6 +87,11 @@ class TestSo2Sat:
|
|||
dataset.plot(x)
|
||||
plt.close()
|
||||
|
||||
def test_plot_rgb(self, dataset: So2Sat) -> None:
|
||||
dataset = So2Sat(root=dataset.root, bands=("B03",))
|
||||
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
|
||||
dataset.plot(dataset[0], suptitle="Single Band")
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: So2Sat, mock_missing_module: None
|
||||
) -> None:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"""So2Sat dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
from typing import Callable, Dict, Optional, Sequence, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@ -105,10 +105,34 @@ class So2Sat(VisionDataset):
|
|||
"Water",
|
||||
]
|
||||
|
||||
all_s1_band_names = ("S1B1", "S1B2", "S1B3", "S1B4", "S1B5", "S1B6", "S1B7", "S1B8")
|
||||
all_s2_band_names = (
|
||||
"B02",
|
||||
"B03",
|
||||
"B04",
|
||||
"B05",
|
||||
"B06",
|
||||
"B07",
|
||||
"B08",
|
||||
"B08A",
|
||||
"B11 SWIR",
|
||||
"B12 SWIR",
|
||||
)
|
||||
all_band_names = all_s1_band_names + all_s2_band_names
|
||||
|
||||
RGB_BANDS = ["B04", "B03", "B02"]
|
||||
|
||||
BAND_SETS = {
|
||||
"all": all_band_names,
|
||||
"s1": all_s1_band_names,
|
||||
"s2": all_s2_band_names,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
bands: Sequence[str] = BAND_SETS["all"],
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
|
@ -117,6 +141,8 @@ class So2Sat(VisionDataset):
|
|||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train", "validation", or "test"
|
||||
bands: a sequence of band names to use where the indices correspond to the
|
||||
array index of combined Sentinel 1 and Sentinel 2
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
@ -134,16 +160,39 @@ class So2Sat(VisionDataset):
|
|||
|
||||
assert split in ["train", "validation", "test"]
|
||||
|
||||
self._validate_bands(bands)
|
||||
self.s1_band_indices: "np.typing.NDArray[np.int_]" = np.array(
|
||||
[
|
||||
self.all_s1_band_names.index(b)
|
||||
for b in bands
|
||||
if b in self.all_s1_band_names
|
||||
]
|
||||
).astype(int)
|
||||
|
||||
self.s1_band_names = [self.all_s1_band_names[i] for i in self.s1_band_indices]
|
||||
|
||||
self.s2_band_indices: "np.typing.NDArray[np.int_]" = np.array(
|
||||
[
|
||||
self.all_s2_band_names.index(b)
|
||||
for b in bands
|
||||
if b in self.all_s2_band_names
|
||||
]
|
||||
).astype(int)
|
||||
|
||||
self.s2_band_names = [self.all_s2_band_names[i] for i in self.s2_band_indices]
|
||||
|
||||
self.bands = bands
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
||||
self.fn = os.path.join(self.root, self.filenames[split])
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError("Dataset not found or corrupted.")
|
||||
|
||||
self.fn = os.path.join(self.root, self.filenames[split])
|
||||
|
||||
with h5py.File(self.fn, "r") as f:
|
||||
self.size = int(f["label"].shape[0])
|
||||
|
||||
|
@ -160,7 +209,10 @@ class So2Sat(VisionDataset):
|
|||
|
||||
with h5py.File(self.fn, "r") as f:
|
||||
s1 = f["sen1"][index].astype(np.float64) # convert from <f8 to float64
|
||||
s1 = np.take(s1, indices=self.s1_band_indices, axis=2)
|
||||
s2 = f["sen2"][index].astype(np.float64) # convert from <f8 to float64
|
||||
s2 = np.take(s2, indices=self.s2_band_indices, axis=2)
|
||||
|
||||
# convert one-hot encoding to int64 then torch int
|
||||
label = torch.tensor( # type: ignore[attr-defined]
|
||||
f["label"][index].argmax()
|
||||
|
@ -196,13 +248,28 @@ class So2Sat(VisionDataset):
|
|||
Returns:
|
||||
True if dataset files are found and/or MD5s match, else False
|
||||
"""
|
||||
for split_name, filename in self.filenames.items():
|
||||
filepath = os.path.join(self.root, filename)
|
||||
md5 = self.md5s[split_name]
|
||||
if not check_integrity(filepath, md5 if self.checksum else None):
|
||||
md5 = self.md5s[self.split]
|
||||
if not check_integrity(self.fn, md5 if self.checksum else None):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _validate_bands(self, bands: Sequence[str]) -> None:
|
||||
"""Validate list of bands.
|
||||
|
||||
Args:
|
||||
bands: user-provided sequence of bands to load
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``bands`` is not a sequence
|
||||
ValueError: if an invalid band name is provided
|
||||
|
||||
.. versionadded:: 0.3
|
||||
"""
|
||||
assert isinstance(bands, Sequence), "'bands' must be a sequence"
|
||||
for band in bands:
|
||||
if band not in self.all_band_names:
|
||||
raise ValueError(f"'{band}' is an invalid band name.")
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
|
@ -219,10 +286,23 @@ class So2Sat(VisionDataset):
|
|||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
|
||||
Raises:
|
||||
ValueError: if RGB bands are not found in dataset
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
image = np.rollaxis(sample["image"][[10, 9, 8]].numpy(), 0, 3)
|
||||
rgb_indices = []
|
||||
for band in self.RGB_BANDS:
|
||||
if band in self.s2_band_names:
|
||||
idx = self.s2_band_names.index(band) + len(self.s1_band_names)
|
||||
rgb_indices.append(idx)
|
||||
else:
|
||||
raise ValueError("Dataset doesn't contain some of the RGB bands")
|
||||
|
||||
image = np.take(sample["image"].numpy(), indices=rgb_indices, axis=0)
|
||||
image = np.rollaxis(image, 0, 3)
|
||||
image = percentile_normalization(image, 0, 100)
|
||||
|
||||
label = cast(int, sample["label"].item())
|
||||
label_class = self.classes[label]
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче