* add quakeset dataset

* add datamodule and tests

* update plot

* add plot title spacing

* fix tests

* fix tests finally

* fix mypy

* fix url

* fix mypy

* pin hf url to commit

* fix docs

* update dataset docs

* add missing h5py test

* fixes per suggestions

* updates per suggestions x3

* add setup method to define validation split

* undo split renaming

* update docstring
This commit is contained in:
Isaac Corley 2024-04-19 15:27:43 -05:00 коммит произвёл GitHub
Родитель b71a948736
Коммит 0f063916cc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
12 изменённых файлов: 500 добавлений и 1 удалений

Просмотреть файл

@ -133,6 +133,11 @@ Potsdam
.. autoclass:: Potsdam2DDataModule
QuakeSet
^^^^^^^^
.. autoclass:: QuakeSetDataModule
RESISC45
^^^^^^^^

Просмотреть файл

@ -353,6 +353,11 @@ Potsdam
.. autoclass:: Potsdam2D
QuakeSet
^^^^^^^^
.. autoclass:: QuakeSet
ReforesTree
^^^^^^^^^^^

Просмотреть файл

@ -29,6 +29,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`PASTIS`_,I,Sentinel-1/2,"CC-BY-4.0","2,433",19,128x128xT,10,MSI
`PatternNet`_,C,Google Earth,-,"30,400",38,256x256,0.06--5,RGB
`Potsdam`_,S,Aerial,-,38,6,"6,000x6,000",0.05,MSI
`QuakeSet`_,"C, R",Sentinel-1,"OpenRAIL","3,327",2,512x512,10,SAR
`ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB
`RESISC45`_,C,Google Earth,"CC-BY-NC-4.0","31,500",45,256x256,0.2--30,RGB
`Rwanda Field Boundary`_,S,Planetscope,"NICFI AND CC-BY-4.0",70,2,256x256,4.7,RGB + NIR

1 Dataset Task Source License # Samples # Classes Size (px) Resolution (m) Bands
29 `PASTIS`_ I Sentinel-1/2 CC-BY-4.0 2,433 19 128x128xT 10 MSI
30 `PatternNet`_ C Google Earth - 30,400 38 256x256 0.06--5 RGB
31 `Potsdam`_ S Aerial - 38 6 6,000x6,000 0.05 MSI
32 `QuakeSet`_ C, R Sentinel-1 OpenRAIL 3,327 2 512x512 10 SAR
33 `ReforesTree`_ OD, R Aerial CC-BY-4.0 100 6 4,000x4,000 0.02 RGB
34 `RESISC45`_ C Google Earth CC-BY-NC-4.0 31,500 45 256x256 0.2--30 RGB
35 `Rwanda Field Boundary`_ S Planetscope NICFI AND CC-BY-4.0 70 2 256x256 4.7 RGB + NIR

14
tests/conf/quakeset.yaml Normal file
Просмотреть файл

@ -0,0 +1,14 @@
model:
class_path: ClassificationTask
init_args:
loss: "ce"
model: "resnet18"
in_channels: 4
num_classes: 2
data:
class_path: QuakeSetDataModule
init_args:
batch_size: 2
dict_kwargs:
root: "tests/data/quakeset"
download: false

Просмотреть файл

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import hashlib
import os
import h5py
import numpy as np
NUM_CHANNELS = 2
SIZE = 32
np.random.seed(0)
filename = "earthquakes.h5"
splits = {
"train": ["611645479", "611658170"],
"validation": ["611684805", "611744956"],
"test": ["611798698", "611818836"],
}
# Remove old data
if os.path.exists(filename):
os.remove(filename)
# Create dataset file
data = np.random.randn(SIZE, SIZE, NUM_CHANNELS)
data = data.astype(np.float32)
with h5py.File(filename, "w") as f:
for split, keys in splits.items():
for key in keys:
sample = f.create_group(key)
sample.attrs.create(name="magnitude", data=np.float32(0.0))
sample.attrs.create(name="split", data=split)
for i in range(2):
patch = sample.create_group(f"patch_{i}")
patch.create_dataset("before", data=data)
patch.create_dataset("pre", data=data)
patch.create_dataset("post", data=data)
# Compute checksums
with open(filename, "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"md5: {md5}")

Двоичные данные
tests/data/quakeset/earthquakes.h5 Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,88 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, QuakeSet
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestQuakeSet:
@pytest.fixture(params=["train", "val", "test"])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> QuakeSet:
monkeypatch.setattr(torchgeo.datasets.quakeset, "download_url", download_url)
url = os.path.join("tests", "data", "quakeset", "earthquakes.h5")
md5 = "127d0d6a1f82d517129535f50053a4c9"
monkeypatch.setattr(QuakeSet, "md5", md5)
monkeypatch.setattr(QuakeSet, "url", url)
root = str(tmp_path)
split = request.param
transforms = nn.Identity()
return QuakeSet(
root, split, transforms=transforms, download=True, checksum=True
)
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "h5py":
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, "__import__", mocked_import)
def test_mock_missing_module(
self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match="h5py is not installed and is required to use this dataset",
):
QuakeSet(dataset.root, download=True, checksum=True)
def test_getitem(self, dataset: QuakeSet) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["image"].shape[0] == 4
def test_len(self, dataset: QuakeSet) -> None:
assert len(dataset) == 8
def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None:
QuakeSet(root=str(tmp_path), download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
QuakeSet(str(tmp_path))
def test_plot(self, dataset: QuakeSet) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["label"].clone()
x["magnitude"] = torch.tensor(0.0)
dataset.plot(x)
plt.close()

Просмотреть файл

@ -76,6 +76,7 @@ class TestClassificationTask:
"eurosat",
"eurosat100",
"fire_risk",
"quakeset",
"resisc45",
"so2sat_all",
"so2sat_s1",
@ -87,7 +88,7 @@ class TestClassificationTask:
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
if name.startswith("so2sat"):
if name.startswith("so2sat") or name == "quakeset":
pytest.importorskip("h5py", minversion="3")
config = os.path.join("tests", "conf", name + ".yaml")

Просмотреть файл

@ -27,6 +27,7 @@ from .naip import NAIPChesapeakeDataModule
from .nasa_marine_debris import NASAMarineDebrisDataModule
from .oscd import OSCDDataModule
from .potsdam import Potsdam2DDataModule
from .quakeset import QuakeSetDataModule
from .resisc45 import RESISC45DataModule
from .seco import SeasonalContrastS2DataModule
from .sen12ms import SEN12MSDataModule
@ -78,6 +79,7 @@ __all__ = (
"NASAMarineDebrisDataModule",
"OSCDDataModule",
"Potsdam2DDataModule",
"QuakeSetDataModule",
"RESISC45DataModule",
"SeasonalContrastS2DataModule",
"SEN12MSDataModule",

Просмотреть файл

@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""QuakeSet datamodule."""
from typing import Any
import kornia.augmentation as K
import torch
from ..datasets import QuakeSet
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
class QuakeSetDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the QuakeSet dataset.
.. versionadded:: 0.6
"""
mean = torch.tensor(0.0)
std = torch.tensor(1.0)
def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
"""Initialize a new QuakeSetDataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.QuakeSet`.
"""
super().__init__(QuakeSet, batch_size, num_workers, **kwargs)
self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=["image"],
)

Просмотреть файл

@ -91,6 +91,7 @@ from .pastis import PASTIS
from .patternnet import PatternNet
from .potsdam import Potsdam2D
from .prisma import PRISMA
from .quakeset import QuakeSet
from .reforestree import ReforesTree
from .resisc45 import RESISC45
from .rwanda_field_boundary import RwandaFieldBoundary
@ -228,6 +229,7 @@ __all__ = (
"PASTIS",
"PatternNet",
"Potsdam2D",
"QuakeSet",
"RESISC45",
"ReforesTree",
"RwandaFieldBoundary",

Просмотреть файл

@ -0,0 +1,290 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""QuakeSet dataset."""
import os
from collections.abc import Callable
from typing import Any, cast
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
from .utils import DatasetNotFoundError, download_url, percentile_normalization
class QuakeSet(NonGeoDataset):
"""QuakeSet dataset.
`QuakeSet <https://huggingface.co/datasets/DarthReca/quakeset>`__
is a dataset for Earthquake Change Detection and Magnitude Estimation and is used
for the Seismic Monitoring and Analysis (SMAC) ECML-PKDD 2024 Discovery Challenge.
Dataset features:
* Sentinel-1 SAR imagery
* before/pre/post imagery of areas affected by earthquakes
* 2 SAR bands (VV/VH)
* 3,327 pairs of pre and post images with 5 m per pixel resolution (512x512 px)
* 2 classification labels (unaffected / affected by earthquake)
* pre/post image pairs represent earthquake affected areas
* before/pre image pairs represent hard negative unaffected areas
* earthquake magnitudes for each sample
Dataset format:
* single hdf5 dataset containing images, magnitudes, hypercenters, and splits
Dataset classes:
0. unaffected area
1. earthquake affected area
If you use this dataset in your research, please cite the following paper:
* https://arxiv.org/abs/2403.18116
.. note::
This dataset requires the following additional library to be installed:
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
.. versionadded:: 0.6
"""
filename = "earthquakes.h5"
url = "https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5"
md5 = "76fc7c76b7ca56f4844d852e175e1560"
splits = {"train": "train", "val": "validation", "test": "test"}
classes = ["unaffected_area", "earthquake_affected_area"]
def __init__(
self,
root: str = "data",
split: str = "train",
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new QuakeSet dataset instance.
Args:
root: root directory where dataset can be found
split: one of "train", "val", or "test"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: If ``split`` argument is invalid.
DatasetNotFoundError: If dataset is not found and *download* is False.
ImportError: if h5py is not installed
"""
assert split in self.splits
self.root = root
self.split = split
self.transforms = transforms
self.download = download
self.checksum = checksum
self.filepath = os.path.join(root, self.filename)
self._verify()
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
"h5py is not installed and is required to use this dataset"
)
self.data = self._load_data()
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
sample containing image and mask
"""
image = self._load_image(index)
label = torch.tensor(self.data[index]["label"])
magnitude = torch.tensor(self.data[index]["magnitude"])
sample = {"image": image, "label": label, "magnitude": magnitude}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.data)
def _load_data(self) -> list[dict[str, Any]]:
"""Return the metadata for a given split.
Returns:
the sample keys, patches, images, labels, and magnitudes
"""
import h5py
data = []
with h5py.File(self.filepath) as f:
for k in sorted(f.keys()):
if f[k].attrs["split"] != self.splits[self.split]:
continue
for patch in sorted(f[k].keys()):
if patch not in ["x", "y"]:
# positive sample
magnitude = float(f[k].attrs["magnitude"])
data.append(
dict(
key=k,
patch=patch,
images=("pre", "post"),
label=1,
magnitude=magnitude,
)
)
# hard negative sample
if "before" in f[k][patch].keys():
data.append(
dict(
key=k,
patch=patch,
images=("before", "pre"),
label=0,
magnitude=0.0,
)
)
return data
def _load_image(self, index: int) -> Tensor:
"""Load a single image.
Args:
index: index to return
Returns:
the image
"""
import h5py
key = self.data[index]["key"]
patch = self.data[index]["patch"]
images = self.data[index]["images"]
with h5py.File(self.filepath) as f:
pre_array = f[key][patch][images[0]][:]
pre_array = np.nan_to_num(pre_array, nan=0)
post_array = f[key][patch][images[1]][:]
post_array = np.nan_to_num(post_array, nan=0)
array = np.concatenate([pre_array, post_array], axis=-1)
array = array.astype(np.float32)
tensor = torch.from_numpy(array)
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the files already exist
if os.path.exists(self.filepath):
return
# Check if the user requested to download the dataset
if not self.download:
raise DatasetNotFoundError(self)
# Download the dataset
self._download()
def _download(self) -> None:
"""Download the dataset."""
if not os.path.exists(self.filepath):
download_url(
self.url,
self.root,
filename=self.filename,
md5=self.md5 if self.checksum else None,
)
def plot(
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional suptitle to use for figure
Returns:
a matplotlib Figure with the rendered sample
"""
image = sample["image"].permute((1, 2, 0)).numpy()
label = cast(int, sample["label"].item())
label_class = self.classes[label]
# Create false color image for image1
vv = percentile_normalization(image[..., 0]) + 1e-16
vh = percentile_normalization(image[..., 1]) + 1e-16
fci1 = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1)
# Create false color image for image2
vv = percentile_normalization(image[..., 2]) + 1e-16
vh = percentile_normalization(image[..., 3]) + 1e-16
fci2 = np.stack([vv, vh, vv / vh], axis=-1).clip(0, 1)
showing_predictions = "prediction" in sample
if showing_predictions:
prediction = cast(int, sample["prediction"].item())
prediction_class = self.classes[prediction]
ncols = 2
fig, axs = plt.subplots(
nrows=1, ncols=ncols, figsize=(ncols * 5, 10), sharex=True
)
axs[0].imshow(fci1)
axs[0].axis("off")
axs[0].set_title("Image Pre")
axs[1].imshow(fci2)
axs[1].axis("off")
axs[1].set_title("Image Post")
if show_titles:
title = f"Label: {label_class}"
if "magnitude" in sample:
magnitude = cast(float, sample["magnitude"].item())
title += f" | Magnitude: {magnitude:.2f}"
if showing_predictions:
title += f"\nPrediction: {prediction_class}"
fig.supxlabel(title, y=0.22)
if suptitle is not None:
fig.suptitle(suptitle, y=0.8)
fig.tight_layout()
return fig