Add Spacenet 1: Building Detection v1 (#129)

* Add Spacenet 1

* Add test data

* Style fixes

* Convert Spacenet1 to VisionDataset

* Add option for selecting imagery

* Consolidate spacenet

* Create single spacenet.py for all spacenet datasets
* Create single spacenet directory for all spacenet test data
* Create single test_spacenet.py for testing all spacenet datasets

* Add copyright

* Reorder Spacenet in docs

* Test both rgb & 8band

* Rename Spacenet -> SpaceNet
This commit is contained in:
Ashwin Nair 2021-09-15 20:35:15 +04:00 коммит произвёл GitHub
Родитель 3313a36014
Коммит 0ce0a591b6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 323 добавлений и 0 удалений

Просмотреть файл

@ -127,6 +127,11 @@ So2Sat
.. autoclass:: So2Sat
SpaceNet
^^^^^^^^
.. autoclass:: SpaceNet1
Tropical Cyclone Wind Estimation Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Двоичные данные
tests/data/spacenet/spacenet1/sn1_AOI_1_RIO.tar.gz Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import glob
import os
import shutil
from pathlib import Path
from typing import Generator
import pytest
import torch
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import SpaceNet1
from torchgeo.transforms import Identity
TEST_DATA_DIR = "tests/data/spacenet"
class Dataset:
def __init__(self, collection_id: str) -> None:
self.collection_id = collection_id
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(TEST_DATA_DIR, self.collection_id, "*.tar.gz")
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
def fetch(collection_id: str, **kwargs: str) -> Dataset:
return Dataset(collection_id)
class TestSpaceNet1:
@pytest.fixture(params=["rgb", "8band"])
def dataset(
self,
request: SubRequest,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> SpaceNet1:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr( # type: ignore[attr-defined]
radiant_mlhub.Dataset, "fetch", fetch
)
test_md5 = "829652022c2df4511ee4ae05bc290250"
monkeypatch.setattr(SpaceNet1, "md5", test_md5) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = Identity()
return SpaceNet1(
root,
image=request.param,
transforms=transforms,
download=True,
api_key="",
)
def test_getitem(self, dataset: SpaceNet1) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
if dataset.image == "rgb":
assert x["image"].shape[0] == 3
else:
assert x["image"].shape[0] == 8
def test_len(self, dataset: SpaceNet1) -> None:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: SpaceNet1) -> None:
SpaceNet1(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SpaceNet1(str(tmp_path))

Просмотреть файл

@ -47,6 +47,7 @@ from .resisc45 import RESISC45
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
from .spacenet import SpaceNet1
from .utils import BoundingBox, collate_dict
__all__ = (
@ -92,6 +93,7 @@ __all__ = (
"RESISC45",
"SEN12MS",
"So2Sat",
"SpaceNet1",
"TropicalCycloneWindEstimation",
"VHR10",
# Base classes

Просмотреть файл

@ -0,0 +1,239 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""SpaceNet datasets."""
import glob
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import fiona
import numpy as np
import rasterio as rio
import torch
from affine import Affine
from rasterio.features import rasterize
from torch import Tensor
from torchgeo.datasets.geo import VisionDataset
from torchgeo.datasets.utils import (
check_integrity,
download_radiant_mlhub,
extract_archive,
)
class SpaceNet1(VisionDataset):
"""SpaceNet 1: Building Detection v1 Dataset.
`SpaceNet 1 <https://spacenet.ai/spacenet-buildings-dataset-v1/>`_
is a dataset of building footprints over the city of Rio de Janeiro.
Dataset features:
* No. of images - 6940 (8 Band) + 6940 (RGB)
* No. of polygons - 382,534 building labels
* Area Coverage - 2544 sq km
Dataset format:
* Imagery - Raw 8 band Worldview-3 (GeoTIFF) & Pansharpened RGB image (GeoTIFF)
* Labels - GeoJSON
If you are using data from SpaceNet in a paper, please cite the following paper:
* https://arxiv.org/abs/1807.01232
.. note::
This dataset requires the following additional library to be installed:
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
imagery and labels from the Radiant Earth MLHub
"""
dataset_id = "spacenet1"
md5 = "e6ea35331636fa0c036c04b3d1cbf226"
imagery = {"rgb": "RGB.tif", "8band": "8Band.tif"}
label_glob = "labels.geojson"
foldername = "sn1_AOI_1_RIO"
def __init__(
self,
root: str,
image: str = "rgb",
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
api_key: Optional[str] = None,
checksum: bool = False,
) -> None:
"""Initialise a new SpaceNet 1 Dataset instance.
Args:
root: root directory where dataset can be found
image: image selection which must be "rgb" or "8band"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version.
download: if True, download dataset and store it in the root directory.
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
Raises:
RuntimeError: if ``download=False`` but dataset is missing
"""
self.root = root
self.image = image # For testing
self.filename = self.imagery[image]
self.transforms = transforms
self.checksum = checksum
if not self._check_integrity():
if download:
self._download(api_key)
else:
raise RuntimeError(
"Dataset not found. You can use download=True to download it."
)
self.files = self._load_files(os.path.join(root, self.foldername))
def _load_files(self, root: str) -> List[Dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
root: root dir of dataset
Returns:
list of dicts containing paths for each triple of rgb,
8band and label
"""
files = []
images = glob.glob(os.path.join(root, "*", self.filename))
images = sorted(images)
for imgpath in images:
lbl_path = os.path.join(
os.path.dirname(imgpath) + "-labels", "labels.geojson"
)
files.append({"image_path": imgpath, "label_path": lbl_path})
return files
def _load_image(self, path: str) -> Tuple[Tensor, Affine]:
"""Load a single image.
Args:
path: path to the image
Returns:
the image
"""
filename = os.path.join(path)
with rio.open(filename) as img:
array = img.read().astype(np.float32)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor, img.transform
def _load_mask(self, path: str, tfm: Affine, shape: Tuple[int, int]) -> Tensor:
"""Rasterizes the dataset's labels (in geojson format).
Args:
path (str): path to the label
tfm (Affine): transform of corresponding image
shape (List[int, int]): shape of corresponding image
Returns:
Tensor: label tensor
"""
with fiona.open(path) as src:
labels = [feature["geometry"] for feature in src]
if not labels:
mask_data = np.zeros(shape=shape)
else:
mask_data = rasterize(
labels,
out_shape=shape,
fill=0, # nodata value
transform=tfm,
all_touched=False,
dtype=np.uint8,
)
mask: Tensor = torch.from_numpy(mask_data).long() # type: ignore[attr-defined]
return mask
def __len__(self) -> int:
"""Return the number of samples in the dataset.
Returns:
length of the dataset
"""
return len(self.files)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
files = self.files[index]
img, tfm = self._load_image(files["image_path"])
h, w = img.shape[1:]
mask = self._load_mask(files["label_path"], tfm, (h, w))
sample = {"image": img, "mask": mask}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def _check_integrity(self) -> bool:
"""Checks the integrity of the dataset structure.
Returns:
True if the dataset directories are found, else False
"""
stacpath = os.path.join(self.root, self.foldername, "collection.json")
if os.path.exists(stacpath):
return True
# If dataset folder does not exist, check for uncorrupted archive
archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
if not os.path.exists(archive_path):
return False
print("Archive found")
if self.checksum and not check_integrity(archive_path, self.md5):
print("Dataset corrupted")
return False
print("Extracting...")
extract_archive(archive_path)
return True
def _download(self, api_key: Optional[str] = None) -> None:
"""Download the dataset and extract it.
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
Raises:
RuntimeError: if download doesn't work correctly or checksums don't match
"""
if self._check_integrity():
print("Files already downloaded")
return
download_radiant_mlhub(self.dataset_id, self.root, api_key)
archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
if (
self.checksum
and check_integrity(archive_path, self.md5)
or not self.checksum
):
extract_archive(archive_path)
else:
raise RuntimeError("Dataset corrupted")