зеркало из https://github.com/microsoft/torchgeo.git
Add SpaceNet 7 (#241)
This commit is contained in:
Родитель
53f7c1d839
Коммит
5a071cef9f
|
@ -162,6 +162,7 @@ SpaceNet
|
|||
.. autoclass:: SpaceNet1
|
||||
.. autoclass:: SpaceNet2
|
||||
.. autoclass:: SpaceNet4
|
||||
.. autoclass:: SpaceNet7
|
||||
|
||||
Tropical Cyclone Wind Estimation Competition
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4
|
||||
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet7
|
||||
|
||||
TEST_DATA_DIR = "tests/data/spacenet"
|
||||
|
||||
|
@ -171,7 +171,7 @@ class TestSpaceNet4:
|
|||
|
||||
def test_getitem(self, dataset: SpaceNet4) -> None:
|
||||
# Get image-label pair with empty label to
|
||||
# enusre coverage
|
||||
# ensure coverage
|
||||
x = dataset[2]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
|
@ -199,3 +199,56 @@ class TestSpaceNet4:
|
|||
RuntimeError, match="Collection sn4_AOI_6_Atlanta corrupted"
|
||||
):
|
||||
SpaceNet4(root=dataset.root, download=True, checksum=True)
|
||||
|
||||
|
||||
class TestSpaceNet7:
|
||||
@pytest.fixture(params=["train", "test"])
|
||||
def dataset(
|
||||
self,
|
||||
request: SubRequest,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
) -> SpaceNet7:
|
||||
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
radiant_mlhub.Collection, "fetch", fetch_collection
|
||||
)
|
||||
test_md5 = {
|
||||
"sn7_train_source": "254fd6b16e350b071137b2658332091f",
|
||||
"sn7_train_labels": "05befe86b037a3af75c7143553033664",
|
||||
"sn7_test_source": "37d98d44a9da39657ed4b7beee22a21e",
|
||||
}
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
SpaceNet7, "collection_md5_dict", test_md5
|
||||
)
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return SpaceNet7(
|
||||
root, split=request.param, transforms=transforms, download=True, api_key=""
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet7) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
if dataset.split == "train":
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_len(self, dataset: SpaceNet7) -> None:
|
||||
if dataset.split == "train":
|
||||
assert len(dataset) == 2
|
||||
else:
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet4) -> None:
|
||||
SpaceNet7(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
SpaceNet7(str(tmp_path))
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
|
||||
dataset.collection_md5_dict["sn7_train_source"] = "randommd5hash123"
|
||||
with pytest.raises(RuntimeError, match="Collection sn7_train_source corrupted"):
|
||||
SpaceNet7(root=dataset.root, download=True, checksum=True)
|
||||
|
|
|
@ -59,7 +59,7 @@ from .seco import SeasonalContrastS2
|
|||
from .sen12ms import SEN12MS, SEN12MSDataModule
|
||||
from .sentinel import Sentinel, Sentinel2
|
||||
from .so2sat import So2Sat, So2SatDataModule
|
||||
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4
|
||||
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet7
|
||||
from .ucmerced import UCMerced, UCMercedDataModule
|
||||
from .utils import BoundingBox, collate_dict
|
||||
from .zuericrop import ZueriCrop
|
||||
|
@ -123,6 +123,7 @@ __all__ = (
|
|||
"SpaceNet1",
|
||||
"SpaceNet2",
|
||||
"SpaceNet4",
|
||||
"SpaceNet7",
|
||||
"TropicalCycloneWindEstimation",
|
||||
"CycloneDataModule",
|
||||
"UCMerced",
|
||||
|
|
|
@ -668,3 +668,158 @@ class SpaceNet4(SpaceNet):
|
|||
for angle in self.angles:
|
||||
files.extend(angle_file_map[angle])
|
||||
return files
|
||||
|
||||
|
||||
class SpaceNet7(SpaceNet):
|
||||
"""SpaceNet 7: Multi-Temporal Urban Development Challenge.
|
||||
|
||||
`SpaceNet 7 <https://spacenet.ai/sn7-challenge/>`_ is a dataset which
|
||||
consist of medium resolution (4.0m) satellite imagery mosaics acquired from
|
||||
Planet Labs’ Dove constellation between 2017 and 2020. It includes ≈ 24
|
||||
images (one per month) covering > 100 unique geographies, and comprises >
|
||||
40,000 km2 of imagery and exhaustive polygon labels of building footprints
|
||||
therein, totaling over 11M individual annotations.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* No. of train samples: 1423
|
||||
* No. of test samples: 466
|
||||
* No. of building footprints: 11,080,000
|
||||
* Area Coverage: 41,000 sq km
|
||||
* Chip size: 1023 x 1023
|
||||
* GSD: ~4m
|
||||
|
||||
Dataset format:
|
||||
|
||||
* Imagery - Planet Dove GeoTIFF
|
||||
|
||||
* mosaic.tif
|
||||
|
||||
* Labels - GeoJSON
|
||||
|
||||
* labels.geojson
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/abs/2102.04420
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
|
||||
imagery and labels from the Radiant Earth MLHub
|
||||
|
||||
"""
|
||||
|
||||
dataset_id = "spacenet7"
|
||||
collection_md5_dict = {
|
||||
"sn7_train_source": "9f8cc109d744537d087bd6ff33132340",
|
||||
"sn7_train_labels": "16f873e3f0f914d95a916fb39b5111b5",
|
||||
"sn7_test_source": "e97914f58e962bba3e898f08a14f83b2",
|
||||
}
|
||||
|
||||
imagery = {"img": "mosaic.tif"}
|
||||
chip_size = {"img": (1023, 1023)}
|
||||
|
||||
label_glob = "labels.geojson"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str,
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
||||
download: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SpaceNet 7 Dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: split selection which must be in ["train", "test"]
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory.
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing
|
||||
"""
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.filename = self.imagery["img"]
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
||||
assert split in {"train", "test"}, "Invalid split"
|
||||
|
||||
if split == "test":
|
||||
self.collections = ["sn7_test_source"]
|
||||
else:
|
||||
self.collections = ["sn7_train_source", "sn7_train_labels"]
|
||||
|
||||
to_be_downloaded = self._check_integrity()
|
||||
|
||||
if to_be_downloaded:
|
||||
if not download:
|
||||
raise RuntimeError(
|
||||
"Dataset not found. You can use download=True to download it."
|
||||
)
|
||||
else:
|
||||
self._download(to_be_downloaded, api_key)
|
||||
|
||||
self.files = self._load_files(root)
|
||||
|
||||
def _load_files(self, root: str) -> List[Dict[str, str]]:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
root: root dir of dataset
|
||||
|
||||
Returns:
|
||||
list of dicts containing paths for images and labels (if train split)
|
||||
"""
|
||||
files = []
|
||||
if self.split == "train":
|
||||
imgs = sorted(
|
||||
glob.glob(os.path.join(root, "sn7_train_source", "*", self.filename))
|
||||
)
|
||||
lbls = sorted(
|
||||
glob.glob(os.path.join(root, "sn7_train_labels", "*", self.label_glob))
|
||||
)
|
||||
for img, lbl in zip(imgs, lbls):
|
||||
files.append({"image_path": img, "label_path": lbl})
|
||||
else:
|
||||
imgs = sorted(
|
||||
glob.glob(os.path.join(root, "sn7_test_source", "*", self.filename))
|
||||
)
|
||||
for img in imgs:
|
||||
files.append({"image_path": img})
|
||||
return files
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data at that index
|
||||
"""
|
||||
sample = {}
|
||||
files = self.files[index]
|
||||
img, tfm = self._load_image(files["image_path"])
|
||||
h, w = img.shape[1:]
|
||||
|
||||
ch, cw = self.chip_size["img"]
|
||||
sample["image"] = img[:, :ch, :cw]
|
||||
if self.split == "train":
|
||||
mask = self._load_mask(files["label_path"], tfm, (h, w))
|
||||
sample["mask"] = mask[:ch, :cw]
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
|
Загрузка…
Ссылка в новой задаче