зеркало из https://github.com/microsoft/torchgeo.git
Add CaBuAr dataset (#2235)
* 🆕 Added CaBuAr dataset * 🆕 Added CaBuAr datamodule * 🔨 Added CaBuAr datamodule test * 🔨 Corrected CaBuAr typing and datamodule test * 🔨 updated test, corrected docs, minor fixes to dataset and datamodule * 🔨 CaBuAr test fixes
This commit is contained in:
Родитель
042b75ea9c
Коммит
ccc314cd88
|
@ -57,6 +57,11 @@ BigEarthNet
|
||||||
|
|
||||||
.. autoclass:: BigEarthNetDataModule
|
.. autoclass:: BigEarthNetDataModule
|
||||||
|
|
||||||
|
CaBuAr
|
||||||
|
^^^^^^
|
||||||
|
|
||||||
|
.. autoclass:: CaBuArDataModule
|
||||||
|
|
||||||
ChaBuD
|
ChaBuD
|
||||||
^^^^^^
|
^^^^^^
|
||||||
|
|
||||||
|
|
|
@ -217,6 +217,11 @@ BioMassters
|
||||||
|
|
||||||
.. autoclass:: BioMassters
|
.. autoclass:: BioMassters
|
||||||
|
|
||||||
|
CaBuAr
|
||||||
|
^^^^^^
|
||||||
|
|
||||||
|
.. autoclass:: CaBuAr
|
||||||
|
|
||||||
ChaBuD
|
ChaBuD
|
||||||
^^^^^^
|
^^^^^^
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
|
||||||
`Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI
|
`Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI
|
||||||
`BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI"
|
`BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI"
|
||||||
`BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI"
|
`BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI"
|
||||||
|
`CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI
|
||||||
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI
|
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI
|
||||||
`Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI
|
`Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI
|
||||||
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB
|
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB
|
||||||
|
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
model:
|
||||||
|
class_path: SemanticSegmentationTask
|
||||||
|
init_args:
|
||||||
|
loss: "ce"
|
||||||
|
model: "unet"
|
||||||
|
backbone: "resnet18"
|
||||||
|
in_channels: 24
|
||||||
|
num_classes: 2
|
||||||
|
num_filters: 1
|
||||||
|
ignore_index: null
|
||||||
|
data:
|
||||||
|
class_path: CaBuArDataModule
|
||||||
|
init_args:
|
||||||
|
batch_size: 2
|
||||||
|
dict_kwargs:
|
||||||
|
root: "tests/data/cabuar"
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,69 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Sentinel-2 is 12-bit with range 0-4095
|
||||||
|
SENTINEL2_MAX = 4096
|
||||||
|
|
||||||
|
NUM_CHANNELS = 12
|
||||||
|
NUM_CLASSES = 2
|
||||||
|
SIZE = 32
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
random.seed(0)
|
||||||
|
|
||||||
|
filenames = ['512x512.hdf5', 'chabud_test.h5']
|
||||||
|
fold_mapping = {'train': [1, 2, 3, 4], 'val': [0], 'test': ['chabud']}
|
||||||
|
|
||||||
|
uris = [
|
||||||
|
'feb08801-64b1-4d11-a3fc-0efaad1f4274_0',
|
||||||
|
'e4d4dbcb-dd92-40cf-a7fe-fda8dd35f367_1',
|
||||||
|
'9fc8c1f4-1858-47c3-953e-1dc8b179a',
|
||||||
|
'3a1358a2-6155-445a-a269-13bebd9741a8_0',
|
||||||
|
'2f8e659c-f457-4527-a57f-bffc3bbe0baa_0',
|
||||||
|
'299ee670-19b1-4a76-bef3-34fd55580711_1',
|
||||||
|
'05cfef86-3e27-42be-a0cb-a61fe2f89e40_0',
|
||||||
|
'0328d12a-4ad8-4504-8ac5-70089db10b4e_1',
|
||||||
|
'04800581-b540-4f9b-9df8-7ee433e83f46_0',
|
||||||
|
'108ae2a9-d7d6-42f7-b89a-90bb75c23ccb_0',
|
||||||
|
'29413474-04b8-4bb1-8b89-fd640023d4a6_0',
|
||||||
|
'43f2e60a-73b4-4f33-b99e-319d892fcab4_0',
|
||||||
|
]
|
||||||
|
folds = random.choices(fold_mapping['train'], k=4) + [0] * 4 + ['chabud'] * 4
|
||||||
|
files = ['512x512.hdf5'] * 8 + ['chabud_test.h5'] * 4
|
||||||
|
|
||||||
|
# Remove old data
|
||||||
|
for filename in filenames:
|
||||||
|
if os.path.exists(filename):
|
||||||
|
os.remove(filename)
|
||||||
|
|
||||||
|
# Create dataset file
|
||||||
|
data = np.random.randint(
|
||||||
|
SENTINEL2_MAX, size=(SIZE, SIZE, NUM_CHANNELS), dtype=np.uint16
|
||||||
|
)
|
||||||
|
gt = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE, 1), dtype=np.uint16)
|
||||||
|
|
||||||
|
for filename, uri, fold in zip(files, uris, folds):
|
||||||
|
with h5py.File(filename, 'a') as f:
|
||||||
|
sample = f.create_group(uri)
|
||||||
|
sample.attrs.create(
|
||||||
|
name='fold', data=np.int64(fold) if fold != 'chabud' else fold
|
||||||
|
)
|
||||||
|
sample.create_dataset
|
||||||
|
sample.create_dataset('pre_fire', data=data)
|
||||||
|
sample.create_dataset('post_fire', data=data)
|
||||||
|
sample.create_dataset('mask', data=gt)
|
||||||
|
|
||||||
|
# Compute checksums
|
||||||
|
for filename in filenames:
|
||||||
|
with open(filename, 'rb') as f:
|
||||||
|
md5 = hashlib.md5(f.read()).hexdigest()
|
||||||
|
print(f'{filename} md5: {md5}')
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from itertools import product
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from _pytest.fixtures import SubRequest
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
|
||||||
|
from torchgeo.datasets import CaBuAr, DatasetNotFoundError
|
||||||
|
|
||||||
|
pytest.importorskip('h5py', minversion='3.6')
|
||||||
|
|
||||||
|
|
||||||
|
class TestCaBuAr:
|
||||||
|
@pytest.fixture(
|
||||||
|
params=product([CaBuAr.all_bands, CaBuAr.rgb_bands], ['train', 'val', 'test'])
|
||||||
|
)
|
||||||
|
def dataset(
|
||||||
|
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||||
|
) -> CaBuAr:
|
||||||
|
data_dir = os.path.join('tests', 'data', 'cabuar')
|
||||||
|
urls = (
|
||||||
|
os.path.join(data_dir, '512x512.hdf5'),
|
||||||
|
os.path.join(data_dir, 'chabud_test.h5'),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(CaBuAr, 'urls', urls)
|
||||||
|
bands, split = request.param
|
||||||
|
root = tmp_path
|
||||||
|
transforms = nn.Identity()
|
||||||
|
return CaBuAr(
|
||||||
|
root=root,
|
||||||
|
split=split,
|
||||||
|
bands=bands,
|
||||||
|
transforms=transforms,
|
||||||
|
download=True,
|
||||||
|
checksum=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_getitem(self, dataset: CaBuAr) -> None:
|
||||||
|
x = dataset[0]
|
||||||
|
assert isinstance(x, dict)
|
||||||
|
assert isinstance(x['image'], torch.Tensor)
|
||||||
|
assert isinstance(x['mask'], torch.Tensor)
|
||||||
|
|
||||||
|
# Image tests
|
||||||
|
assert x['image'].ndim == 3
|
||||||
|
|
||||||
|
if dataset.bands == CaBuAr.rgb_bands:
|
||||||
|
assert x['image'].shape[0] == 2 * 3
|
||||||
|
elif dataset.bands == CaBuAr.all_bands:
|
||||||
|
assert x['image'].shape[0] == 2 * 12
|
||||||
|
|
||||||
|
# Mask tests:
|
||||||
|
assert x['mask'].ndim == 2
|
||||||
|
|
||||||
|
def test_len(self, dataset: CaBuAr) -> None:
|
||||||
|
assert len(dataset) == 4
|
||||||
|
|
||||||
|
def test_already_downloaded(self, dataset: CaBuAr) -> None:
|
||||||
|
CaBuAr(root=dataset.root, download=True)
|
||||||
|
|
||||||
|
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||||
|
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||||
|
CaBuAr(tmp_path)
|
||||||
|
|
||||||
|
def test_invalid_bands(self) -> None:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
CaBuAr(bands=('OK', 'BK'))
|
||||||
|
|
||||||
|
def test_plot(self, dataset: CaBuAr) -> None:
|
||||||
|
dataset.plot(dataset[0], suptitle='Test')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
sample = dataset[0]
|
||||||
|
sample['prediction'] = sample['mask'].clone()
|
||||||
|
dataset.plot(sample, suptitle='prediction')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def test_plot_rgb(self, dataset: CaBuAr) -> None:
|
||||||
|
dataset = CaBuAr(root=dataset.root, bands=('B02',))
|
||||||
|
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
|
||||||
|
dataset.plot(dataset[0], suptitle='Single Band')
|
||||||
|
|
||||||
|
def test_invalid_split(self, dataset: CaBuAr) -> None:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
CaBuAr(dataset.root, split='foo')
|
|
@ -50,6 +50,7 @@ class TestSemanticSegmentationTask:
|
||||||
'name',
|
'name',
|
||||||
[
|
[
|
||||||
'agrifieldnet',
|
'agrifieldnet',
|
||||||
|
'cabuar',
|
||||||
'chabud',
|
'chabud',
|
||||||
'chesapeake_cvpr_5',
|
'chesapeake_cvpr_5',
|
||||||
'chesapeake_cvpr_7',
|
'chesapeake_cvpr_7',
|
||||||
|
@ -83,7 +84,7 @@ class TestSemanticSegmentationTask:
|
||||||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
|
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
match name:
|
match name:
|
||||||
case 'chabud':
|
case 'chabud' | 'cabuar':
|
||||||
pytest.importorskip('h5py', minversion='3.6')
|
pytest.importorskip('h5py', minversion='3.6')
|
||||||
case 'landcoverai':
|
case 'landcoverai':
|
||||||
sha256 = (
|
sha256 = (
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
from .agrifieldnet import AgriFieldNetDataModule
|
from .agrifieldnet import AgriFieldNetDataModule
|
||||||
from .bigearthnet import BigEarthNetDataModule
|
from .bigearthnet import BigEarthNetDataModule
|
||||||
|
from .cabuar import CaBuArDataModule
|
||||||
from .chabud import ChaBuDDataModule
|
from .chabud import ChaBuDDataModule
|
||||||
from .chesapeake import ChesapeakeCVPRDataModule
|
from .chesapeake import ChesapeakeCVPRDataModule
|
||||||
from .cowc import COWCCountingDataModule
|
from .cowc import COWCCountingDataModule
|
||||||
|
@ -65,6 +66,7 @@ __all__ = (
|
||||||
'SouthAfricaCropTypeDataModule',
|
'SouthAfricaCropTypeDataModule',
|
||||||
# NonGeoDataset
|
# NonGeoDataset
|
||||||
'BigEarthNetDataModule',
|
'BigEarthNetDataModule',
|
||||||
|
'CaBuArDataModule',
|
||||||
'ChaBuDDataModule',
|
'ChaBuDDataModule',
|
||||||
'COWCCountingDataModule',
|
'COWCCountingDataModule',
|
||||||
'DeepGlobeLandCoverDataModule',
|
'DeepGlobeLandCoverDataModule',
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""CaBuAr datamodule."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
from ..datasets import CaBuAr
|
||||||
|
from .geo import NonGeoDataModule
|
||||||
|
|
||||||
|
|
||||||
|
class CaBuArDataModule(NonGeoDataModule):
|
||||||
|
"""LightningDataModule implementation for the CaBuAr dataset.
|
||||||
|
|
||||||
|
Uses the train/val/test splits from the dataset
|
||||||
|
|
||||||
|
.. versionadded:: 0.6
|
||||||
|
"""
|
||||||
|
|
||||||
|
# min/max values computed on train set using 2/98 percentiles
|
||||||
|
min = torch.tensor(
|
||||||
|
[0.0, 1.0, 73.0, 39.0, 46.0, 25.0, 26.0, 21.0, 17.0, 1.0, 20.0, 21.0]
|
||||||
|
)
|
||||||
|
max = torch.tensor(
|
||||||
|
[
|
||||||
|
1926.0,
|
||||||
|
2174.0,
|
||||||
|
2527.0,
|
||||||
|
2950.0,
|
||||||
|
3237.0,
|
||||||
|
3717.0,
|
||||||
|
4087.0,
|
||||||
|
4271.0,
|
||||||
|
4290.0,
|
||||||
|
4219.0,
|
||||||
|
4568.0,
|
||||||
|
3753.0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a new CaBuArDataModule instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Size of each mini-batch.
|
||||||
|
num_workers: Number of workers for parallel data loading.
|
||||||
|
**kwargs: Additional keyword arguments passed to
|
||||||
|
:class:`~torchgeo.datasets.CaBuAr`.
|
||||||
|
"""
|
||||||
|
bands = kwargs.get('bands', CaBuAr.all_bands)
|
||||||
|
band_indices = [CaBuAr.all_bands.index(b) for b in bands]
|
||||||
|
mins = self.min[band_indices]
|
||||||
|
maxs = self.max[band_indices]
|
||||||
|
|
||||||
|
# Change detection, 2 images from different times
|
||||||
|
mins = repeat(mins, 'c -> (t c)', t=2)
|
||||||
|
maxs = repeat(maxs, 'c -> (t c)', t=2)
|
||||||
|
|
||||||
|
self.mean = mins
|
||||||
|
self.std = maxs - mins
|
||||||
|
|
||||||
|
super().__init__(CaBuAr, batch_size, num_workers, **kwargs)
|
|
@ -11,6 +11,7 @@ from .astergdem import AsterGDEM
|
||||||
from .benin_cashews import BeninSmallHolderCashews
|
from .benin_cashews import BeninSmallHolderCashews
|
||||||
from .bigearthnet import BigEarthNet
|
from .bigearthnet import BigEarthNet
|
||||||
from .biomassters import BioMassters
|
from .biomassters import BioMassters
|
||||||
|
from .cabuar import CaBuAr
|
||||||
from .cbf import CanadianBuildingFootprints
|
from .cbf import CanadianBuildingFootprints
|
||||||
from .cdl import CDL
|
from .cdl import CDL
|
||||||
from .chabud import ChaBuD
|
from .chabud import ChaBuD
|
||||||
|
@ -199,6 +200,7 @@ __all__ = (
|
||||||
'BeninSmallHolderCashews',
|
'BeninSmallHolderCashews',
|
||||||
'BigEarthNet',
|
'BigEarthNet',
|
||||||
'BioMassters',
|
'BioMassters',
|
||||||
|
'CaBuAr',
|
||||||
'ChaBuD',
|
'ChaBuD',
|
||||||
'CloudCoverDetection',
|
'CloudCoverDetection',
|
||||||
'COWC',
|
'COWC',
|
||||||
|
|
|
@ -0,0 +1,303 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""CaBuAr dataset."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .errors import DatasetNotFoundError
|
||||||
|
from .geo import NonGeoDataset
|
||||||
|
from .utils import Path, download_url, lazy_import, percentile_normalization
|
||||||
|
|
||||||
|
|
||||||
|
class CaBuAr(NonGeoDataset):
|
||||||
|
"""CaBuAr dataset.
|
||||||
|
|
||||||
|
`CaBuAr <https://huggingface.co/datasets/DarthReca/california_burned_areas>`__
|
||||||
|
is a dataset for Change detection for Burned area Delineation and part of
|
||||||
|
the splits are used for the ChaBuD ECML-PKDD 2023 Discovery Challenge.
|
||||||
|
|
||||||
|
Dataset features:
|
||||||
|
|
||||||
|
* Sentinel-2 multispectral imagery
|
||||||
|
* binary masks of burned areas
|
||||||
|
* 12 multispectral bands
|
||||||
|
* 424 pairs of pre and post images with 20 m per pixel resolution (512x512 px)
|
||||||
|
|
||||||
|
Dataset format:
|
||||||
|
|
||||||
|
* single hdf5 dataset containing images and masks
|
||||||
|
|
||||||
|
Dataset classes:
|
||||||
|
|
||||||
|
0. no change
|
||||||
|
1. burned area
|
||||||
|
|
||||||
|
If you use this dataset in your research, please cite the following paper:
|
||||||
|
|
||||||
|
* https://doi.org/10.1109/MGRS.2023.3292467
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
This dataset requires the following additional library to be installed:
|
||||||
|
|
||||||
|
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
|
||||||
|
|
||||||
|
.. versionadded:: 0.6
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_bands = (
|
||||||
|
'B01',
|
||||||
|
'B02',
|
||||||
|
'B03',
|
||||||
|
'B04',
|
||||||
|
'B05',
|
||||||
|
'B06',
|
||||||
|
'B07',
|
||||||
|
'B08',
|
||||||
|
'B8A',
|
||||||
|
'B09',
|
||||||
|
'B11',
|
||||||
|
'B12',
|
||||||
|
)
|
||||||
|
rgb_bands = ('B04', 'B03', 'B02')
|
||||||
|
folds: ClassVar[dict[str, list[object]]] = {
|
||||||
|
'train': [1, 2, 3, 4],
|
||||||
|
'val': [0],
|
||||||
|
'test': ['chabud'],
|
||||||
|
}
|
||||||
|
urls = (
|
||||||
|
'https://huggingface.co/datasets/DarthReca/california_burned_areas/resolve/main/raw/patched/512x512.hdf5',
|
||||||
|
'https://huggingface.co/datasets/DarthReca/california_burned_areas/resolve/main/raw/patched/chabud_test.h5',
|
||||||
|
)
|
||||||
|
filenames = ('512x512.hdf5', 'chabud_test.h5')
|
||||||
|
md5s = ('15d78fb825f9a81dad600db828d22c08', 'a70bb7e4a2788657c2354c4c3d9296fe')
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: Path = 'data',
|
||||||
|
split: str = 'train',
|
||||||
|
bands: tuple[str, ...] = all_bands,
|
||||||
|
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||||
|
download: bool = False,
|
||||||
|
checksum: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a new CaBuAr dataset instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root: root directory where dataset can be found
|
||||||
|
split: one of "train", "val", "test"
|
||||||
|
bands: the subset of bands to load
|
||||||
|
transforms: a function/transform that takes input sample and its target as
|
||||||
|
entry and returns a transformed version
|
||||||
|
download: if True, download dataset and store it in the root directory
|
||||||
|
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If ``split`` or ``bands`` arguments are invalid.
|
||||||
|
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||||
|
DependencyNotFoundError: If h5py is not installed.
|
||||||
|
"""
|
||||||
|
lazy_import('h5py')
|
||||||
|
|
||||||
|
assert split in self.folds
|
||||||
|
assert set(bands) <= set(self.all_bands)
|
||||||
|
|
||||||
|
# Set the file index based on the split
|
||||||
|
file_index = 1 if split == 'test' else 0
|
||||||
|
|
||||||
|
self.root = root
|
||||||
|
self.split = split
|
||||||
|
self.bands = bands
|
||||||
|
self.transforms = transforms
|
||||||
|
self.download = download
|
||||||
|
self.checksum = checksum
|
||||||
|
self.filepath = os.path.join(root, self.filenames[file_index])
|
||||||
|
self.band_indices = [self.all_bands.index(b) for b in bands]
|
||||||
|
|
||||||
|
self._verify()
|
||||||
|
|
||||||
|
self.uuids = self._load_uuids()
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||||
|
"""Return an index within the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sample containing image and mask
|
||||||
|
"""
|
||||||
|
image = self._load_image(index)
|
||||||
|
mask = self._load_target(index)
|
||||||
|
|
||||||
|
sample = {'image': image, 'mask': mask}
|
||||||
|
|
||||||
|
if self.transforms is not None:
|
||||||
|
sample = self.transforms(sample)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of data points in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
length of the dataset
|
||||||
|
"""
|
||||||
|
return len(self.uuids)
|
||||||
|
|
||||||
|
def _load_uuids(self) -> list[str]:
|
||||||
|
"""Return the image uuids for the given split.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the image uuids
|
||||||
|
"""
|
||||||
|
h5py = lazy_import('h5py')
|
||||||
|
uuids = []
|
||||||
|
with h5py.File(self.filepath, 'r') as f:
|
||||||
|
for k, v in f.items():
|
||||||
|
if v.attrs['fold'] in self.folds[self.split] and 'pre_fire' in v.keys():
|
||||||
|
uuids.append(k)
|
||||||
|
return sorted(uuids)
|
||||||
|
|
||||||
|
def _load_image(self, index: int) -> Tensor:
|
||||||
|
"""Load a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the image
|
||||||
|
"""
|
||||||
|
h5py = lazy_import('h5py')
|
||||||
|
uuid = self.uuids[index]
|
||||||
|
with h5py.File(self.filepath, 'r') as f:
|
||||||
|
pre_array = f[uuid]['pre_fire'][:]
|
||||||
|
post_array = f[uuid]['post_fire'][:]
|
||||||
|
|
||||||
|
# index specified bands and concatenate
|
||||||
|
pre_array = pre_array[..., self.band_indices]
|
||||||
|
post_array = post_array[..., self.band_indices]
|
||||||
|
array = np.concatenate([pre_array, post_array], axis=-1).astype(np.float32)
|
||||||
|
|
||||||
|
tensor = torch.from_numpy(array)
|
||||||
|
# Convert from HxWxC to CxHxW
|
||||||
|
tensor = tensor.permute((2, 0, 1))
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _load_target(self, index: int) -> Tensor:
|
||||||
|
"""Load the target mask for a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the target mask
|
||||||
|
"""
|
||||||
|
h5py = lazy_import('h5py')
|
||||||
|
uuid = self.uuids[index]
|
||||||
|
with h5py.File(self.filepath, 'r') as f:
|
||||||
|
array = f[uuid]['mask'][:].astype(np.int32).squeeze(axis=-1)
|
||||||
|
|
||||||
|
tensor = torch.from_numpy(array)
|
||||||
|
tensor = tensor.to(torch.long)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _verify(self) -> None:
|
||||||
|
"""Verify the integrity of the dataset."""
|
||||||
|
# Check if the files already exist
|
||||||
|
exists = []
|
||||||
|
for filename in self.filenames:
|
||||||
|
filepath = os.path.join(self.root, filename)
|
||||||
|
exists.append(os.path.exists(filepath))
|
||||||
|
|
||||||
|
if all(exists):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the user requested to download the dataset
|
||||||
|
if not self.download:
|
||||||
|
raise DatasetNotFoundError(self)
|
||||||
|
|
||||||
|
# Download the dataset
|
||||||
|
self._download()
|
||||||
|
|
||||||
|
def _download(self) -> None:
|
||||||
|
"""Download the dataset."""
|
||||||
|
for url, filename, md5 in zip(self.urls, self.filenames, self.md5s):
|
||||||
|
filepath = os.path.join(self.root, filename)
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
download_url(
|
||||||
|
url,
|
||||||
|
self.root,
|
||||||
|
filename=filename,
|
||||||
|
md5=md5 if self.checksum else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot(
|
||||||
|
self,
|
||||||
|
sample: dict[str, Tensor],
|
||||||
|
show_titles: bool = True,
|
||||||
|
suptitle: str | None = None,
|
||||||
|
) -> Figure:
|
||||||
|
"""Plot a sample from the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: a sample returned by :meth:`__getitem__`
|
||||||
|
show_titles: flag indicating whether to show titles above each panel
|
||||||
|
suptitle: optional suptitle to use for figure
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a matplotlib Figure with the rendered sample
|
||||||
|
"""
|
||||||
|
rgb_indices = []
|
||||||
|
for band in self.rgb_bands:
|
||||||
|
if band in self.bands:
|
||||||
|
rgb_indices.append(self.bands.index(band))
|
||||||
|
else:
|
||||||
|
raise ValueError("Dataset doesn't contain some of the RGB bands")
|
||||||
|
|
||||||
|
mask = sample['mask'].numpy()
|
||||||
|
image_pre = sample['image'][: len(self.bands)][rgb_indices].numpy()
|
||||||
|
image_post = sample['image'][len(self.bands) :][rgb_indices].numpy()
|
||||||
|
image_pre = percentile_normalization(image_pre)
|
||||||
|
image_post = percentile_normalization(image_post)
|
||||||
|
|
||||||
|
ncols = 3
|
||||||
|
|
||||||
|
showing_predictions = 'prediction' in sample
|
||||||
|
if showing_predictions:
|
||||||
|
prediction = sample['prediction']
|
||||||
|
ncols += 1
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))
|
||||||
|
|
||||||
|
axs[0].imshow(np.transpose(image_pre, (1, 2, 0)))
|
||||||
|
axs[0].axis('off')
|
||||||
|
axs[1].imshow(np.transpose(image_post, (1, 2, 0)))
|
||||||
|
axs[1].axis('off')
|
||||||
|
axs[2].imshow(mask)
|
||||||
|
axs[2].axis('off')
|
||||||
|
|
||||||
|
if showing_predictions:
|
||||||
|
axs[3].imshow(prediction)
|
||||||
|
axs[3].axis('off')
|
||||||
|
|
||||||
|
if show_titles:
|
||||||
|
axs[0].set_title('Image Pre')
|
||||||
|
axs[1].set_title('Image Post')
|
||||||
|
axs[2].set_title('Mask')
|
||||||
|
if showing_predictions:
|
||||||
|
axs[3].set_title('Prediction')
|
||||||
|
|
||||||
|
if suptitle is not None:
|
||||||
|
plt.suptitle(suptitle)
|
||||||
|
|
||||||
|
return fig
|
Загрузка…
Ссылка в новой задаче