* Add SpaceNet3

* Fixes

* Replace itertools.product with zip

* Update docstring

* Remove unused options
This commit is contained in:
Ashwin Nair 2022-03-29 15:47:14 +01:00 коммит произвёл GitHub
Родитель b2e178f1fe
Коммит 2e5c2b274e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 370 добавлений и 157 удалений

Просмотреть файл

@ -233,6 +233,7 @@ SpaceNet
.. autoclass:: SpaceNet
.. autoclass:: SpaceNet1
.. autoclass:: SpaceNet2
.. autoclass:: SpaceNet3
.. autoclass:: SpaceNet4
.. autoclass:: SpaceNet5
.. autoclass:: SpaceNet7

Двоичные данные
tests/data/spacenet/sn3_AOI_3_Paris.tar.gz Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -2,7 +2,6 @@
# Licensed under the MIT License.
import glob
import itertools
import os
import shutil
from pathlib import Path
@ -14,7 +13,14 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
from torchgeo.datasets import (
SpaceNet1,
SpaceNet2,
SpaceNet3,
SpaceNet4,
SpaceNet5,
SpaceNet7,
)
TEST_DATA_DIR = "tests/data/spacenet"
@ -142,6 +148,71 @@ class TestSpaceNet2:
plt.close()
class TestSpaceNet3:
@pytest.fixture(params=zip(["PAN", "MS"], [False, True]))
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet3:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch_collection)
test_md5 = {
"sn3_AOI_3_Paris": "197440e0ade970169a801a173a492c27",
"sn3_AOI_5_Khartoum": "b21ff7dd33a15ec32bd380c083263cdf",
}
monkeypatch.setattr(SpaceNet3, "collection_md5_dict", test_md5)
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[no-untyped-call]
return SpaceNet3(
root,
image=request.param[0],
speed_mask=request.param[1],
collections=["sn3_AOI_3_Paris", "sn3_AOI_5_Khartoum"],
transforms=transforms,
download=True,
api_key="",
)
def test_getitem(self, dataset: SpaceNet3) -> None:
# Iterate over all elements to maximize coverage
samples = [dataset[i] for i in range(len(dataset))]
x = samples[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
if dataset.image == "MS":
assert x["image"].shape[0] == 8
else:
assert x["image"].shape[0] == 1
def test_len(self, dataset: SpaceNet3) -> None:
assert len(dataset) == 4
def test_already_downloaded(self, dataset: SpaceNet3) -> None:
SpaceNet3(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SpaceNet3(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet3) -> None:
dataset.collection_md5_dict["sn3_AOI_5_Khartoum"] = "randommd5hash123"
with pytest.raises(
RuntimeError, match="Collection sn3_AOI_5_Khartoum corrupted"
):
SpaceNet3(root=dataset.root, download=True, checksum=True)
def test_plot(self, dataset: SpaceNet3) -> None:
x = dataset[0].copy()
x["prediction"] = x["mask"]
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
dataset.plot({"image": x["image"]})
plt.close()
class TestSpaceNet4:
@pytest.fixture(params=["PAN", "MS", "PS-RGBNIR"])
def dataset(
@ -206,9 +277,7 @@ class TestSpaceNet4:
class TestSpaceNet5:
@pytest.fixture(
params=itertools.product(["PAN", "MS", "PS-MS", "PS-RGB"], [False, True])
)
@pytest.fixture(params=zip(["PAN", "MS"], [False, True]))
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet5:
@ -239,9 +308,7 @@ class TestSpaceNet5:
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
if dataset.image == "PS-RGB":
assert x["image"].shape[0] == 3
elif dataset.image in ["MS", "PS-MS"]:
if dataset.image == "MS":
assert x["image"].shape[0] == 8
else:
assert x["image"].shape[0] == 1

Просмотреть файл

@ -75,7 +75,15 @@ from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet5, SpaceNet7
from .spacenet import (
SpaceNet,
SpaceNet1,
SpaceNet2,
SpaceNet3,
SpaceNet4,
SpaceNet5,
SpaceNet7,
)
from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
@ -155,6 +163,7 @@ __all__ = (
"SpaceNet",
"SpaceNet1",
"SpaceNet2",
"SpaceNet3",
"SpaceNet4",
"SpaceNet5",
"SpaceNet7",

Просмотреть файл

@ -133,7 +133,7 @@ class SpaceNet(VisionDataset, abc.ABC):
images = sorted(images)
for imgpath in images:
lbl_path = os.path.join(
os.path.dirname(imgpath) + "-labels", self.label_glob
f"{os.path.dirname(imgpath)}-labels", self.label_glob
)
files.append({"image_path": imgpath, "label_path": lbl_path})
return files
@ -248,7 +248,7 @@ class SpaceNet(VisionDataset, abc.ABC):
to_be_downloaded = []
for collection in missing_collections:
archive_path = os.path.join(self.root, collection + ".tar.gz")
archive_path = os.path.join(self.root, f"{collection}.tar.gz")
if os.path.exists(archive_path):
print(f"Found {collection} archive")
if (
@ -281,7 +281,7 @@ class SpaceNet(VisionDataset, abc.ABC):
"""
for collection in collections:
download_radiant_mlhub_collection(collection, self.root, api_key)
archive_path = os.path.join(self.root, collection + ".tar.gz")
archive_path = os.path.join(self.root, f"{collection}.tar.gz")
if (
not self.checksum
or not check_integrity(
@ -538,7 +538,8 @@ class SpaceNet2(SpaceNet):
image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"]
collections: collection selection which must be a subset of:
[sn2_AOI_2_Vegas, sn2_AOI_3_Paris, sn2_AOI_4_Shanghai,
sn2_AOI_5_Khartoum]
sn2_AOI_5_Khartoum]. If unspecified, all collections will be
used.
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory.
@ -554,6 +555,272 @@ class SpaceNet2(SpaceNet):
)
class SpaceNet3(SpaceNet):
r"""SpaceNet 3: Road Network Detection.
`SpaceNet 3 <https://spacenet.ai/spacenet-roads-dataset/>`_
is a dataset of road networks over the cities of Las Vegas, Paris, Shanghai,
and Khartoum.
Collection features:
+------------+---------------------+------------+---------------------------+
| AOI | Area (km\ :sup:`2`\)| # Images | # Road Network Labels (km)|
+============+=====================+============+===========================+
| Vegas | 216 | 854 | 3685 |
+------------+---------------------+------------+---------------------------+
| Paris | 1030 | 257 | 425 |
+------------+---------------------+------------+---------------------------+
| Shanghai | 1000 | 1028 | 3537 |
+------------+---------------------+------------+---------------------------+
| Khartoum | 765 | 283 | 1030 |
+------------+---------------------+------------+---------------------------+
Imagery features:
.. list-table::
:widths: 10 10 10 10 10
:header-rows: 1
:stub-columns: 1
* -
- PAN
- MS
- PS-MS
- PS-RGB
* - GSD (m)
- 0.31
- 1.24
- 0.30
- 0.30
* - Chip size (px)
- 1300 x 1300
- 325 x 325
- 1300 x 1300
- 1300 x 1300
Dataset format:
* Imagery - Worldview-3 GeoTIFFs
* PAN.tif (Panchromatic)
* MS.tif (Multispectral)
* PS-MS (Pansharpened Multispectral)
* PS-RGB (Pansharpened RGB)
* Labels - GeoJSON
* labels.geojson
If you use this dataset in your research, please cite the following paper:
* https://arxiv.org/abs/1807.01232
.. note::
This dataset requires the following additional library to be installed:
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
imagery and labels from the Radiant Earth MLHub
.. versionadded:: 0.3
"""
dataset_id = "spacenet3"
collection_md5_dict = {
"sn3_AOI_2_Vegas": "8ce7e6abffb8849eb88885035f061ee8",
"sn3_AOI_3_Paris": "90b9ebd64cd83dc8d3d4773f45050d8f",
"sn3_AOI_4_Shanghai": "3ea291df34548962dfba8b5ed37d700c",
"sn3_AOI_5_Khartoum": "b8d549ac9a6d7456c0f7a8e6de23d9f9",
}
imagery = {
"MS": "MS.tif",
"PAN": "PAN.tif",
"PS-MS": "PS-MS.tif",
"PS-RGB": "PS-RGB.tif",
}
chip_size = {
"MS": (325, 325),
"PAN": (1300, 1300),
"PS-MS": (1300, 1300),
"PS-RGB": (1300, 1300),
}
label_glob = "labels.geojson"
def __init__(
self,
root: str,
image: str = "PS-RGB",
speed_mask: Optional[bool] = False,
collections: List[str] = [],
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
api_key: Optional[str] = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet 3 Dataset instance.
Args:
root: root directory where dataset can be found
image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"]
speed_mask: use multi-class speed mask (created by binning roads at
10 mph increments) as label if true, else use binary mask
collections: collection selection which must be a subset of:
[sn3_AOI_2_Vegas, sn3_AOI_3_Paris, sn3_AOI_4_Shanghai,
sn3_AOI_5_Khartoum]. If unspecified, all collections will be
used.
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory.
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` but dataset is missing
"""
assert image in {"MS", "PAN", "PS-MS", "PS-RGB"}
self.speed_mask = speed_mask
super().__init__(
root, image, collections, transforms, download, api_key, checksum
)
def _load_mask(
self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int]
) -> Tensor:
"""Rasterizes the dataset's labels (in geojson format).
Args:
path: path to the label
tfm: transform of corresponding image
shape: shape of corresponding image
Returns:
Tensor: label tensor
"""
min_speed_bin = 1
max_speed_bin = 65
speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1)
bin_size_mph = 10.0
speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array(
[int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin]
)
try:
with fiona.open(path) as src:
vector_crs = CRS(src.crs)
labels = []
for feature in src:
if raster_crs != vector_crs:
geom = transform_geom(
vector_crs.to_string(),
raster_crs.to_string(),
feature["geometry"],
)
else:
geom = feature["geometry"]
if self.speed_mask:
val = speed_cls_arr[
int(feature["properties"]["inferred_speed_mph"]) - 1
]
else:
val = 1
labels.append((geom, val))
except FionaValueError:
labels = []
if not labels:
mask_data = np.zeros(shape=shape)
else:
mask_data = rasterize(
labels,
out_shape=shape,
fill=0, # nodata value
transform=tfm,
all_touched=False,
dtype=np.uint8,
)
mask = torch.from_numpy(mask_data).long()
return mask
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`SpaceNet.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
"""
# image can be 1 channel or >3 channels
if sample["image"].shape[0] == 1:
image = np.rollaxis(sample["image"].numpy(), 0, 3)
else:
image = np.rollaxis(sample["image"][:3].numpy(), 0, 3)
image = percentile_normalization(image, axis=(0, 1))
ncols = 1
show_mask = "mask" in sample
show_predictions = "prediction" in sample
if show_mask:
mask = sample["mask"].numpy()
ncols += 1
if show_predictions:
prediction = sample["prediction"].numpy()
ncols += 1
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8))
if not isinstance(axs, np.ndarray):
axs = [axs]
axs[0].imshow(image)
axs[0].axis("off")
if show_titles:
axs[0].set_title("Image")
if show_mask:
if self.speed_mask:
cmap = copy.copy(plt.get_cmap("autumn_r"))
cmap.set_under(color="black")
axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none")
else:
axs[1].imshow(mask, cmap="Greys_r", interpolation="none")
axs[1].axis("off")
if show_titles:
axs[1].set_title("Label")
if show_predictions:
if self.speed_mask:
cmap = copy.copy(plt.get_cmap("autumn_r"))
cmap.set_under(color="black")
axs[2].imshow(
prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none"
)
else:
axs[2].imshow(prediction, cmap="Greys_r", interpolation="none")
axs[2].axis("off")
if show_titles:
axs[2].set_title("Prediction")
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class SpaceNet4(SpaceNet):
"""SpaceNet 4: Off-Nadir Buildings Dataset.
@ -699,7 +966,7 @@ class SpaceNet4(SpaceNet):
lbl_dir = os.path.dirname(imgpath).split("-nadir")[0]
lbl_path = os.path.join(lbl_dir + "-labels", self.label_glob)
lbl_path = os.path.join(f"{lbl_dir}-labels", self.label_glob)
assert os.path.exists(lbl_path)
_file = {"image_path": imgpath, "label_path": lbl_path}
@ -724,7 +991,7 @@ class SpaceNet4(SpaceNet):
return files
class SpaceNet5(SpaceNet):
class SpaceNet5(SpaceNet3):
r"""SpaceNet 5: Automated Road Network Extraction and Route Travel Time Estimation.
`SpaceNet 5 <https://spacenet.ai/sn5-challenge/>`_
@ -832,7 +1099,8 @@ class SpaceNet5(SpaceNet):
speed_mask: use multi-class speed mask (created by binning roads at
10 mph increments) as label if true, else use binary mask
collections: collection selection which must be a subset of:
[sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai]
[sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai]. If unspecified, all
collections will be used.
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory.
@ -842,148 +1110,17 @@ class SpaceNet5(SpaceNet):
Raises:
RuntimeError: if ``download=False`` but dataset is missing
"""
assert image in {"MS", "PAN", "PS-MS", "PS-RGB"}
self.speed_mask = speed_mask
super().__init__(
root, image, collections, transforms, download, api_key, checksum
root,
image,
speed_mask,
collections,
transforms,
download,
api_key,
checksum,
)
def _load_mask(
self, path: str, tfm: Affine, raster_crs: CRS, shape: Tuple[int, int]
) -> Tensor:
"""Rasterizes the dataset's labels (in geojson format).
Args:
path: path to the label
tfm: transform of corresponding image
shape: shape of corresponding image
Returns:
Tensor: label tensor
"""
min_speed_bin = 1
max_speed_bin = 65
speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1)
bin_size_mph = 10.0
speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array(
[int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin]
)
try:
with fiona.open(path) as src:
vector_crs = CRS(src.crs)
labels = []
for feature in src:
if raster_crs != vector_crs:
geom = transform_geom(
vector_crs.to_string(),
raster_crs.to_string(),
feature["geometry"],
)
else:
geom = feature["geometry"]
if self.speed_mask:
val = speed_cls_arr[
int(feature["properties"]["inferred_speed_mph"]) - 1
]
else:
val = 1
labels.append((geom, val))
except FionaValueError:
labels = []
if not labels:
mask_data = np.zeros(shape=shape)
else:
mask_data = rasterize(
labels,
out_shape=shape,
fill=0, # nodata value
transform=tfm,
all_touched=False,
dtype=np.uint8,
)
mask = torch.from_numpy(mask_data).long()
return mask
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`SpaceNet.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.2
"""
# image can be 1 channel or >3 channels
if sample["image"].shape[0] == 1:
image = np.rollaxis(sample["image"].numpy(), 0, 3)
else:
image = np.rollaxis(sample["image"][:3].numpy(), 0, 3)
image = percentile_normalization(image, axis=(0, 1))
ncols = 1
show_mask = "mask" in sample
show_predictions = "prediction" in sample
if show_mask:
mask = sample["mask"].numpy()
ncols += 1
if show_predictions:
prediction = sample["prediction"].numpy()
ncols += 1
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8))
if not isinstance(axs, np.ndarray):
axs = [axs]
axs[0].imshow(image)
axs[0].axis("off")
if show_titles:
axs[0].set_title("Image")
if show_mask:
if self.speed_mask:
cmap = copy.copy(plt.get_cmap("autumn_r"))
cmap.set_under(color="black")
axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation="none")
else:
axs[1].imshow(mask, cmap="Greys_r", interpolation="none")
axs[1].axis("off")
if show_titles:
axs[1].set_title("Label")
if show_predictions:
if self.speed_mask:
cmap = copy.copy(plt.get_cmap("autumn_r"))
cmap.set_under(color="black")
axs[2].imshow(
prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation="none"
)
else:
axs[2].imshow(prediction, cmap="Greys_r", interpolation="none")
axs[2].axis("off")
if show_titles:
axs[2].set_title("Prediction")
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class SpaceNet7(SpaceNet):
"""SpaceNet 7: Multi-Temporal Urban Development Challenge.
@ -1126,13 +1263,12 @@ class SpaceNet7(SpaceNet):
Returns:
data at that index
"""
sample = {}
files = self.files[index]
img, tfm, raster_crs = self._load_image(files["image_path"])
h, w = img.shape[1:]
ch, cw = self.chip_size["img"]
sample["image"] = img[:, :ch, :cw]
sample = {"image": img[:, :ch, :cw]}
if self.split == "train":
mask = self._load_mask(files["label_path"], tfm, raster_crs, (h, w))
sample["mask"] = mask[:ch, :cw]