зеркало из https://github.com/microsoft/torchgeo.git
QuakeSet dataset (#1997)
* add quakeset dataset * add datamodule and tests * update plot * add plot title spacing * fix tests * fix tests finally * fix mypy * fix url * fix mypy * pin hf url to commit * fix docs * update dataset docs * add missing h5py test * fixes per suggestions * updates per suggestions x3 * add setup method to define validation split * undo split renaming * update docstring
This commit is contained in:
Родитель
b71a948736
Коммит
0f063916cc
|
@ -133,6 +133,11 @@ Potsdam
|
|||
|
||||
.. autoclass:: Potsdam2DDataModule
|
||||
|
||||
QuakeSet
|
||||
^^^^^^^^
|
||||
|
||||
.. autoclass:: QuakeSetDataModule
|
||||
|
||||
RESISC45
|
||||
^^^^^^^^
|
||||
|
||||
|
|
|
@ -353,6 +353,11 @@ Potsdam
|
|||
|
||||
.. autoclass:: Potsdam2D
|
||||
|
||||
QuakeSet
|
||||
^^^^^^^^
|
||||
|
||||
.. autoclass:: QuakeSet
|
||||
|
||||
ReforesTree
|
||||
^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
|
|||
`PASTIS`_,I,Sentinel-1/2,"CC-BY-4.0","2,433",19,128x128xT,10,MSI
|
||||
`PatternNet`_,C,Google Earth,-,"30,400",38,256x256,0.06--5,RGB
|
||||
`Potsdam`_,S,Aerial,-,38,6,"6,000x6,000",0.05,MSI
|
||||
`QuakeSet`_,"C, R",Sentinel-1,"OpenRAIL","3,327",2,512x512,10,SAR
|
||||
`ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB
|
||||
`RESISC45`_,C,Google Earth,"CC-BY-NC-4.0","31,500",45,256x256,0.2--30,RGB
|
||||
`Rwanda Field Boundary`_,S,Planetscope,"NICFI AND CC-BY-4.0",70,2,256x256,4.7,RGB + NIR
|
||||
|
|
|
|
@ -0,0 +1,14 @@
|
|||
model:
|
||||
class_path: ClassificationTask
|
||||
init_args:
|
||||
loss: "ce"
|
||||
model: "resnet18"
|
||||
in_channels: 4
|
||||
num_classes: 2
|
||||
data:
|
||||
class_path: QuakeSetDataModule
|
||||
init_args:
|
||||
batch_size: 2
|
||||
dict_kwargs:
|
||||
root: "tests/data/quakeset"
|
||||
download: false
|
|
@ -0,0 +1,49 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
NUM_CHANNELS = 2
|
||||
SIZE = 32
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
filename = "earthquakes.h5"
|
||||
|
||||
splits = {
|
||||
"train": ["611645479", "611658170"],
|
||||
"validation": ["611684805", "611744956"],
|
||||
"test": ["611798698", "611818836"],
|
||||
}
|
||||
|
||||
# Remove old data
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
# Create dataset file
|
||||
data = np.random.randn(SIZE, SIZE, NUM_CHANNELS)
|
||||
data = data.astype(np.float32)
|
||||
|
||||
|
||||
with h5py.File(filename, "w") as f:
|
||||
for split, keys in splits.items():
|
||||
for key in keys:
|
||||
sample = f.create_group(key)
|
||||
sample.attrs.create(name="magnitude", data=np.float32(0.0))
|
||||
sample.attrs.create(name="split", data=split)
|
||||
for i in range(2):
|
||||
patch = sample.create_group(f"patch_{i}")
|
||||
patch.create_dataset("before", data=data)
|
||||
patch.create_dataset("pre", data=data)
|
||||
patch.create_dataset("post", data=data)
|
||||
|
||||
# Compute checksums
|
||||
with open(filename, "rb") as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f"md5: {md5}")
|
Двоичный файл не отображается.
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import DatasetNotFoundError, QuakeSet
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestQuakeSet:
|
||||
@pytest.fixture(params=["train", "val", "test"])
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> QuakeSet:
|
||||
monkeypatch.setattr(torchgeo.datasets.quakeset, "download_url", download_url)
|
||||
url = os.path.join("tests", "data", "quakeset", "earthquakes.h5")
|
||||
md5 = "127d0d6a1f82d517129535f50053a4c9"
|
||||
monkeypatch.setattr(QuakeSet, "md5", md5)
|
||||
monkeypatch.setattr(QuakeSet, "url", url)
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return QuakeSet(
|
||||
root, split, transforms=transforms, download=True, checksum=True
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == "h5py":
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", mocked_import)
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match="h5py is not installed and is required to use this dataset",
|
||||
):
|
||||
QuakeSet(dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: QuakeSet) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["label"], torch.Tensor)
|
||||
assert x["image"].shape[0] == 4
|
||||
|
||||
def test_len(self, dataset: QuakeSet) -> None:
|
||||
assert len(dataset) == 8
|
||||
|
||||
def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None:
|
||||
QuakeSet(root=str(tmp_path), download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
QuakeSet(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: QuakeSet) -> None:
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, suptitle="Test")
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
x["prediction"] = x["label"].clone()
|
||||
x["magnitude"] = torch.tensor(0.0)
|
||||
dataset.plot(x)
|
||||
plt.close()
|
|
@ -76,6 +76,7 @@ class TestClassificationTask:
|
|||
"eurosat",
|
||||
"eurosat100",
|
||||
"fire_risk",
|
||||
"quakeset",
|
||||
"resisc45",
|
||||
"so2sat_all",
|
||||
"so2sat_s1",
|
||||
|
@ -87,7 +88,7 @@ class TestClassificationTask:
|
|||
def test_trainer(
|
||||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
|
||||
) -> None:
|
||||
if name.startswith("so2sat"):
|
||||
if name.startswith("so2sat") or name == "quakeset":
|
||||
pytest.importorskip("h5py", minversion="3")
|
||||
|
||||
config = os.path.join("tests", "conf", name + ".yaml")
|
||||
|
|
|
@ -27,6 +27,7 @@ from .naip import NAIPChesapeakeDataModule
|
|||
from .nasa_marine_debris import NASAMarineDebrisDataModule
|
||||
from .oscd import OSCDDataModule
|
||||
from .potsdam import Potsdam2DDataModule
|
||||
from .quakeset import QuakeSetDataModule
|
||||
from .resisc45 import RESISC45DataModule
|
||||
from .seco import SeasonalContrastS2DataModule
|
||||
from .sen12ms import SEN12MSDataModule
|
||||
|
@ -78,6 +79,7 @@ __all__ = (
|
|||
"NASAMarineDebrisDataModule",
|
||||
"OSCDDataModule",
|
||||
"Potsdam2DDataModule",
|
||||
"QuakeSetDataModule",
|
||||
"RESISC45DataModule",
|
||||
"SeasonalContrastS2DataModule",
|
||||
"SEN12MSDataModule",
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""QuakeSet datamodule."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
import torch
|
||||
|
||||
from ..datasets import QuakeSet
|
||||
from ..transforms import AugmentationSequential
|
||||
from .geo import NonGeoDataModule
|
||||
|
||||
|
||||
class QuakeSetDataModule(NonGeoDataModule):
|
||||
"""LightningDataModule implementation for the QuakeSet dataset.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
mean = torch.tensor(0.0)
|
||||
std = torch.tensor(1.0)
|
||||
|
||||
def __init__(
|
||||
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a new QuakeSetDataModule instance.
|
||||
|
||||
Args:
|
||||
batch_size: Size of each mini-batch.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
**kwargs: Additional keyword arguments passed to
|
||||
:class:`~torchgeo.datasets.QuakeSet`.
|
||||
"""
|
||||
super().__init__(QuakeSet, batch_size, num_workers, **kwargs)
|
||||
self.train_aug = AugmentationSequential(
|
||||
K.Normalize(mean=self.mean, std=self.std),
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
K.RandomVerticalFlip(p=0.5),
|
||||
data_keys=["image"],
|
||||
)
|
|
@ -91,6 +91,7 @@ from .pastis import PASTIS
|
|||
from .patternnet import PatternNet
|
||||
from .potsdam import Potsdam2D
|
||||
from .prisma import PRISMA
|
||||
from .quakeset import QuakeSet
|
||||
from .reforestree import ReforesTree
|
||||
from .resisc45 import RESISC45
|
||||
from .rwanda_field_boundary import RwandaFieldBoundary
|
||||
|
@ -228,6 +229,7 @@ __all__ = (
|
|||
"PASTIS",
|
||||
"PatternNet",
|
||||
"Potsdam2D",
|
||||
"QuakeSet",
|
||||
"RESISC45",
|
||||
"ReforesTree",
|
||||
"RwandaFieldBoundary",
|
||||
|
|
|
@ -0,0 +1,290 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""QuakeSet dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import DatasetNotFoundError, download_url, percentile_normalization
|
||||
|
||||
|
||||
class QuakeSet(NonGeoDataset):
|
||||
"""QuakeSet dataset.
|
||||
|
||||
`QuakeSet <https://huggingface.co/datasets/DarthReca/quakeset>`__
|
||||
is a dataset for Earthquake Change Detection and Magnitude Estimation and is used
|
||||
for the Seismic Monitoring and Analysis (SMAC) ECML-PKDD 2024 Discovery Challenge.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* Sentinel-1 SAR imagery
|
||||
* before/pre/post imagery of areas affected by earthquakes
|
||||
* 2 SAR bands (VV/VH)
|
||||
* 3,327 pairs of pre and post images with 5 m per pixel resolution (512x512 px)
|
||||
* 2 classification labels (unaffected / affected by earthquake)
|
||||
* pre/post image pairs represent earthquake affected areas
|
||||
* before/pre image pairs represent hard negative unaffected areas
|
||||
* earthquake magnitudes for each sample
|
||||
|
||||
Dataset format:
|
||||
|
||||
* single hdf5 dataset containing images, magnitudes, hypercenters, and splits
|
||||
|
||||
Dataset classes:
|
||||
|
||||
0. unaffected area
|
||||
1. earthquake affected area
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://arxiv.org/abs/2403.18116
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
filename = "earthquakes.h5"
|
||||
url = "https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5"
|
||||
md5 = "76fc7c76b7ca56f4844d852e175e1560"
|
||||
splits = {"train": "train", "val": "validation", "test": "test"}
|
||||
classes = ["unaffected_area", "earthquake_affected_area"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new QuakeSet dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train", "val", or "test"
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: If ``split`` argument is invalid.
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
ImportError: if h5py is not installed
|
||||
"""
|
||||
assert split in self.splits
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.filepath = os.path.join(root, self.filename)
|
||||
|
||||
self._verify()
|
||||
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"h5py is not installed and is required to use this dataset"
|
||||
)
|
||||
|
||||
self.data = self._load_data()
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
sample containing image and mask
|
||||
"""
|
||||
image = self._load_image(index)
|
||||
label = torch.tensor(self.data[index]["label"])
|
||||
magnitude = torch.tensor(self.data[index]["magnitude"])
|
||||
|
||||
sample = {"image": image, "label": label, "magnitude": magnitude}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.data)
|
||||
|
||||
def _load_data(self) -> list[dict[str, Any]]:
|
||||
"""Return the metadata for a given split.
|
||||
|
||||
Returns:
|
||||
the sample keys, patches, images, labels, and magnitudes
|
||||
"""
|
||||
import h5py
|
||||
|
||||
data = []
|
||||
with h5py.File(self.filepath) as f:
|
||||
for k in sorted(f.keys()):
|
||||
if f[k].attrs["split"] != self.splits[self.split]:
|
||||
continue
|
||||
|
||||
for patch in sorted(f[k].keys()):
|
||||
if patch not in ["x", "y"]:
|
||||
# positive sample
|
||||
magnitude = float(f[k].attrs["magnitude"])
|
||||
data.append(
|
||||
dict(
|
||||
key=k,
|
||||
patch=patch,
|
||||
images=("pre", "post"),
|
||||
label=1,
|
||||
magnitude=magnitude,
|
||||
)
|
||||
)
|
||||
|
||||
# hard negative sample
|
||||
if "before" in f[k][patch].keys():
|
||||
data.append(
|
||||
dict(
|
||||
key=k,
|
||||
patch=patch,
|
||||
images=("before", "pre"),
|
||||
label=0,
|
||||
magnitude=0.0,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
def _load_image(self, index: int) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
import h5py
|
||||
|
||||
key = self.data[index]["key"]
|
||||
patch = self.data[index]["patch"]
|
||||
images = self.data[index]["images"]
|
||||
|
||||
with h5py.File(self.filepath) as f:
|
||||
pre_array = f[key][patch][images[0]][:]
|
||||
pre_array = np.nan_to_num(pre_array, nan=0)
|
||||
post_array = f[key][patch][images[1]][:]
|
||||
post_array = np.nan_to_num(post_array, nan=0)
|
||||
array = np.concatenate([pre_array, post_array], axis=-1)
|
||||
array = array.astype(np.float32)
|
||||
|
||||
tensor = torch.from_numpy(array)
|
||||
# Convert from HxWxC to CxHxW
|
||||
tensor = tensor.permute((2, 0, 1))
|
||||
return tensor
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the files already exist
|
||||
if os.path.exists(self.filepath):
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
if not os.path.exists(self.filepath):
|
||||
download_url(
|
||||
self.url,
|
||||
self.root,
|
||||
filename=self.filename,
|
||||
md5=self.md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional suptitle to use for figure
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
image = sample["image"].permute((1, 2, 0)).numpy()
|
||||
label = cast(int, sample["label"].item())
|
||||
label_class = self.classes[label]
|
||||
|
||||
# Create false color image for image1
|
||||
vv = percentile_normalization(image[..., 0]) + 1e-16
|
||||
vh = percentile_normalization(image[..., 1]) + 1e-16
|
||||
fci1 = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1)
|
||||
|
||||
# Create false color image for image2
|
||||
vv = percentile_normalization(image[..., 2]) + 1e-16
|
||||
vh = percentile_normalization(image[..., 3]) + 1e-16
|
||||
fci2 = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1)
|
||||
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
prediction = cast(int, sample["prediction"].item())
|
||||
prediction_class = self.classes[prediction]
|
||||
|
||||
ncols = 2
|
||||
fig, axs = plt.subplots(
|
||||
nrows=1, ncols=ncols, figsize=(ncols * 5, 10), sharex=True
|
||||
)
|
||||
|
||||
axs[0].imshow(fci1)
|
||||
axs[0].axis("off")
|
||||
axs[0].set_title("Image Pre")
|
||||
axs[1].imshow(fci2)
|
||||
axs[1].axis("off")
|
||||
axs[1].set_title("Image Post")
|
||||
|
||||
if show_titles:
|
||||
title = f"Label: {label_class}"
|
||||
if "magnitude" in sample:
|
||||
magnitude = cast(float, sample["magnitude"].item())
|
||||
title += f" | Magnitude: {magnitude:.2f}"
|
||||
if showing_predictions:
|
||||
title += f"\nPrediction: {prediction_class}"
|
||||
fig.supxlabel(title, y=0.22)
|
||||
|
||||
if suptitle is not None:
|
||||
fig.suptitle(suptitle, y=0.8)
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче