зеркало из https://github.com/microsoft/torchgeo.git
Add Spacenet 1: Building Detection v1 (#129)
* Add Spacenet 1 * Add test data * Style fixes * Convert Spacenet1 to VisionDataset * Add option for selecting imagery * Consolidate spacenet * Create single spacenet.py for all spacenet datasets * Create single spacenet directory for all spacenet test data * Create single test_spacenet.py for testing all spacenet datasets * Add copyright * Reorder Spacenet in docs * Test both rgb & 8band * Rename Spacenet -> SpaceNet
This commit is contained in:
Родитель
3313a36014
Коммит
0ce0a591b6
|
@ -127,6 +127,11 @@ So2Sat
|
|||
|
||||
.. autoclass:: So2Sat
|
||||
|
||||
SpaceNet
|
||||
^^^^^^^^
|
||||
|
||||
.. autoclass:: SpaceNet1
|
||||
|
||||
Tropical Cyclone Wind Estimation Competition
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
Двоичный файл не отображается.
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import SpaceNet1
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
TEST_DATA_DIR = "tests/data/spacenet"
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __init__(self, collection_id: str) -> None:
|
||||
self.collection_id = collection_id
|
||||
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(TEST_DATA_DIR, self.collection_id, "*.tar.gz")
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(collection_id: str, **kwargs: str) -> Dataset:
|
||||
return Dataset(collection_id)
|
||||
|
||||
|
||||
class TestSpaceNet1:
|
||||
@pytest.fixture(params=["rgb", "8band"])
|
||||
def dataset(
|
||||
self,
|
||||
request: SubRequest,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
) -> SpaceNet1:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
radiant_mlhub.Dataset, "fetch", fetch
|
||||
)
|
||||
test_md5 = "829652022c2df4511ee4ae05bc290250"
|
||||
monkeypatch.setattr(SpaceNet1, "md5", test_md5) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
return SpaceNet1(
|
||||
root,
|
||||
image=request.param,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet1) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
if dataset.image == "rgb":
|
||||
assert x["image"].shape[0] == 3
|
||||
else:
|
||||
assert x["image"].shape[0] == 8
|
||||
|
||||
def test_len(self, dataset: SpaceNet1) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet1) -> None:
|
||||
SpaceNet1(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
SpaceNet1(str(tmp_path))
|
|
@ -47,6 +47,7 @@ from .resisc45 import RESISC45
|
|||
from .sen12ms import SEN12MS
|
||||
from .sentinel import Sentinel, Sentinel2
|
||||
from .so2sat import So2Sat
|
||||
from .spacenet import SpaceNet1
|
||||
from .utils import BoundingBox, collate_dict
|
||||
|
||||
__all__ = (
|
||||
|
@ -92,6 +93,7 @@ __all__ = (
|
|||
"RESISC45",
|
||||
"SEN12MS",
|
||||
"So2Sat",
|
||||
"SpaceNet1",
|
||||
"TropicalCycloneWindEstimation",
|
||||
"VHR10",
|
||||
# Base classes
|
||||
|
|
|
@ -0,0 +1,239 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""SpaceNet datasets."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import fiona
|
||||
import numpy as np
|
||||
import rasterio as rio
|
||||
import torch
|
||||
from affine import Affine
|
||||
from rasterio.features import rasterize
|
||||
from torch import Tensor
|
||||
|
||||
from torchgeo.datasets.geo import VisionDataset
|
||||
from torchgeo.datasets.utils import (
|
||||
check_integrity,
|
||||
download_radiant_mlhub,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
class SpaceNet1(VisionDataset):
|
||||
"""SpaceNet 1: Building Detection v1 Dataset.
|
||||
|
||||
`SpaceNet 1 <https://spacenet.ai/spacenet-buildings-dataset-v1/>`_
|
||||
is a dataset of building footprints over the city of Rio de Janeiro.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* No. of images - 6940 (8 Band) + 6940 (RGB)
|
||||
* No. of polygons - 382,534 building labels
|
||||
* Area Coverage - 2544 sq km
|
||||
|
||||
Dataset format:
|
||||
|
||||
* Imagery - Raw 8 band Worldview-3 (GeoTIFF) & Pansharpened RGB image (GeoTIFF)
|
||||
* Labels - GeoJSON
|
||||
|
||||
If you are using data from SpaceNet in a paper, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/abs/1807.01232
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
|
||||
imagery and labels from the Radiant Earth MLHub
|
||||
|
||||
"""
|
||||
|
||||
dataset_id = "spacenet1"
|
||||
md5 = "e6ea35331636fa0c036c04b3d1cbf226"
|
||||
imagery = {"rgb": "RGB.tif", "8band": "8Band.tif"}
|
||||
label_glob = "labels.geojson"
|
||||
foldername = "sn1_AOI_1_RIO"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
image: str = "rgb",
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialise a new SpaceNet 1 Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
image: image selection which must be "rgb" or "8band"
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version.
|
||||
download: if True, download dataset and store it in the root directory.
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing
|
||||
"""
|
||||
self.root = root
|
||||
self.image = image # For testing
|
||||
self.filename = self.imagery[image]
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
||||
if not self._check_integrity():
|
||||
if download:
|
||||
self._download(api_key)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Dataset not found. You can use download=True to download it."
|
||||
)
|
||||
|
||||
self.files = self._load_files(os.path.join(root, self.foldername))
|
||||
|
||||
def _load_files(self, root: str) -> List[Dict[str, str]]:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
root: root dir of dataset
|
||||
|
||||
Returns:
|
||||
list of dicts containing paths for each triple of rgb,
|
||||
8band and label
|
||||
"""
|
||||
files = []
|
||||
images = glob.glob(os.path.join(root, "*", self.filename))
|
||||
images = sorted(images)
|
||||
for imgpath in images:
|
||||
lbl_path = os.path.join(
|
||||
os.path.dirname(imgpath) + "-labels", "labels.geojson"
|
||||
)
|
||||
files.append({"image_path": imgpath, "label_path": lbl_path})
|
||||
return files
|
||||
|
||||
def _load_image(self, path: str) -> Tuple[Tensor, Affine]:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
path: path to the image
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
filename = os.path.join(path)
|
||||
with rio.open(filename) as img:
|
||||
array = img.read().astype(np.float32)
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
return tensor, img.transform
|
||||
|
||||
def _load_mask(self, path: str, tfm: Affine, shape: Tuple[int, int]) -> Tensor:
|
||||
"""Rasterizes the dataset's labels (in geojson format).
|
||||
|
||||
Args:
|
||||
path (str): path to the label
|
||||
tfm (Affine): transform of corresponding image
|
||||
shape (List[int, int]): shape of corresponding image
|
||||
|
||||
Returns:
|
||||
Tensor: label tensor
|
||||
"""
|
||||
with fiona.open(path) as src:
|
||||
labels = [feature["geometry"] for feature in src]
|
||||
|
||||
if not labels:
|
||||
mask_data = np.zeros(shape=shape)
|
||||
else:
|
||||
mask_data = rasterize(
|
||||
labels,
|
||||
out_shape=shape,
|
||||
fill=0, # nodata value
|
||||
transform=tfm,
|
||||
all_touched=False,
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
mask: Tensor = torch.from_numpy(mask_data).long() # type: ignore[attr-defined]
|
||||
|
||||
return mask
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of samples in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.files)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
files = self.files[index]
|
||||
img, tfm = self._load_image(files["image_path"])
|
||||
h, w = img.shape[1:]
|
||||
mask = self._load_mask(files["label_path"], tfm, (h, w))
|
||||
|
||||
sample = {"image": img, "mask": mask}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Checks the integrity of the dataset structure.
|
||||
|
||||
Returns:
|
||||
True if the dataset directories are found, else False
|
||||
"""
|
||||
stacpath = os.path.join(self.root, self.foldername, "collection.json")
|
||||
|
||||
if os.path.exists(stacpath):
|
||||
return True
|
||||
|
||||
# If dataset folder does not exist, check for uncorrupted archive
|
||||
archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
|
||||
if not os.path.exists(archive_path):
|
||||
return False
|
||||
print("Archive found")
|
||||
if self.checksum and not check_integrity(archive_path, self.md5):
|
||||
print("Dataset corrupted")
|
||||
return False
|
||||
print("Extracting...")
|
||||
extract_archive(archive_path)
|
||||
return True
|
||||
|
||||
def _download(self, api_key: Optional[str] = None) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Args:
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
|
||||
Raises:
|
||||
RuntimeError: if download doesn't work correctly or checksums don't match
|
||||
"""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded")
|
||||
return
|
||||
|
||||
download_radiant_mlhub(self.dataset_id, self.root, api_key)
|
||||
archive_path = os.path.join(self.root, self.foldername + ".tar.gz")
|
||||
if (
|
||||
self.checksum
|
||||
and check_integrity(archive_path, self.md5)
|
||||
or not self.checksum
|
||||
):
|
||||
extract_archive(archive_path)
|
||||
else:
|
||||
raise RuntimeError("Dataset corrupted")
|
Загрузка…
Ссылка в новой задаче