зеркало из https://github.com/microsoft/torchgeo.git
Add CaBuAr dataset (#2235)
* 🆕 Added CaBuAr dataset * 🆕 Added CaBuAr datamodule * 🔨 Added CaBuAr datamodule test * 🔨 Corrected CaBuAr typing and datamodule test * 🔨 updated test, corrected docs, minor fixes to dataset and datamodule * 🔨 CaBuAr test fixes
This commit is contained in:
Родитель
042b75ea9c
Коммит
ccc314cd88
|
@ -57,6 +57,11 @@ BigEarthNet
|
|||
|
||||
.. autoclass:: BigEarthNetDataModule
|
||||
|
||||
CaBuAr
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: CaBuArDataModule
|
||||
|
||||
ChaBuD
|
||||
^^^^^^
|
||||
|
||||
|
|
|
@ -217,6 +217,11 @@ BioMassters
|
|||
|
||||
.. autoclass:: BioMassters
|
||||
|
||||
CaBuAr
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: CaBuAr
|
||||
|
||||
ChaBuD
|
||||
^^^^^^
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
|
|||
`Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI
|
||||
`BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI"
|
||||
`BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI"
|
||||
`CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI
|
||||
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI
|
||||
`Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI
|
||||
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB
|
||||
|
|
|
|
@ -0,0 +1,16 @@
|
|||
model:
|
||||
class_path: SemanticSegmentationTask
|
||||
init_args:
|
||||
loss: "ce"
|
||||
model: "unet"
|
||||
backbone: "resnet18"
|
||||
in_channels: 24
|
||||
num_classes: 2
|
||||
num_filters: 1
|
||||
ignore_index: null
|
||||
data:
|
||||
class_path: CaBuArDataModule
|
||||
init_args:
|
||||
batch_size: 2
|
||||
dict_kwargs:
|
||||
root: "tests/data/cabuar"
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
# Sentinel-2 is 12-bit with range 0-4095
|
||||
SENTINEL2_MAX = 4096
|
||||
|
||||
NUM_CHANNELS = 12
|
||||
NUM_CLASSES = 2
|
||||
SIZE = 32
|
||||
|
||||
np.random.seed(0)
|
||||
random.seed(0)
|
||||
|
||||
filenames = ['512x512.hdf5', 'chabud_test.h5']
|
||||
fold_mapping = {'train': [1, 2, 3, 4], 'val': [0], 'test': ['chabud']}
|
||||
|
||||
uris = [
|
||||
'feb08801-64b1-4d11-a3fc-0efaad1f4274_0',
|
||||
'e4d4dbcb-dd92-40cf-a7fe-fda8dd35f367_1',
|
||||
'9fc8c1f4-1858-47c3-953e-1dc8b179a',
|
||||
'3a1358a2-6155-445a-a269-13bebd9741a8_0',
|
||||
'2f8e659c-f457-4527-a57f-bffc3bbe0baa_0',
|
||||
'299ee670-19b1-4a76-bef3-34fd55580711_1',
|
||||
'05cfef86-3e27-42be-a0cb-a61fe2f89e40_0',
|
||||
'0328d12a-4ad8-4504-8ac5-70089db10b4e_1',
|
||||
'04800581-b540-4f9b-9df8-7ee433e83f46_0',
|
||||
'108ae2a9-d7d6-42f7-b89a-90bb75c23ccb_0',
|
||||
'29413474-04b8-4bb1-8b89-fd640023d4a6_0',
|
||||
'43f2e60a-73b4-4f33-b99e-319d892fcab4_0',
|
||||
]
|
||||
folds = random.choices(fold_mapping['train'], k=4) + [0] * 4 + ['chabud'] * 4
|
||||
files = ['512x512.hdf5'] * 8 + ['chabud_test.h5'] * 4
|
||||
|
||||
# Remove old data
|
||||
for filename in filenames:
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
# Create dataset file
|
||||
data = np.random.randint(
|
||||
SENTINEL2_MAX, size=(SIZE, SIZE, NUM_CHANNELS), dtype=np.uint16
|
||||
)
|
||||
gt = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE, 1), dtype=np.uint16)
|
||||
|
||||
for filename, uri, fold in zip(files, uris, folds):
|
||||
with h5py.File(filename, 'a') as f:
|
||||
sample = f.create_group(uri)
|
||||
sample.attrs.create(
|
||||
name='fold', data=np.int64(fold) if fold != 'chabud' else fold
|
||||
)
|
||||
sample.create_dataset
|
||||
sample.create_dataset('pre_fire', data=data)
|
||||
sample.create_dataset('post_fire', data=data)
|
||||
sample.create_dataset('mask', data=gt)
|
||||
|
||||
# Compute checksums
|
||||
for filename in filenames:
|
||||
with open(filename, 'rb') as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f'{filename} md5: {md5}')
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import CaBuAr, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
class TestCaBuAr:
|
||||
@pytest.fixture(
|
||||
params=product([CaBuAr.all_bands, CaBuAr.rgb_bands], ['train', 'val', 'test'])
|
||||
)
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> CaBuAr:
|
||||
data_dir = os.path.join('tests', 'data', 'cabuar')
|
||||
urls = (
|
||||
os.path.join(data_dir, '512x512.hdf5'),
|
||||
os.path.join(data_dir, 'chabud_test.h5'),
|
||||
)
|
||||
monkeypatch.setattr(CaBuAr, 'urls', urls)
|
||||
bands, split = request.param
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return CaBuAr(
|
||||
root=root,
|
||||
split=split,
|
||||
bands=bands,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
checksum=True,
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: CaBuAr) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
assert isinstance(x['mask'], torch.Tensor)
|
||||
|
||||
# Image tests
|
||||
assert x['image'].ndim == 3
|
||||
|
||||
if dataset.bands == CaBuAr.rgb_bands:
|
||||
assert x['image'].shape[0] == 2 * 3
|
||||
elif dataset.bands == CaBuAr.all_bands:
|
||||
assert x['image'].shape[0] == 2 * 12
|
||||
|
||||
# Mask tests:
|
||||
assert x['mask'].ndim == 2
|
||||
|
||||
def test_len(self, dataset: CaBuAr) -> None:
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: CaBuAr) -> None:
|
||||
CaBuAr(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
CaBuAr(tmp_path)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
CaBuAr(bands=('OK', 'BK'))
|
||||
|
||||
def test_plot(self, dataset: CaBuAr) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
plt.close()
|
||||
|
||||
sample = dataset[0]
|
||||
sample['prediction'] = sample['mask'].clone()
|
||||
dataset.plot(sample, suptitle='prediction')
|
||||
plt.close()
|
||||
|
||||
def test_plot_rgb(self, dataset: CaBuAr) -> None:
|
||||
dataset = CaBuAr(root=dataset.root, bands=('B02',))
|
||||
with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"):
|
||||
dataset.plot(dataset[0], suptitle='Single Band')
|
||||
|
||||
def test_invalid_split(self, dataset: CaBuAr) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
CaBuAr(dataset.root, split='foo')
|
|
@ -50,6 +50,7 @@ class TestSemanticSegmentationTask:
|
|||
'name',
|
||||
[
|
||||
'agrifieldnet',
|
||||
'cabuar',
|
||||
'chabud',
|
||||
'chesapeake_cvpr_5',
|
||||
'chesapeake_cvpr_7',
|
||||
|
@ -83,7 +84,7 @@ class TestSemanticSegmentationTask:
|
|||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
|
||||
) -> None:
|
||||
match name:
|
||||
case 'chabud':
|
||||
case 'chabud' | 'cabuar':
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
case 'landcoverai':
|
||||
sha256 = (
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
from .agrifieldnet import AgriFieldNetDataModule
|
||||
from .bigearthnet import BigEarthNetDataModule
|
||||
from .cabuar import CaBuArDataModule
|
||||
from .chabud import ChaBuDDataModule
|
||||
from .chesapeake import ChesapeakeCVPRDataModule
|
||||
from .cowc import COWCCountingDataModule
|
||||
|
@ -65,6 +66,7 @@ __all__ = (
|
|||
'SouthAfricaCropTypeDataModule',
|
||||
# NonGeoDataset
|
||||
'BigEarthNetDataModule',
|
||||
'CaBuArDataModule',
|
||||
'ChaBuDDataModule',
|
||||
'COWCCountingDataModule',
|
||||
'DeepGlobeLandCoverDataModule',
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""CaBuAr datamodule."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from einops import repeat
|
||||
|
||||
from ..datasets import CaBuAr
|
||||
from .geo import NonGeoDataModule
|
||||
|
||||
|
||||
class CaBuArDataModule(NonGeoDataModule):
|
||||
"""LightningDataModule implementation for the CaBuAr dataset.
|
||||
|
||||
Uses the train/val/test splits from the dataset
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
# min/max values computed on train set using 2/98 percentiles
|
||||
min = torch.tensor(
|
||||
[0.0, 1.0, 73.0, 39.0, 46.0, 25.0, 26.0, 21.0, 17.0, 1.0, 20.0, 21.0]
|
||||
)
|
||||
max = torch.tensor(
|
||||
[
|
||||
1926.0,
|
||||
2174.0,
|
||||
2527.0,
|
||||
2950.0,
|
||||
3237.0,
|
||||
3717.0,
|
||||
4087.0,
|
||||
4271.0,
|
||||
4290.0,
|
||||
4219.0,
|
||||
4568.0,
|
||||
3753.0,
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a new CaBuArDataModule instance.
|
||||
|
||||
Args:
|
||||
batch_size: Size of each mini-batch.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
**kwargs: Additional keyword arguments passed to
|
||||
:class:`~torchgeo.datasets.CaBuAr`.
|
||||
"""
|
||||
bands = kwargs.get('bands', CaBuAr.all_bands)
|
||||
band_indices = [CaBuAr.all_bands.index(b) for b in bands]
|
||||
mins = self.min[band_indices]
|
||||
maxs = self.max[band_indices]
|
||||
|
||||
# Change detection, 2 images from different times
|
||||
mins = repeat(mins, 'c -> (t c)', t=2)
|
||||
maxs = repeat(maxs, 'c -> (t c)', t=2)
|
||||
|
||||
self.mean = mins
|
||||
self.std = maxs - mins
|
||||
|
||||
super().__init__(CaBuAr, batch_size, num_workers, **kwargs)
|
|
@ -11,6 +11,7 @@ from .astergdem import AsterGDEM
|
|||
from .benin_cashews import BeninSmallHolderCashews
|
||||
from .bigearthnet import BigEarthNet
|
||||
from .biomassters import BioMassters
|
||||
from .cabuar import CaBuAr
|
||||
from .cbf import CanadianBuildingFootprints
|
||||
from .cdl import CDL
|
||||
from .chabud import ChaBuD
|
||||
|
@ -199,6 +200,7 @@ __all__ = (
|
|||
'BeninSmallHolderCashews',
|
||||
'BigEarthNet',
|
||||
'BioMassters',
|
||||
'CaBuAr',
|
||||
'ChaBuD',
|
||||
'CloudCoverDetection',
|
||||
'COWC',
|
||||
|
|
|
@ -0,0 +1,303 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""CaBuAr dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import ClassVar
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import Path, download_url, lazy_import, percentile_normalization
|
||||
|
||||
|
||||
class CaBuAr(NonGeoDataset):
|
||||
"""CaBuAr dataset.
|
||||
|
||||
`CaBuAr <https://huggingface.co/datasets/DarthReca/california_burned_areas>`__
|
||||
is a dataset for Change detection for Burned area Delineation and part of
|
||||
the splits are used for the ChaBuD ECML-PKDD 2023 Discovery Challenge.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* Sentinel-2 multispectral imagery
|
||||
* binary masks of burned areas
|
||||
* 12 multispectral bands
|
||||
* 424 pairs of pre and post images with 20 m per pixel resolution (512x512 px)
|
||||
|
||||
Dataset format:
|
||||
|
||||
* single hdf5 dataset containing images and masks
|
||||
|
||||
Dataset classes:
|
||||
|
||||
0. no change
|
||||
1. burned area
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://doi.org/10.1109/MGRS.2023.3292467
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
all_bands = (
|
||||
'B01',
|
||||
'B02',
|
||||
'B03',
|
||||
'B04',
|
||||
'B05',
|
||||
'B06',
|
||||
'B07',
|
||||
'B08',
|
||||
'B8A',
|
||||
'B09',
|
||||
'B11',
|
||||
'B12',
|
||||
)
|
||||
rgb_bands = ('B04', 'B03', 'B02')
|
||||
folds: ClassVar[dict[str, list[object]]] = {
|
||||
'train': [1, 2, 3, 4],
|
||||
'val': [0],
|
||||
'test': ['chabud'],
|
||||
}
|
||||
urls = (
|
||||
'https://huggingface.co/datasets/DarthReca/california_burned_areas/resolve/main/raw/patched/512x512.hdf5',
|
||||
'https://huggingface.co/datasets/DarthReca/california_burned_areas/resolve/main/raw/patched/chabud_test.h5',
|
||||
)
|
||||
filenames = ('512x512.hdf5', 'chabud_test.h5')
|
||||
md5s = ('15d78fb825f9a81dad600db828d22c08', 'a70bb7e4a2788657c2354c4c3d9296fe')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
bands: tuple[str, ...] = all_bands,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new CaBuAr dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train", "val", "test"
|
||||
bands: the subset of bands to load
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: If ``split`` or ``bands`` arguments are invalid.
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
DependencyNotFoundError: If h5py is not installed.
|
||||
"""
|
||||
lazy_import('h5py')
|
||||
|
||||
assert split in self.folds
|
||||
assert set(bands) <= set(self.all_bands)
|
||||
|
||||
# Set the file index based on the split
|
||||
file_index = 1 if split == 'test' else 0
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.bands = bands
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.filepath = os.path.join(root, self.filenames[file_index])
|
||||
self.band_indices = [self.all_bands.index(b) for b in bands]
|
||||
|
||||
self._verify()
|
||||
|
||||
self.uuids = self._load_uuids()
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
sample containing image and mask
|
||||
"""
|
||||
image = self._load_image(index)
|
||||
mask = self._load_target(index)
|
||||
|
||||
sample = {'image': image, 'mask': mask}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.uuids)
|
||||
|
||||
def _load_uuids(self) -> list[str]:
|
||||
"""Return the image uuids for the given split.
|
||||
|
||||
Returns:
|
||||
the image uuids
|
||||
"""
|
||||
h5py = lazy_import('h5py')
|
||||
uuids = []
|
||||
with h5py.File(self.filepath, 'r') as f:
|
||||
for k, v in f.items():
|
||||
if v.attrs['fold'] in self.folds[self.split] and 'pre_fire' in v.keys():
|
||||
uuids.append(k)
|
||||
return sorted(uuids)
|
||||
|
||||
def _load_image(self, index: int) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
h5py = lazy_import('h5py')
|
||||
uuid = self.uuids[index]
|
||||
with h5py.File(self.filepath, 'r') as f:
|
||||
pre_array = f[uuid]['pre_fire'][:]
|
||||
post_array = f[uuid]['post_fire'][:]
|
||||
|
||||
# index specified bands and concatenate
|
||||
pre_array = pre_array[..., self.band_indices]
|
||||
post_array = post_array[..., self.band_indices]
|
||||
array = np.concatenate([pre_array, post_array], axis=-1).astype(np.float32)
|
||||
|
||||
tensor = torch.from_numpy(array)
|
||||
# Convert from HxWxC to CxHxW
|
||||
tensor = tensor.permute((2, 0, 1))
|
||||
return tensor
|
||||
|
||||
def _load_target(self, index: int) -> Tensor:
|
||||
"""Load the target mask for a single image.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the target mask
|
||||
"""
|
||||
h5py = lazy_import('h5py')
|
||||
uuid = self.uuids[index]
|
||||
with h5py.File(self.filepath, 'r') as f:
|
||||
array = f[uuid]['mask'][:].astype(np.int32).squeeze(axis=-1)
|
||||
|
||||
tensor = torch.from_numpy(array)
|
||||
tensor = tensor.to(torch.long)
|
||||
return tensor
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the files already exist
|
||||
exists = []
|
||||
for filename in self.filenames:
|
||||
filepath = os.path.join(self.root, filename)
|
||||
exists.append(os.path.exists(filepath))
|
||||
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
for url, filename, md5 in zip(self.urls, self.filenames, self.md5s):
|
||||
filepath = os.path.join(self.root, filename)
|
||||
if not os.path.exists(filepath):
|
||||
download_url(
|
||||
url,
|
||||
self.root,
|
||||
filename=filename,
|
||||
md5=md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional suptitle to use for figure
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
rgb_indices = []
|
||||
for band in self.rgb_bands:
|
||||
if band in self.bands:
|
||||
rgb_indices.append(self.bands.index(band))
|
||||
else:
|
||||
raise ValueError("Dataset doesn't contain some of the RGB bands")
|
||||
|
||||
mask = sample['mask'].numpy()
|
||||
image_pre = sample['image'][: len(self.bands)][rgb_indices].numpy()
|
||||
image_post = sample['image'][len(self.bands) :][rgb_indices].numpy()
|
||||
image_pre = percentile_normalization(image_pre)
|
||||
image_post = percentile_normalization(image_post)
|
||||
|
||||
ncols = 3
|
||||
|
||||
showing_predictions = 'prediction' in sample
|
||||
if showing_predictions:
|
||||
prediction = sample['prediction']
|
||||
ncols += 1
|
||||
|
||||
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))
|
||||
|
||||
axs[0].imshow(np.transpose(image_pre, (1, 2, 0)))
|
||||
axs[0].axis('off')
|
||||
axs[1].imshow(np.transpose(image_post, (1, 2, 0)))
|
||||
axs[1].axis('off')
|
||||
axs[2].imshow(mask)
|
||||
axs[2].axis('off')
|
||||
|
||||
if showing_predictions:
|
||||
axs[3].imshow(prediction)
|
||||
axs[3].axis('off')
|
||||
|
||||
if show_titles:
|
||||
axs[0].set_title('Image Pre')
|
||||
axs[1].set_title('Image Post')
|
||||
axs[2].set_title('Mask')
|
||||
if showing_predictions:
|
||||
axs[3].set_title('Prediction')
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче