This commit is contained in:
Ashwin Nair 2021-11-13 22:39:40 +04:00 коммит произвёл GitHub
Родитель 53f7c1d839
Коммит 5a071cef9f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 213 добавлений и 3 удалений

Просмотреть файл

@ -162,6 +162,7 @@ SpaceNet
.. autoclass:: SpaceNet1
.. autoclass:: SpaceNet2
.. autoclass:: SpaceNet4
.. autoclass:: SpaceNet7
Tropical Cyclone Wind Estimation Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Двоичные данные
tests/data/spacenet/sn7_test_source.tar.gz Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn7_train_labels.tar.gz Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn7_train_source.tar.gz Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -13,7 +13,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4
from torchgeo.datasets import SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet7
TEST_DATA_DIR = "tests/data/spacenet"
@ -171,7 +171,7 @@ class TestSpaceNet4:
def test_getitem(self, dataset: SpaceNet4) -> None:
# Get image-label pair with empty label to
# enusre coverage
# ensure coverage
x = dataset[2]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
@ -199,3 +199,56 @@ class TestSpaceNet4:
RuntimeError, match="Collection sn4_AOI_6_Atlanta corrupted"
):
SpaceNet4(root=dataset.root, download=True, checksum=True)
class TestSpaceNet7:
@pytest.fixture(params=["train", "test"])
def dataset(
self,
request: SubRequest,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> SpaceNet7:
radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.2.1")
monkeypatch.setattr( # type: ignore[attr-defined]
radiant_mlhub.Collection, "fetch", fetch_collection
)
test_md5 = {
"sn7_train_source": "254fd6b16e350b071137b2658332091f",
"sn7_train_labels": "05befe86b037a3af75c7143553033664",
"sn7_test_source": "37d98d44a9da39657ed4b7beee22a21e",
}
monkeypatch.setattr( # type: ignore[attr-defined]
SpaceNet7, "collection_md5_dict", test_md5
)
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return SpaceNet7(
root, split=request.param, transforms=transforms, download=True, api_key=""
)
def test_getitem(self, dataset: SpaceNet7) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
if dataset.split == "train":
assert isinstance(x["mask"], torch.Tensor)
def test_len(self, dataset: SpaceNet7) -> None:
if dataset.split == "train":
assert len(dataset) == 2
else:
assert len(dataset) == 1
def test_already_downloaded(self, dataset: SpaceNet4) -> None:
SpaceNet7(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SpaceNet7(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
dataset.collection_md5_dict["sn7_train_source"] = "randommd5hash123"
with pytest.raises(RuntimeError, match="Collection sn7_train_source corrupted"):
SpaceNet7(root=dataset.root, download=True, checksum=True)

Просмотреть файл

@ -59,7 +59,7 @@ from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS, SEN12MSDataModule
from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat, So2SatDataModule
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4
from .spacenet import SpaceNet, SpaceNet1, SpaceNet2, SpaceNet4, SpaceNet7
from .ucmerced import UCMerced, UCMercedDataModule
from .utils import BoundingBox, collate_dict
from .zuericrop import ZueriCrop
@ -123,6 +123,7 @@ __all__ = (
"SpaceNet1",
"SpaceNet2",
"SpaceNet4",
"SpaceNet7",
"TropicalCycloneWindEstimation",
"CycloneDataModule",
"UCMerced",

Просмотреть файл

@ -668,3 +668,158 @@ class SpaceNet4(SpaceNet):
for angle in self.angles:
files.extend(angle_file_map[angle])
return files
class SpaceNet7(SpaceNet):
"""SpaceNet 7: Multi-Temporal Urban Development Challenge.
`SpaceNet 7 <https://spacenet.ai/sn7-challenge/>`_ is a dataset which
consist of medium resolution (4.0m) satellite imagery mosaics acquired from
Planet Labs Dove constellation between 2017 and 2020. It includes 24
images (one per month) covering > 100 unique geographies, and comprises >
40,000 km2 of imagery and exhaustive polygon labels of building footprints
therein, totaling over 11M individual annotations.
Dataset features:
* No. of train samples: 1423
* No. of test samples: 466
* No. of building footprints: 11,080,000
* Area Coverage: 41,000 sq km
* Chip size: 1023 x 1023
* GSD: ~4m
Dataset format:
* Imagery - Planet Dove GeoTIFF
* mosaic.tif
* Labels - GeoJSON
* labels.geojson
If you use this dataset in your research, please cite the following paper:
* https://arxiv.org/abs/2102.04420
.. note::
This dataset requires the following additional library to be installed:
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
imagery and labels from the Radiant Earth MLHub
"""
dataset_id = "spacenet7"
collection_md5_dict = {
"sn7_train_source": "9f8cc109d744537d087bd6ff33132340",
"sn7_train_labels": "16f873e3f0f914d95a916fb39b5111b5",
"sn7_test_source": "e97914f58e962bba3e898f08a14f83b2",
}
imagery = {"img": "mosaic.tif"}
chip_size = {"img": (1023, 1023)}
label_glob = "labels.geojson"
def __init__(
self,
root: str,
split: str = "train",
transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
download: bool = False,
api_key: Optional[str] = None,
checksum: bool = False,
) -> None:
"""Initialize a new SpaceNet 7 Dataset instance.
Args:
root: root directory where dataset can be found
split: split selection which must be in ["train", "test"]
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory.
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` but dataset is missing
"""
self.root = root
self.split = split
self.filename = self.imagery["img"]
self.transforms = transforms
self.checksum = checksum
assert split in {"train", "test"}, "Invalid split"
if split == "test":
self.collections = ["sn7_test_source"]
else:
self.collections = ["sn7_train_source", "sn7_train_labels"]
to_be_downloaded = self._check_integrity()
if to_be_downloaded:
if not download:
raise RuntimeError(
"Dataset not found. You can use download=True to download it."
)
else:
self._download(to_be_downloaded, api_key)
self.files = self._load_files(root)
def _load_files(self, root: str) -> List[Dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
root: root dir of dataset
Returns:
list of dicts containing paths for images and labels (if train split)
"""
files = []
if self.split == "train":
imgs = sorted(
glob.glob(os.path.join(root, "sn7_train_source", "*", self.filename))
)
lbls = sorted(
glob.glob(os.path.join(root, "sn7_train_labels", "*", self.label_glob))
)
for img, lbl in zip(imgs, lbls):
files.append({"image_path": img, "label_path": lbl})
else:
imgs = sorted(
glob.glob(os.path.join(root, "sn7_test_source", "*", self.filename))
)
for img in imgs:
files.append({"image_path": img})
return files
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data at that index
"""
sample = {}
files = self.files[index]
img, tfm = self._load_image(files["image_path"])
h, w = img.shape[1:]
ch, cw = self.chip_size["img"]
sample["image"] = img[:, :ch, :cw]
if self.split == "train":
mask = self._load_mask(files["label_path"], tfm, (h, w))
sample["mask"] = mask[:ch, :cw]
if self.transforms is not None:
sample = self.transforms(sample)
return sample