зеркало из https://github.com/microsoft/torchgeo.git
Add ETCI2021 Dataset (#119)
* add dataset to docs * add sample test data * add dataset unit tests * add etci2021 dataset * updated tests * updated dataset to download only desired split file * removed flood mask from file list for test set and other formatting * Update torchgeo/datasets/etci2021.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * fixed doc formatting Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
ea7dc26393
Коммит
67f7d8a520
|
@ -87,6 +87,11 @@ CV4A Kenya Crop Type Competition
|
|||
|
||||
.. autoclass:: CV4AKenyaCropType
|
||||
|
||||
ETCI2021 Flood Detection
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: ETCI2021
|
||||
|
||||
GID-15 (Gaofen Image Dataset)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ETCI2021
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestETCI2021:
|
||||
@pytest.fixture(params=["train", "val", "test"])
|
||||
def dataset(
|
||||
self,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
request: SubRequest,
|
||||
) -> ETCI2021:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.utils, "download_url", download_url
|
||||
)
|
||||
data_dir = os.path.join("tests", "data", "etci2021")
|
||||
metadata = {
|
||||
"train": {
|
||||
"filename": "train.zip",
|
||||
"md5": "50c10eb07d6db9aee3ba36401e4a2c45",
|
||||
"directory": "train",
|
||||
"url": os.path.join(data_dir, "train.zip"),
|
||||
},
|
||||
"val": {
|
||||
"filename": "val_with_ref_labels.zip",
|
||||
"md5": "3e8b5a3cb95e6029e0e2c2d4b4ec6fba",
|
||||
"directory": "test",
|
||||
"url": os.path.join(data_dir, "val_with_ref_labels.zip"),
|
||||
},
|
||||
"test": {
|
||||
"filename": "test_without_ref_labels.zip",
|
||||
"md5": "c8ee1e5d3e478761cd00ebc6f28b0ae7",
|
||||
"directory": "test_internal",
|
||||
"url": os.path.join(data_dir, "test_without_ref_labels.zip"),
|
||||
},
|
||||
}
|
||||
monkeypatch.setattr(ETCI2021, "metadata", metadata) # type: ignore[attr-defined] # noqa: E501
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
return ETCI2021(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: ETCI2021) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
assert x["image"].shape[0] == 6
|
||||
assert x["image"].shape[-2:] == x["mask"].shape[-2:]
|
||||
|
||||
if dataset.split != "test":
|
||||
assert x["mask"].shape[0] == 2
|
||||
else:
|
||||
assert x["mask"].shape[0] == 1
|
||||
|
||||
def test_len(self, dataset: ETCI2021) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: ETCI2021) -> None:
|
||||
ETCI2021(root=dataset.root, download=True)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
ETCI2021(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
ETCI2021(str(tmp_path))
|
|
@ -22,6 +22,7 @@ from .chesapeake import (
|
|||
from .cowc import COWC, COWCCounting, COWCDetection
|
||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .cyclone import TropicalCycloneWindEstimation
|
||||
from .etci2021 import ETCI2021
|
||||
from .geo import GeoDataset, RasterDataset, VectorDataset, VisionDataset, ZipDataset
|
||||
from .gid15 import GID15
|
||||
from .landcoverai import LandCoverAI
|
||||
|
@ -82,6 +83,7 @@ __all__ = (
|
|||
"COWCCounting",
|
||||
"COWCDetection",
|
||||
"CV4AKenyaCropType",
|
||||
"ETCI2021",
|
||||
"GID15",
|
||||
"LandCoverAI",
|
||||
"LEVIRCDPlus",
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""ETCI 2021 dataset."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import download_and_extract_archive
|
||||
|
||||
|
||||
class ETCI2021(VisionDataset):
|
||||
"""ETCI 2021 Flood Detection dataset.
|
||||
|
||||
The `ETCI2021 <https://nasa-impact.github.io/etci2021/>`_
|
||||
dataset is a dataset for flood detection
|
||||
|
||||
Dataset features:
|
||||
|
||||
* 33,405 VV & VH Sentinel-1 Synthetic Aperture Radar (SAR) images
|
||||
* 2 binary masks per image representing water body & flood, respectively
|
||||
* 2 polarization band images (VV, VH) of 3 RGB channels per band
|
||||
* 3 RGB channels per band generated by the Hybrid Pluggable Processing
|
||||
Pipeline (hyp3)
|
||||
* Images with 5x20m per pixel resolution (256x256) px) taken in
|
||||
Interferometric Wide Swath acquisition mode
|
||||
* Flood events from 5 different regions
|
||||
|
||||
Dataset format:
|
||||
|
||||
* VV band three-channel png
|
||||
* VH band three-channel png
|
||||
* water body mask single-channel png where no water body = 0, water body = 255
|
||||
* flood mask single-channel png where no flood = 0, flood = 255
|
||||
|
||||
Dataset classes:
|
||||
|
||||
1. no flood/water
|
||||
2. flood/water
|
||||
|
||||
If you use this dataset in your research, please add the following to your
|
||||
acknowledgements section::
|
||||
|
||||
The authors would like to thank the NASA Earth Science Data Systems Program,
|
||||
NASA Digital Transformation AI/ML thrust, and IEEE GRSS for organizing
|
||||
the ETCI competition.
|
||||
"""
|
||||
|
||||
bands = ["VV", "VH"]
|
||||
masks = ["flood", "water_body"]
|
||||
metadata = {
|
||||
"train": {
|
||||
"filename": "train.zip",
|
||||
"md5": "1e95792fe0f6e3c9000abdeab2a8ab0f",
|
||||
"directory": "train",
|
||||
"url": "https://drive.google.com/file/d/14HqNW5uWLS92n7KrxKgDwUTsSEST6LCr",
|
||||
},
|
||||
"val": {
|
||||
"filename": "val_with_ref_labels.zip",
|
||||
"md5": "fd18cecb318efc69f8319f90c3771bdf",
|
||||
"directory": "test",
|
||||
"url": "https://drive.google.com/file/d/19sriKPHCZLfJn_Jmk3Z_0b3VaCBVRVyn",
|
||||
},
|
||||
"test": {
|
||||
"filename": "test_without_ref_labels.zip",
|
||||
"md5": "da9fa69e1498bd49d5c766338c6dac3d",
|
||||
"directory": "test_internal",
|
||||
"url": "https://drive.google.com/file/d/1rpMVluASnSHBfm2FhpPDio0GyCPOqg7E",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new ETCI 2021 dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train", "val", or "test"
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
"""
|
||||
assert split in self.metadata.keys()
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
||||
if download:
|
||||
self._download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
|
||||
self.files = self._load_files(self.root, self.split)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
files = self.files[index]
|
||||
vv = self._load_image(files["vv"])
|
||||
vh = self._load_image(files["vh"])
|
||||
water_mask = self._load_target(files["water_mask"])
|
||||
|
||||
if self.split != "test":
|
||||
flood_mask = self._load_target(files["flood_mask"])
|
||||
mask = torch.stack(tensors=[water_mask, flood_mask], dim=0)
|
||||
else:
|
||||
mask = water_mask.unsqueeze(0)
|
||||
|
||||
image = torch.cat(tensors=[vv, vh], dim=0) # type: ignore[attr-defined]
|
||||
sample = {"image": image, "mask": mask}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.files)
|
||||
|
||||
def _load_files(self, root: str, split: str) -> List[Dict[str, str]]:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
root: root dir of dataset
|
||||
split: subset of dataset, one of [train, val, test]
|
||||
|
||||
Returns:
|
||||
list of dicts containing paths for each pair of vv, vh,
|
||||
water body mask, flood mask (train/val only)
|
||||
"""
|
||||
files = []
|
||||
directory = self.metadata[split]["directory"]
|
||||
folders = sorted(glob.glob(os.path.join(root, directory, "*")))
|
||||
folders = [os.path.join(folder, "tiles") for folder in folders]
|
||||
for folder in folders:
|
||||
vvs = glob.glob(os.path.join(folder, "vv", "*.png"))
|
||||
vhs = glob.glob(os.path.join(folder, "vh", "*.png"))
|
||||
water_masks = glob.glob(os.path.join(folder, "water_body_label", "*.png"))
|
||||
|
||||
if split != "test":
|
||||
flood_masks = glob.glob(os.path.join(folder, "flood_label", "*.png"))
|
||||
|
||||
for vv, vh, flood_mask, water_mask in zip(
|
||||
vvs, vhs, flood_masks, water_masks
|
||||
):
|
||||
files.append(
|
||||
dict(vv=vv, vh=vh, flood_mask=flood_mask, water_mask=water_mask)
|
||||
)
|
||||
else:
|
||||
for vv, vh, water_mask in zip(vvs, vhs, water_masks):
|
||||
files.append(dict(vv=vv, vh=vh, water_mask=water_mask))
|
||||
|
||||
return files
|
||||
|
||||
def _load_image(self, path: str) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
path: path to the image
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
filename = os.path.join(path)
|
||||
with Image.open(filename) as img:
|
||||
array = np.array(img.convert("RGB"))
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
# Convert from HxWxC to CxHxW
|
||||
tensor = tensor.permute((2, 0, 1))
|
||||
return tensor
|
||||
|
||||
def _load_target(self, path: str) -> Tensor:
|
||||
"""Load the target mask for a single image.
|
||||
|
||||
Args:
|
||||
path: path to the image
|
||||
|
||||
Returns:
|
||||
the target mask
|
||||
"""
|
||||
filename = os.path.join(path)
|
||||
with Image.open(filename) as img:
|
||||
array = np.array(img.convert("L"))
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
tensor = torch.clip(tensor, min=0, max=1) # type: ignore[attr-defined]
|
||||
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Checks the integrity of the dataset structure.
|
||||
|
||||
Returns:
|
||||
True if the dataset directories and split files are found, else False
|
||||
"""
|
||||
directory = self.metadata[self.split]["directory"]
|
||||
dirpath = os.path.join(self.root, directory)
|
||||
if not os.path.exists(dirpath):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Raises:
|
||||
AssertionError: if the checksum of split.py does not match
|
||||
"""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
||||
download_and_extract_archive(
|
||||
self.metadata[self.split]["url"],
|
||||
self.root,
|
||||
filename=self.metadata[self.split]["filename"],
|
||||
md5=self.metadata[self.split]["md5"] if self.checksum else None,
|
||||
)
|
||||
|
||||
if os.path.exists(os.path.join(self.root, "__MACOSX")):
|
||||
shutil.rmtree(os.path.join(self.root, "__MACOSX"))
|
Загрузка…
Ссылка в новой задаче