зеркало из https://github.com/microsoft/torchgeo.git
Add SpaceNet3 (#480)
* Add SpaceNet3 * Fixes * Replace itertools.product with zip * Update docstring * Remove unused options
This commit is contained in:
Родитель
b2e178f1fe
Коммит
2e5c2b274e
|
@ -233,6 +233,7 @@ SpaceNet
|
|||
.. autoclass:: SpaceNet
|
||||
.. autoclass:: SpaceNet1
|
||||
.. autoclass:: SpaceNet2
|
||||
.. autoclass:: SpaceNet3
|
||||
.. autoclass:: SpaceNet4
|
||||
.. autoclass:: SpaceNet5
|
||||
.. autoclass:: SpaceNet7
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -2,7 +2,6 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
@ -14,7 +13,14 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
|
||||
from torchgeo.datasets import (
|
||||
SpaceNet1,
|
||||
SpaceNet2,
|
||||
SpaceNet3,
|
||||
SpaceNet4,
|
||||
SpaceNet5,
|
||||
SpaceNet7,
|
||||
)
|
||||
|
||||
TEST_DATA_DIR = "tests/data/spacenet"
|
||||
|
||||
|
@ -142,6 +148,71 @@ class TestSpaceNet2:
|
|||
plt.close()
|
||||
|
||||
|
||||
class TestSpaceNet3:
|
||||
@pytest.fixture(params=zip(["PAN", "MS"], [False, True]))
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet3:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection)
|
||||
test_md5 = {
|
||||
"sn3_AOI_3_Paris": "197440e0ade970169a801a173a492c27",
|
||||
"sn3_AOI_5_Khartoum": "b21ff7dd33a15ec32bd380c083263cdf",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(SpaceNet3, "collection_md5_dict", test_md5)
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity() # type: ignore[no-untyped-call]
|
||||
return SpaceNet3(
|
||||
root,
|
||||
image=request.param[0],
|
||||
speed_mask=request.param[1],
|
||||
collections=["sn3_AOI_3_Paris", "sn3_AOI_5_Khartoum"],
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet3) -> None:
|
||||
# Iterate over all elements to maximize coverage
|
||||
samples = [dataset[i] for i in range(len(dataset))]
|
||||
x = samples[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
if dataset.image == "MS":
|
||||
assert x["image"].shape[0] == 8
|
||||
else:
|
||||
assert x["image"].shape[0] == 1
|
||||
|
||||
def test_len(self, dataset: SpaceNet3) -> None:
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet3) -> None:
|
||||
SpaceNet3(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
SpaceNet3(str(tmp_path))
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet3) -> None:
|
||||
dataset.collection_md5_dict["sn3_AOI_5_Khartoum"] = "randommd5hash123"
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Collection sn3_AOI_5_Khartoum corrupted"
|
||||
):
|
||||
SpaceNet3(root=dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_plot(self, dataset: SpaceNet3) -> None:
|
||||
x = dataset[0].copy()
|
||||
x["prediction"] = x["mask"]
|
||||
dataset.plot(x, suptitle="Test")
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
dataset.plot({"image": x["image"]})
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestSpaceNet4:
|
||||
@pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"])
|
||||
def dataset(
|
||||
|
@ -206,9 +277,7 @@ class TestSpaceNet4:
|
|||
|
||||
|
||||
class TestSpaceNet5:
|
||||
@pytest.fixture(
|
||||
params=itertools.product(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True])
|
||||
)
|
||||
@pytest.fixture(params=zip(["PAN", "MS"], [False, True]))
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet5:
|
||||
|
@ -239,9 +308,7 @@ class TestSpaceNet5:
|
|||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
if dataset.image == "PS-RGB":
|
||||
assert x["image"].shape[0] == 3
|
||||
elif dataset.image in ["MS", "PS-MS"]:
|
||||
if dataset.image == "MS":
|
||||
assert x["image"].shape[0] == 8
|
||||
else:
|
||||
assert x["image"].shape[0] == 1
|
||||
|
|
|
@ -75,7 +75,15 @@ from .seco import SeasonalContrastS2
|
|||
from .sen12ms import SEN12MS
|
||||
from .sentinel import Sentinel, Sentinel2
|
||||
from .so2sat import So2Sat
|
||||
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
|
||||
from .spacenet import (
|
||||
SpaceNet,
|
||||
SpaceNet1,
|
||||
SpaceNet2,
|
||||
SpaceNet3,
|
||||
SpaceNet4,
|
||||
SpaceNet5,
|
||||
SpaceNet7,
|
||||
)
|
||||
from .ucmerced import UCMerced
|
||||
from .usavars import USAVars
|
||||
from .utils import (
|
||||
|
@ -155,6 +163,7 @@ __all__ = (
|
|||
"SpaceNet",
|
||||
"SpaceNet1",
|
||||
"SpaceNet2",
|
||||
"SpaceNet3",
|
||||
"SpaceNet4",
|
||||
"SpaceNet5",
|
||||
"SpaceNet7",
|
||||
|
|
|
@ -133,7 +133,7 @@ class SpaceNet(VisionDataset, abc.ABC):
|
|||
images = sorted(images)
|
||||
for imgpath in images:
|
||||
lbl_path = os.path.join(
|
||||
os.path.dirname(imgpath) + "-labels", self.label_glob
|
||||
f"{os.path.dirname(imgpath)}-labels", self.label_glob
|
||||
)
|
||||
files.append({"image_path": imgpath, "label_path": lbl_path})
|
||||
return files
|
||||
|
@ -248,7 +248,7 @@ class SpaceNet(VisionDataset, abc.ABC):
|
|||
|
||||
to_be_downloaded = []
|
||||
for collection in missing_collections:
|
||||
archive_path = os.path.join(self.root, collection + ".tar.gz")
|
||||
archive_path = os.path.join(self.root, f"{collection}.tar.gz")
|
||||
if os.path.exists(archive_path):
|
||||
print(f"Found {collection} archive")
|
||||
if (
|
||||
|
@ -281,7 +281,7 @@ class SpaceNet(VisionDataset, abc.ABC):
|
|||
"""
|
||||
for collection in collections:
|
||||
download_radiant_mlhub_collection(collection, self.root, api_key)
|
||||
archive_path = os.path.join(self.root, collection + ".tar.gz")
|
||||
archive_path = os.path.join(self.root, f"{collection}.tar.gz")
|
||||
if (
|
||||
not self.checksum
|
||||
or not check_integrity(
|
||||
|
@ -538,7 +538,8 @@ class SpaceNet2(SpaceNet):
|
|||
image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"]
|
||||
collections: collection selection which must be a subset of:
|
||||
[sn2_AOI_2_Vegas, sn2_AOI_3_Paris, sn2_AOI_4_Shanghai,
|
||||
sn2_AOI_5_Khartoum]
|
||||
sn2_AOI_5_Khartoum]. If unspecified, all collections will be
|
||||
used.
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory.
|
||||
|
@ -554,6 +555,272 @@ class SpaceNet2(SpaceNet):
|
|||
)
|
||||
|
||||
|
||||
class SpaceNet3(SpaceNet):
|
||||
r"""SpaceNet 3: Road Network Detection.
|
||||
|
||||
`SpaceNet 3 <https://spacenet.ai/spacenet-roads-dataset/>`_
|
||||
is a dataset of road networks over the cities of Las Vegas, Paris, Shanghai,
|
||||
and Khartoum.
|
||||
|
||||
Collection features:
|
||||
|
||||
+------------+---------------------+------------+---------------------------+
|
||||
| AOI | Area (km\ :sup:`2`\)| # Images | # Road Network Labels (km)|
|
||||
+============+=====================+============+===========================+
|
||||
| Vegas | 216 | 854 | 3685 |
|
||||
+------------+---------------------+------------+---------------------------+
|
||||
| Paris | 1030 | 257 | 425 |
|
||||
+------------+---------------------+------------+---------------------------+
|
||||
| Shanghai | 1000 | 1028 | 3537 |
|
||||
+------------+---------------------+------------+---------------------------+
|
||||
| Khartoum | 765 | 283 | 1030 |
|
||||
+------------+---------------------+------------+---------------------------+
|
||||
|
||||
Imagery features:
|
||||
|
||||
.. list-table::
|
||||
:widths: 10 10 10 10 10
|
||||
:header-rows: 1
|
||||
:stub-columns: 1
|
||||
|
||||
* -
|
||||
- PAN
|
||||
- MS
|
||||
- PS-MS
|
||||
- PS-RGB
|
||||
* - GSD (m)
|
||||
- 0.31
|
||||
- 1.24
|
||||
- 0.30
|
||||
- 0.30
|
||||
* - Chip size (px)
|
||||
- 1300 x 1300
|
||||
- 325 x 325
|
||||
- 1300 x 1300
|
||||
- 1300 x 1300
|
||||
|
||||
Dataset format:
|
||||
|
||||
* Imagery - Worldview-3 GeoTIFFs
|
||||
|
||||
* PAN.tif (Panchromatic)
|
||||
* MS.tif (Multispectral)
|
||||
* PS-MS (Pansharpened Multispectral)
|
||||
* PS-RGB (Pansharpened RGB)
|
||||
|
||||
* Labels - GeoJSON
|
||||
|
||||
* labels.geojson
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/abs/1807.01232
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
|
||||
imagery and labels from the Radiant Earth MLHub
|
||||
|
||||
.. versionadded:: 0.3
|
||||
"""
|
||||
|
||||
dataset_id = "spacenet3"
|
||||
collection_md5_dict = {
|
||||
"sn3_AOI_2_Vegas": "8ce7e6abffb8849eb88885035f061ee8",
|
||||
"sn3_AOI_3_Paris": "90b9ebd64cd83dc8d3d4773f45050d8f",
|
||||
"sn3_AOI_4_Shanghai": "3ea291df34548962dfba8b5ed37d700c",
|
||||
"sn3_AOI_5_Khartoum": "b8d549ac9a6d7456c0f7a8e6de23d9f9",
|
||||
}
|
||||
|
||||
imagery = {
|
||||
"MS": "MS.tif",
|
||||
"PAN": "PAN.tif",
|
||||
"PS-MS": "PS-MS.tif",
|
||||
"PS-RGB": "PS-RGB.tif",
|
||||
}
|
||||
chip_size = {
|
||||
"MS": (325, 325),
|
||||
"PAN": (1300, 1300),
|
||||
"PS-MS": (1300, 1300),
|
||||
"PS-RGB": (1300, 1300),
|
||||
}
|
||||
label_glob = "labels.geojson"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
image: str = "PS-RGB",
|
||||
speed_mask: Optional[bool] = False,
|
||||
collections: List[str] = [],
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 3 Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"]
|
||||
speed_mask: use multi-class speed mask (created by binning roads at
|
||||
10 mph increments) as label if true, else use binary mask
|
||||
collections: collection selection which must be a subset of:
|
||||
[sn3_AOI_2_Vegas, sn3_AOI_3_Paris, sn3_AOI_4_Shanghai,
|
||||
sn3_AOI_5_Khartoum]. If unspecified, all collections will be
|
||||
used.
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory.
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing
|
||||
"""
|
||||
assert image in {"MS", "PAN", "PS-MS", "PS-RGB"}
|
||||
self.speed_mask = speed_mask
|
||||
super().__init__(
|
||||
root, image, collections, transforms, download, api_key, checksum
|
||||
)
|
||||
|
||||
def _load_mask(
|
||||
self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int]
|
||||
) -> Tensor:
|
||||
"""Rasterizes the dataset's labels (in geojson format).
|
||||
|
||||
Args:
|
||||
path: path to the label
|
||||
tfm: transform of corresponding image
|
||||
shape: shape of corresponding image
|
||||
|
||||
Returns:
|
||||
Tensor: label tensor
|
||||
"""
|
||||
min_speed_bin = 1
|
||||
max_speed_bin = 65
|
||||
speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1)
|
||||
bin_size_mph = 10.0
|
||||
speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array(
|
||||
[int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin]
|
||||
)
|
||||
|
||||
try:
|
||||
with fiona.open(path) as src:
|
||||
vector_crs = CRS(src.crs)
|
||||
labels = []
|
||||
|
||||
for feature in src:
|
||||
if raster_crs != vector_crs:
|
||||
geom = transform_geom(
|
||||
vector_crs.to_string(),
|
||||
raster_crs.to_string(),
|
||||
feature["geometry"],
|
||||
)
|
||||
else:
|
||||
geom = feature["geometry"]
|
||||
|
||||
if self.speed_mask:
|
||||
val = speed_cls_arr[
|
||||
int(feature["properties"]["inferred_speed_mph"]) - 1
|
||||
]
|
||||
else:
|
||||
val = 1
|
||||
|
||||
labels.append((geom, val))
|
||||
|
||||
except FionaValueError:
|
||||
labels = []
|
||||
|
||||
if not labels:
|
||||
mask_data = np.zeros(shape=shape)
|
||||
else:
|
||||
mask_data = rasterize(
|
||||
labels,
|
||||
out_shape=shape,
|
||||
fill=0, # nodata value
|
||||
transform=tfm,
|
||||
all_touched=False,
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
mask = torch.from_numpy(mask_data).long()
|
||||
return mask
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`SpaceNet.__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
|
||||
"""
|
||||
# image can be 1 channel or >3 channels
|
||||
if sample["image"].shape[0] == 1:
|
||||
image = np.rollaxis(sample["image"].numpy(), 0, 3)
|
||||
else:
|
||||
image = np.rollaxis(sample["image"][:3].numpy(), 0, 3)
|
||||
image = percentile_normalization(image, axis=(0, 1))
|
||||
|
||||
ncols = 1
|
||||
show_mask = "mask" in sample
|
||||
show_predictions = "prediction" in sample
|
||||
|
||||
if show_mask:
|
||||
mask = sample["mask"].numpy()
|
||||
ncols += 1
|
||||
|
||||
if show_predictions:
|
||||
prediction = sample["prediction"].numpy()
|
||||
ncols += 1
|
||||
|
||||
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8))
|
||||
if not isinstance(axs, np.ndarray):
|
||||
axs = [axs]
|
||||
axs[0].imshow(image)
|
||||
axs[0].axis("off")
|
||||
if show_titles:
|
||||
axs[0].set_title("Image")
|
||||
|
||||
if show_mask:
|
||||
if self.speed_mask:
|
||||
cmap = copy.copy(plt.get_cmap("autumn_r"))
|
||||
cmap.set_under(color="black")
|
||||
axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none")
|
||||
else:
|
||||
axs[1].imshow(mask, cmap="Greys_r", interpolation="none")
|
||||
axs[1].axis("off")
|
||||
if show_titles:
|
||||
axs[1].set_title("Label")
|
||||
|
||||
if show_predictions:
|
||||
if self.speed_mask:
|
||||
cmap = copy.copy(plt.get_cmap("autumn_r"))
|
||||
cmap.set_under(color="black")
|
||||
axs[2].imshow(
|
||||
prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none"
|
||||
)
|
||||
else:
|
||||
axs[2].imshow(prediction, cmap="Greys_r", interpolation="none")
|
||||
axs[2].axis("off")
|
||||
if show_titles:
|
||||
axs[2].set_title("Prediction")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
||||
|
||||
|
||||
class SpaceNet4(SpaceNet):
|
||||
"""SpaceNet 4: Off-Nadir Buildings Dataset.
|
||||
|
||||
|
@ -699,7 +966,7 @@ class SpaceNet4(SpaceNet):
|
|||
|
||||
lbl_dir = os.path.dirname(imgpath).split("-nadir")[0]
|
||||
|
||||
lbl_path = os.path.join(lbl_dir + "-labels", self.label_glob)
|
||||
lbl_path = os.path.join(f"{lbl_dir}-labels", self.label_glob)
|
||||
assert os.path.exists(lbl_path)
|
||||
|
||||
_file = {"image_path": imgpath, "label_path": lbl_path}
|
||||
|
@ -724,7 +991,7 @@ class SpaceNet4(SpaceNet):
|
|||
return files
|
||||
|
||||
|
||||
class SpaceNet5(SpaceNet):
|
||||
class SpaceNet5(SpaceNet3):
|
||||
r"""SpaceNet 5: Automated Road Network Extraction and Route Travel Time Estimation.
|
||||
|
||||
`SpaceNet 5 <https://spacenet.ai/sn5-challenge/>`_
|
||||
|
@ -832,7 +1099,8 @@ class SpaceNet5(SpaceNet):
|
|||
speed_mask: use multi-class speed mask (created by binning roads at
|
||||
10 mph increments) as label if true, else use binary mask
|
||||
collections: collection selection which must be a subset of:
|
||||
[sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai]
|
||||
[sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai]. If unspecified, all
|
||||
collections will be used.
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory.
|
||||
|
@ -842,148 +1110,17 @@ class SpaceNet5(SpaceNet):
|
|||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing
|
||||
"""
|
||||
assert image in {"MS", "PAN", "PS-MS", "PS-RGB"}
|
||||
self.speed_mask = speed_mask
|
||||
super().__init__(
|
||||
root, image, collections, transforms, download, api_key, checksum
|
||||
root,
|
||||
image,
|
||||
speed_mask,
|
||||
collections,
|
||||
transforms,
|
||||
download,
|
||||
api_key,
|
||||
checksum,
|
||||
)
|
||||
|
||||
def _load_mask(
|
||||
self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int]
|
||||
) -> Tensor:
|
||||
"""Rasterizes the dataset's labels (in geojson format).
|
||||
|
||||
Args:
|
||||
path: path to the label
|
||||
tfm: transform of corresponding image
|
||||
shape: shape of corresponding image
|
||||
|
||||
Returns:
|
||||
Tensor: label tensor
|
||||
"""
|
||||
min_speed_bin = 1
|
||||
max_speed_bin = 65
|
||||
speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1)
|
||||
bin_size_mph = 10.0
|
||||
speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array(
|
||||
[int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin]
|
||||
)
|
||||
|
||||
try:
|
||||
with fiona.open(path) as src:
|
||||
vector_crs = CRS(src.crs)
|
||||
labels = []
|
||||
|
||||
for feature in src:
|
||||
if raster_crs != vector_crs:
|
||||
geom = transform_geom(
|
||||
vector_crs.to_string(),
|
||||
raster_crs.to_string(),
|
||||
feature["geometry"],
|
||||
)
|
||||
else:
|
||||
geom = feature["geometry"]
|
||||
|
||||
if self.speed_mask:
|
||||
val = speed_cls_arr[
|
||||
int(feature["properties"]["inferred_speed_mph"]) - 1
|
||||
]
|
||||
else:
|
||||
val = 1
|
||||
|
||||
labels.append((geom, val))
|
||||
|
||||
except FionaValueError:
|
||||
labels = []
|
||||
|
||||
if not labels:
|
||||
mask_data = np.zeros(shape=shape)
|
||||
else:
|
||||
mask_data = rasterize(
|
||||
labels,
|
||||
out_shape=shape,
|
||||
fill=0, # nodata value
|
||||
transform=tfm,
|
||||
all_touched=False,
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
mask = torch.from_numpy(mask_data).long()
|
||||
return mask
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`SpaceNet.__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
# image can be 1 channel or >3 channels
|
||||
if sample["image"].shape[0] == 1:
|
||||
image = np.rollaxis(sample["image"].numpy(), 0, 3)
|
||||
else:
|
||||
image = np.rollaxis(sample["image"][:3].numpy(), 0, 3)
|
||||
image = percentile_normalization(image, axis=(0, 1))
|
||||
|
||||
ncols = 1
|
||||
show_mask = "mask" in sample
|
||||
show_predictions = "prediction" in sample
|
||||
|
||||
if show_mask:
|
||||
mask = sample["mask"].numpy()
|
||||
ncols += 1
|
||||
|
||||
if show_predictions:
|
||||
prediction = sample["prediction"].numpy()
|
||||
ncols += 1
|
||||
|
||||
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8))
|
||||
if not isinstance(axs, np.ndarray):
|
||||
axs = [axs]
|
||||
axs[0].imshow(image)
|
||||
axs[0].axis("off")
|
||||
if show_titles:
|
||||
axs[0].set_title("Image")
|
||||
|
||||
if show_mask:
|
||||
if self.speed_mask:
|
||||
cmap = copy.copy(plt.get_cmap("autumn_r"))
|
||||
cmap.set_under(color="black")
|
||||
axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none")
|
||||
else:
|
||||
axs[1].imshow(mask, cmap="Greys_r", interpolation="none")
|
||||
axs[1].axis("off")
|
||||
if show_titles:
|
||||
axs[1].set_title("Label")
|
||||
|
||||
if show_predictions:
|
||||
if self.speed_mask:
|
||||
cmap = copy.copy(plt.get_cmap("autumn_r"))
|
||||
cmap.set_under(color="black")
|
||||
axs[2].imshow(
|
||||
prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none"
|
||||
)
|
||||
else:
|
||||
axs[2].imshow(prediction, cmap="Greys_r", interpolation="none")
|
||||
axs[2].axis("off")
|
||||
if show_titles:
|
||||
axs[2].set_title("Prediction")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
||||
|
||||
|
||||
class SpaceNet7(SpaceNet):
|
||||
"""SpaceNet 7: Multi-Temporal Urban Development Challenge.
|
||||
|
@ -1126,13 +1263,12 @@ class SpaceNet7(SpaceNet):
|
|||
Returns:
|
||||
data at that index
|
||||
"""
|
||||
sample = {}
|
||||
files = self.files[index]
|
||||
img, tfm, raster_crs = self._load_image(files["image_path"])
|
||||
h, w = img.shape[1:]
|
||||
|
||||
ch, cw = self.chip_size["img"]
|
||||
sample["image"] = img[:, :ch, :cw]
|
||||
sample = {"image": img[:, :ch, :cw]}
|
||||
if self.split == "train":
|
||||
mask = self._load_mask(files["label_path"], tfm, raster_crs, (h, w))
|
||||
sample["mask"] = mask[:ch, :cw]
|
||||
|
|
Загрузка…
Ссылка в новой задаче