зеркало из https://github.com/microsoft/torchgeo.git
Add GeoNRW dataset (#2209)
* dataset and module * test with training * add tests * start the fight with mypy * kick off tests * class var ruff * don't download * forgot tests data * already downloaded * coverage * review * mypy * docs * docs * suggestion * plotting * versionadded: 2 digits * Type hint unnecessary --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
6d758abf81
Коммит
2d6e27ebd0
|
@ -94,6 +94,11 @@ FireRisk
|
|||
|
||||
.. autoclass:: FireRiskDataModule
|
||||
|
||||
GeoNRW
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: GeoNRWDataModule
|
||||
|
||||
GID-15
|
||||
^^^^^^
|
||||
|
||||
|
|
|
@ -281,6 +281,11 @@ Forest Damage
|
|||
|
||||
.. autoclass:: ForestDamage
|
||||
|
||||
GeoNRW
|
||||
^^^^^^^
|
||||
|
||||
.. autoclass:: GeoNRW
|
||||
|
||||
GID-15
|
||||
^^^^^^
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
|
|||
`FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB
|
||||
`FireRisk`_,C,NAIP Aerial,"CC-BY-NC-4.0","91,872",7,"320x320",1,RGB
|
||||
`Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB
|
||||
`GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM"
|
||||
`GID-15`_,S,Gaofen-2,-,150,15,"6,800x7,200",3,RGB
|
||||
`IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB
|
||||
`Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB
|
||||
|
|
|
|
@ -0,0 +1,16 @@
|
|||
model:
|
||||
class_path: SemanticSegmentationTask
|
||||
init_args:
|
||||
loss: "ce"
|
||||
model: "unet"
|
||||
backbone: "resnet18"
|
||||
in_channels: 3
|
||||
num_classes: 11
|
||||
num_filters: 1
|
||||
ignore_index: null
|
||||
data:
|
||||
class_path: GeoNRWDataModule
|
||||
init_args:
|
||||
batch_size: 1
|
||||
dict_kwargs:
|
||||
root: "tests/data/geonrw"
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,87 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# Constants
|
||||
IMAGE_SIZE = (100, 100)
|
||||
TRAIN_CITIES = ['aachen', 'bergisch', 'bielefeld']
|
||||
TEST_CITIES = ['duesseldorf']
|
||||
CLASSES = [
|
||||
'background',
|
||||
'forest',
|
||||
'water',
|
||||
'agricultural',
|
||||
'residential,commercial,industrial',
|
||||
'grassland,swamp,shrubbery',
|
||||
'railway,trainstation',
|
||||
'highway,squares',
|
||||
'airport,shipyard',
|
||||
'roads',
|
||||
'buildings',
|
||||
]
|
||||
NUM_SAMPLES_PER_CITY = 2
|
||||
|
||||
|
||||
def create_directories(cities: list[str]) -> None:
|
||||
for city in cities:
|
||||
if os.path.exists(city):
|
||||
shutil.rmtree(city)
|
||||
os.makedirs(city, exist_ok=True)
|
||||
|
||||
|
||||
def generate_dummy_data(cities: list[str]) -> None:
|
||||
for city in cities:
|
||||
for i in range(NUM_SAMPLES_PER_CITY):
|
||||
utm_coords = f'{i}_{i}'
|
||||
rgb_image = np.random.randint(0, 256, (*IMAGE_SIZE, 3), dtype=np.uint8)
|
||||
dem_image = np.random.randint(0, 256, IMAGE_SIZE, dtype=np.uint8)
|
||||
seg_image = np.random.randint(0, len(CLASSES), IMAGE_SIZE, dtype=np.uint8)
|
||||
|
||||
Image.fromarray(rgb_image).save(os.path.join(city, f'{utm_coords}_rgb.jp2'))
|
||||
Image.fromarray(dem_image).save(os.path.join(city, f'{utm_coords}_dem.tif'))
|
||||
Image.fromarray(seg_image).save(os.path.join(city, f'{utm_coords}_seg.tif'))
|
||||
|
||||
|
||||
def create_tarball(output_filename: str, source_dirs: list[str]) -> None:
|
||||
with tarfile.open(output_filename, 'w:gz') as tar:
|
||||
for source_dir in source_dirs:
|
||||
tar.add(source_dir, arcname=os.path.basename(source_dir))
|
||||
|
||||
|
||||
def calculate_md5(filename: str) -> str:
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(filename, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(4096), b''):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
# Main function
|
||||
def main() -> None:
|
||||
train_cities = TRAIN_CITIES
|
||||
test_cities = TEST_CITIES
|
||||
|
||||
create_directories(train_cities)
|
||||
create_directories(test_cities)
|
||||
|
||||
generate_dummy_data(train_cities)
|
||||
generate_dummy_data(test_cities)
|
||||
|
||||
tarball_name = 'nrw_dataset.tar.gz'
|
||||
create_tarball(tarball_name, train_cities + test_cities)
|
||||
|
||||
md5sum = calculate_md5(tarball_name)
|
||||
print(f'MD5 checksum: {md5sum}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import DatasetNotFoundError, GeoNRW
|
||||
|
||||
|
||||
class TestGeoNRW:
|
||||
@pytest.fixture(params=['train', 'test'])
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> GeoNRW:
|
||||
md5 = '6ffc014d4b345bba3076e8d76ab481fa'
|
||||
monkeypatch.setattr(GeoNRW, 'md5', md5)
|
||||
url = os.path.join('tests', 'data', 'geonrw', 'nrw_dataset.tar.gz')
|
||||
monkeypatch.setattr(GeoNRW, 'url', url)
|
||||
monkeypatch.setattr(GeoNRW, 'train_list', ['aachen', 'bergisch', 'bielefeld'])
|
||||
monkeypatch.setattr(GeoNRW, 'test_list', ['duesseldorf'])
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return GeoNRW(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: GeoNRW) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
assert x['image'].shape[0] == 3
|
||||
assert isinstance(x['mask'], torch.Tensor)
|
||||
assert x['image'].shape[-2:] == x['mask'].shape[-2:]
|
||||
|
||||
def test_len(self, dataset: GeoNRW) -> None:
|
||||
if dataset.split == 'train':
|
||||
assert len(dataset) == 6
|
||||
else:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: GeoNRW) -> None:
|
||||
GeoNRW(root=dataset.root)
|
||||
|
||||
def test_not_yet_extracted(self, tmp_path: Path) -> None:
|
||||
filename = 'nrw_dataset.tar.gz'
|
||||
dir = os.path.join('tests', 'data', 'geonrw')
|
||||
shutil.copyfile(
|
||||
os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
|
||||
)
|
||||
GeoNRW(root=str(tmp_path))
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
GeoNRW(split='foo')
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
GeoNRW(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: GeoNRW) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
plt.close()
|
||||
|
||||
sample = dataset[0]
|
||||
sample['prediction'] = torch.clone(sample['mask'])
|
||||
dataset.plot(sample, suptitle='Prediction')
|
||||
plt.close()
|
|
@ -55,6 +55,7 @@ class TestSemanticSegmentationTask:
|
|||
'chesapeake_cvpr_7',
|
||||
'deepglobelandcover',
|
||||
'etci2021',
|
||||
'geonrw',
|
||||
'gid15',
|
||||
'inria',
|
||||
'l7irish',
|
||||
|
|
|
@ -15,6 +15,7 @@ from .eurosat import EuroSAT100DataModule, EuroSATDataModule, EuroSATSpatialData
|
|||
from .fair1m import FAIR1MDataModule
|
||||
from .fire_risk import FireRiskDataModule
|
||||
from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
|
||||
from .geonrw import GeoNRWDataModule
|
||||
from .gid15 import GID15DataModule
|
||||
from .inria import InriaAerialImageLabelingDataModule
|
||||
from .iobench import IOBenchDataModule
|
||||
|
@ -73,6 +74,7 @@ __all__ = (
|
|||
'EuroSAT100DataModule',
|
||||
'FAIR1MDataModule',
|
||||
'FireRiskDataModule',
|
||||
'GeoNRWDataModule',
|
||||
'GID15DataModule',
|
||||
'InriaAerialImageLabelingDataModule',
|
||||
'LandCoverAIDataModule',
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""GeoNRW datamodule."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import kornia.augmentation as K
|
||||
from torch.utils.data import Subset
|
||||
|
||||
from ..datasets import GeoNRW
|
||||
from ..transforms import AugmentationSequential
|
||||
from .geo import NonGeoDataModule
|
||||
from .utils import group_shuffle_split
|
||||
|
||||
|
||||
class GeoNRWDataModule(NonGeoDataModule):
|
||||
"""LightningDataModule implementation for the GeoNRW dataset.
|
||||
|
||||
Implements 80/20 train/val splits based on city locations.
|
||||
See :func:`setup` for more details.
|
||||
|
||||
.. versionadded: 0.6
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, batch_size: int = 64, num_workers: int = 0, size: int = 256, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize a new GeoNRWDataModule instance.
|
||||
|
||||
Args:
|
||||
batch_size: Size of each mini-batch.
|
||||
num_workers: Number of workers for parallel data loading.
|
||||
size: resize images of input size 1000x1000 to size x size
|
||||
**kwargs: Additional keyword arguments passed to
|
||||
:class:`~torchgeo.datasets.GeoNRW`.
|
||||
"""
|
||||
super().__init__(GeoNRW, batch_size, num_workers, **kwargs)
|
||||
|
||||
self.train_aug = AugmentationSequential(
|
||||
K.Resize(size),
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
K.RandomVerticalFlip(p=0.5),
|
||||
data_keys=['image', 'mask'],
|
||||
)
|
||||
|
||||
self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask'])
|
||||
|
||||
self.size = size
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
"""Set up datasets.
|
||||
|
||||
Args:
|
||||
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
||||
"""
|
||||
if stage in ['fit', 'validate']:
|
||||
dataset = GeoNRW(split='train', **self.kwargs)
|
||||
city_paths = [os.path.dirname(path) for path in dataset.file_list]
|
||||
train_indices, val_indices = group_shuffle_split(
|
||||
city_paths, test_size=0.2, random_state=0
|
||||
)
|
||||
self.train_dataset = Subset(dataset, train_indices)
|
||||
self.val_dataset = Subset(dataset, val_indices)
|
||||
if stage in ['test']:
|
||||
self.test_dataset = GeoNRW(split='test', **self.kwargs)
|
|
@ -54,6 +54,7 @@ from .geo import (
|
|||
UnionDataset,
|
||||
VectorDataset,
|
||||
)
|
||||
from .geonrw import GeoNRW
|
||||
from .gid15 import GID15
|
||||
from .globbiomass import GlobBiomass
|
||||
from .idtrees import IDTReeS
|
||||
|
@ -213,6 +214,7 @@ __all__ = (
|
|||
'FAIR1M',
|
||||
'FireRisk',
|
||||
'ForestDamage',
|
||||
'GeoNRW',
|
||||
'GID15',
|
||||
'IDTReeS',
|
||||
'InriaAerialImageLabeling',
|
||||
|
|
|
@ -0,0 +1,346 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""GeoNRW dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from glob import glob
|
||||
from typing import ClassVar
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.cm
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torchvision import transforms
|
||||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import Path, download_and_extract_archive, extract_archive
|
||||
|
||||
|
||||
class GeoNRW(NonGeoDataset):
|
||||
"""GeoNRW dataset.
|
||||
|
||||
This datasets contains RGB, DEM and segmentation label data from North Rhine-Westphalia, Germany.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* 7298 training and 485 test samples
|
||||
* RGB images, 1000x1000px normalized to [0, 1]
|
||||
* DEM images, unnormalized
|
||||
* segmentation labels
|
||||
|
||||
Dataset format:
|
||||
|
||||
* RGB images are three-channel jp2
|
||||
* DEM images are single-channel tif
|
||||
* segmentation labels are single-channel tif
|
||||
|
||||
Dataset classes:
|
||||
|
||||
0. background
|
||||
1. forest
|
||||
2. water
|
||||
3. agricultural
|
||||
4. residential,commercial,industrial
|
||||
5. grassland,swamp,shrubbery
|
||||
6. railway,trainstation
|
||||
7. highway,squares
|
||||
8. airport,shipyard
|
||||
9. roads
|
||||
10. buildings
|
||||
|
||||
Additional information about the dataset can be found `on this site <https://ieee-dataport.org/open-access/geonrw>`__.
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://ieeexplore.ieee.org/document/9406194
|
||||
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
# Splits taken from https://github.com/gbaier/geonrw/blob/ecfcdbca8cfaaeb490a9c6916980f385b9f3941a/pytorch/nrw.py#L48
|
||||
|
||||
splits = ('train', 'test')
|
||||
|
||||
train_list: tuple[str, ...] = (
|
||||
'aachen',
|
||||
'bergisch',
|
||||
'bielefeld',
|
||||
'bochum',
|
||||
'bonn',
|
||||
'borken',
|
||||
'bottrop',
|
||||
'coesfeld',
|
||||
'dortmund',
|
||||
'dueren',
|
||||
'duisburg',
|
||||
'ennepetal',
|
||||
'erftstadt',
|
||||
'essen',
|
||||
'euskirchen',
|
||||
'gelsenkirchen',
|
||||
'guetersloh',
|
||||
'hagen',
|
||||
'hamm',
|
||||
'heinsberg',
|
||||
'herford',
|
||||
'hoexter',
|
||||
'kleve',
|
||||
'koeln',
|
||||
'krefeld',
|
||||
'leverkusen',
|
||||
'lippetal',
|
||||
'lippstadt',
|
||||
'lotte',
|
||||
'moenchengladbach',
|
||||
'moers',
|
||||
'muelheim',
|
||||
'muenster',
|
||||
'oberhausen',
|
||||
'paderborn',
|
||||
'recklinghausen',
|
||||
'remscheid',
|
||||
'siegen',
|
||||
'solingen',
|
||||
'wuppertal',
|
||||
)
|
||||
|
||||
test_list: tuple[str, ...] = ('duesseldorf', 'herne', 'neuss')
|
||||
|
||||
classes = (
|
||||
'background',
|
||||
'forest',
|
||||
'water',
|
||||
'agricultural',
|
||||
'residential,commercial,industrial',
|
||||
'grassland,swamp,shrubbery',
|
||||
'railway,trainstation',
|
||||
'highway,squares',
|
||||
'airport,shipyard',
|
||||
'roads',
|
||||
'buildings',
|
||||
)
|
||||
|
||||
colormap = mcolors.ListedColormap(
|
||||
[
|
||||
'#000000', # matplotlib black for background
|
||||
'#2ca02c', # matplotlib green for forest
|
||||
'#1f77b4', # matplotlib blue for water
|
||||
'#8c564b', # matplotlib brown for agricultural
|
||||
'#7f7f7f', # matplotlib gray residential_commercial_industrial
|
||||
'#bcbd22', # matplotlib olive for grassland_swamp_shrubbery
|
||||
'#ff7f0e', # matplotlib orange for railway_trainstation
|
||||
'#9467bd', # matplotlib purple for highway_squares
|
||||
'#17becf', # matplotlib cyan for airport_shipyard
|
||||
'#d62728', # matplotlib red for roads
|
||||
'#e377c2', # matplotlib pink for buildings
|
||||
]
|
||||
)
|
||||
|
||||
readers: ClassVar[dict[str, Callable[[str], Image.Image]]] = {
|
||||
'rgb': lambda path: Image.open(path).convert('RGB'),
|
||||
'dem': lambda path: Image.open(path).copy(),
|
||||
'seg': lambda path: Image.open(path).convert('I;16'),
|
||||
}
|
||||
|
||||
modality_filenames: ClassVar[dict[str, Callable[[list[str]], str]]] = {
|
||||
'rgb': lambda utm_coords: '{}_{}_rgb.jp2'.format(*utm_coords),
|
||||
'dem': lambda utm_coords: '{}_{}_dem.tif'.format(*utm_coords),
|
||||
'seg': lambda utm_coords: '{}_{}_seg.tif'.format(*utm_coords),
|
||||
}
|
||||
|
||||
modalities: tuple[str, ...] = ('rgb', 'dem', 'seg')
|
||||
|
||||
url = 'https://hf.co/datasets/torchgeo/geonrw/resolve/3cb6bdf2a615b9e526c7dcff85fd1f20728081b7/{}'
|
||||
|
||||
filename = 'nrw_dataset.tar.gz'
|
||||
md5 = 'd56ab50098d5452c33d08ff4e99ce281'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the GeoNRW dataset.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train", or "test"
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
assert split in self.splits, f'split must be one of {self.splits}'
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
|
||||
self.city_names = self.test_list if split == 'test' else self.train_list
|
||||
|
||||
self._verify()
|
||||
|
||||
self.file_list = self._get_file_list()
|
||||
|
||||
def _get_file_list(self) -> list[str]:
|
||||
"""Get a list of files for cities in the dataset split.
|
||||
|
||||
Returns:
|
||||
list of filenames in the dataset split
|
||||
"""
|
||||
file_list: list[str] = []
|
||||
for cn in self.city_names:
|
||||
pattern = os.path.join(self.root, cn, '*rgb.jp2')
|
||||
file_list.extend(glob(pattern))
|
||||
return sorted(file_list)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.file_list)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
to_tensor = transforms.ToTensor()
|
||||
|
||||
path: str = self.file_list[index]
|
||||
utm_coords = os.path.basename(path).split('_')[:2]
|
||||
base_dir = os.path.dirname(path)
|
||||
|
||||
sample: dict[str, Tensor] = {}
|
||||
for modality in self.modalities:
|
||||
modality_path = os.path.join(
|
||||
base_dir, self.modality_filenames[modality](utm_coords)
|
||||
)
|
||||
sample[modality] = to_tensor(self.readers[modality](modality_path))
|
||||
|
||||
# rename to torchgeo standard keys
|
||||
sample['image'] = sample.pop('rgb').float()
|
||||
sample['mask'] = sample.pop('seg').long()
|
||||
|
||||
if self.transforms:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
# check if city names directories exist
|
||||
all_exist = all(
|
||||
os.path.exists(os.path.join(self.root, cn)) for cn in self.city_names
|
||||
)
|
||||
if all_exist:
|
||||
return
|
||||
|
||||
# Check if the tar file has been downloaded
|
||||
if os.path.exists(os.path.join(self.root, self.filename)):
|
||||
extract_archive(os.path.join(self.root, self.filename), self.root)
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
download_and_extract_archive(
|
||||
self.url.format(self.filename),
|
||||
download_root=self.root,
|
||||
md5=self.md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional suptitle to use for figure
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
showing_predictions = 'prediction' in sample
|
||||
ncols = 3
|
||||
if showing_predictions:
|
||||
prediction = sample['prediction'].long()
|
||||
ncols += 1
|
||||
|
||||
fig, axs = plt.subplots(
|
||||
nrows=1, ncols=ncols, figsize=(ncols * 5, 10), sharex=True
|
||||
)
|
||||
|
||||
axs[0].imshow(sample['image'].permute(1, 2, 0))
|
||||
axs[0].axis('off')
|
||||
axs[1].imshow(sample['dem'].squeeze(0), cmap='gray')
|
||||
axs[1].axis('off')
|
||||
axs[2].imshow(
|
||||
sample['mask'].squeeze(0),
|
||||
self.colormap,
|
||||
vmin=0,
|
||||
vmax=10,
|
||||
interpolation='none',
|
||||
)
|
||||
axs[2].axis('off')
|
||||
|
||||
if showing_predictions:
|
||||
axs[3].imshow(
|
||||
prediction.squeeze(0),
|
||||
self.colormap,
|
||||
vmin=0,
|
||||
vmax=10,
|
||||
interpolation='none',
|
||||
)
|
||||
|
||||
# show classes in legend
|
||||
if show_titles:
|
||||
patches = [matplotlib.patches.Patch(color=c) for c in self.colormap.colors] # type: ignore
|
||||
axs[2].legend(
|
||||
patches, self.classes, loc='center left', bbox_to_anchor=(1, 0.5)
|
||||
)
|
||||
|
||||
if show_titles:
|
||||
axs[0].set_title('RGB Image')
|
||||
axs[1].set_title('DEM')
|
||||
axs[2].set_title('Labels')
|
||||
|
||||
if suptitle is not None:
|
||||
fig.suptitle(suptitle, y=0.8)
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче