* dataset and module

* test with training

* add tests

* start the fight with mypy

* kick off tests

* class var ruff

* don't download

* forgot tests data

* already downloaded

* coverage

* review

* mypy

* docs

* docs

* suggestion

* plotting

* versionadded: 2 digits

* Type hint unnecessary

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Nils Lehmann 2024-08-27 15:57:14 +02:00 коммит произвёл GitHub
Родитель 6d758abf81
Коммит 2d6e27ebd0
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
36 изменённых файлов: 606 добавлений и 0 удалений

Просмотреть файл

@ -94,6 +94,11 @@ FireRisk
.. autoclass:: FireRiskDataModule
GeoNRW
^^^^^^
.. autoclass:: GeoNRWDataModule
GID-15
^^^^^^

Просмотреть файл

@ -281,6 +281,11 @@ Forest Damage
.. autoclass:: ForestDamage
GeoNRW
^^^^^^^
.. autoclass:: GeoNRW
GID-15
^^^^^^

Просмотреть файл

@ -15,6 +15,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB
`FireRisk`_,C,NAIP Aerial,"CC-BY-NC-4.0","91,872",7,"320x320",1,RGB
`Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB
`GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM"
`GID-15`_,S,Gaofen-2,-,150,15,"6,800x7,200",3,RGB
`IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB
`Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB

1 Dataset Task Source License # Samples # Classes Size (px) Resolution (m) Bands
15 `FAIR1M`_ OD Gaofen/Google Earth CC-BY-NC-SA-3.0 15,000 37 1,024x1,024 0.3--0.8 RGB
16 `FireRisk`_ C NAIP Aerial CC-BY-NC-4.0 91,872 7 320x320 1 RGB
17 `Forest Damage`_ OD Drone imagery CDLA-Permissive-1.0 1,543 4 1,500x1,500 RGB
18 `GeoNRW`_ S Aerial CC-BY-4.0 7,783 11 1,000x1,000 1 RGB, DEM
19 `GID-15`_ S Gaofen-2 - 150 15 6,800x7,200 3 RGB
20 `IDTReeS`_ OD,C Aerial CC-BY-4.0 591 33 200x200 0.1--1 RGB
21 `Inria Aerial Image Labeling`_ S Aerial - 360 2 5,000x5,000 0.3 RGB

16
tests/conf/geonrw.yaml Normal file
Просмотреть файл

@ -0,0 +1,16 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
in_channels: 3
num_classes: 11
num_filters: 1
ignore_index: null
data:
class_path: GeoNRWDataModule
init_args:
batch_size: 1
dict_kwargs:
root: "tests/data/geonrw"

Двоичные данные
tests/data/geonrw/aachen/0_0_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/aachen/0_0_rgb.jp2 Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/aachen/0_0_seg.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/aachen/1_1_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/aachen/1_1_rgb.jp2 Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/aachen/1_1_seg.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bergisch/0_0_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bergisch/0_0_rgb.jp2 Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bergisch/0_0_seg.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bergisch/1_1_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bergisch/1_1_rgb.jp2 Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bergisch/1_1_seg.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bielefeld/0_0_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bielefeld/0_0_rgb.jp2 Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bielefeld/0_0_seg.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bielefeld/1_1_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bielefeld/1_1_rgb.jp2 Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/bielefeld/1_1_seg.tif Normal file

Двоичный файл не отображается.

87
tests/data/geonrw/data.py Normal file
Просмотреть файл

@ -0,0 +1,87 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import hashlib
import os
import shutil
import tarfile
import numpy as np
from PIL import Image
# Constants
IMAGE_SIZE = (100, 100)
TRAIN_CITIES = ['aachen', 'bergisch', 'bielefeld']
TEST_CITIES = ['duesseldorf']
CLASSES = [
'background',
'forest',
'water',
'agricultural',
'residential,commercial,industrial',
'grassland,swamp,shrubbery',
'railway,trainstation',
'highway,squares',
'airport,shipyard',
'roads',
'buildings',
]
NUM_SAMPLES_PER_CITY = 2
def create_directories(cities: list[str]) -> None:
for city in cities:
if os.path.exists(city):
shutil.rmtree(city)
os.makedirs(city, exist_ok=True)
def generate_dummy_data(cities: list[str]) -> None:
for city in cities:
for i in range(NUM_SAMPLES_PER_CITY):
utm_coords = f'{i}_{i}'
rgb_image = np.random.randint(0, 256, (*IMAGE_SIZE, 3), dtype=np.uint8)
dem_image = np.random.randint(0, 256, IMAGE_SIZE, dtype=np.uint8)
seg_image = np.random.randint(0, len(CLASSES), IMAGE_SIZE, dtype=np.uint8)
Image.fromarray(rgb_image).save(os.path.join(city, f'{utm_coords}_rgb.jp2'))
Image.fromarray(dem_image).save(os.path.join(city, f'{utm_coords}_dem.tif'))
Image.fromarray(seg_image).save(os.path.join(city, f'{utm_coords}_seg.tif'))
def create_tarball(output_filename: str, source_dirs: list[str]) -> None:
with tarfile.open(output_filename, 'w:gz') as tar:
for source_dir in source_dirs:
tar.add(source_dir, arcname=os.path.basename(source_dir))
def calculate_md5(filename: str) -> str:
hash_md5 = hashlib.md5()
with open(filename, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b''):
hash_md5.update(chunk)
return hash_md5.hexdigest()
# Main function
def main() -> None:
train_cities = TRAIN_CITIES
test_cities = TEST_CITIES
create_directories(train_cities)
create_directories(test_cities)
generate_dummy_data(train_cities)
generate_dummy_data(test_cities)
tarball_name = 'nrw_dataset.tar.gz'
create_tarball(tarball_name, train_cities + test_cities)
md5sum = calculate_md5(tarball_name)
print(f'MD5 checksum: {md5sum}')
if __name__ == '__main__':
main()

Двоичные данные
tests/data/geonrw/duesseldorf/0_0_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/duesseldorf/0_0_rgb.jp2 Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/duesseldorf/0_0_seg.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/duesseldorf/1_1_dem.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/duesseldorf/1_1_rgb.jp2 Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/duesseldorf/1_1_seg.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/geonrw/nrw_dataset.tar.gz Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import DatasetNotFoundError, GeoNRW
class TestGeoNRW:
@pytest.fixture(params=['train', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> GeoNRW:
md5 = '6ffc014d4b345bba3076e8d76ab481fa'
monkeypatch.setattr(GeoNRW, 'md5', md5)
url = os.path.join('tests', 'data', 'geonrw', 'nrw_dataset.tar.gz')
monkeypatch.setattr(GeoNRW, 'url', url)
monkeypatch.setattr(GeoNRW, 'train_list', ['aachen', 'bergisch', 'bielefeld'])
monkeypatch.setattr(GeoNRW, 'test_list', ['duesseldorf'])
root = tmp_path
split = request.param
transforms = nn.Identity()
return GeoNRW(root, split, transforms, download=True, checksum=True)
def test_getitem(self, dataset: GeoNRW) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert x['image'].shape[0] == 3
assert isinstance(x['mask'], torch.Tensor)
assert x['image'].shape[-2:] == x['mask'].shape[-2:]
def test_len(self, dataset: GeoNRW) -> None:
if dataset.split == 'train':
assert len(dataset) == 6
else:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: GeoNRW) -> None:
GeoNRW(root=dataset.root)
def test_not_yet_extracted(self, tmp_path: Path) -> None:
filename = 'nrw_dataset.tar.gz'
dir = os.path.join('tests', 'data', 'geonrw')
shutil.copyfile(
os.path.join(dir, filename), os.path.join(str(tmp_path), filename)
)
GeoNRW(root=str(tmp_path))
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
GeoNRW(split='foo')
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
GeoNRW(tmp_path)
def test_plot(self, dataset: GeoNRW) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()
sample = dataset[0]
sample['prediction'] = torch.clone(sample['mask'])
dataset.plot(sample, suptitle='Prediction')
plt.close()

Просмотреть файл

@ -55,6 +55,7 @@ class TestSemanticSegmentationTask:
'chesapeake_cvpr_7',
'deepglobelandcover',
'etci2021',
'geonrw',
'gid15',
'inria',
'l7irish',

Просмотреть файл

@ -15,6 +15,7 @@ from .eurosat import EuroSAT100DataModule, EuroSATDataModule, EuroSATSpatialData
from .fair1m import FAIR1MDataModule
from .fire_risk import FireRiskDataModule
from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule
from .geonrw import GeoNRWDataModule
from .gid15 import GID15DataModule
from .inria import InriaAerialImageLabelingDataModule
from .iobench import IOBenchDataModule
@ -73,6 +74,7 @@ __all__ = (
'EuroSAT100DataModule',
'FAIR1MDataModule',
'FireRiskDataModule',
'GeoNRWDataModule',
'GID15DataModule',
'InriaAerialImageLabelingDataModule',
'LandCoverAIDataModule',

Просмотреть файл

@ -0,0 +1,67 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""GeoNRW datamodule."""
import os
from typing import Any
import kornia.augmentation as K
from torch.utils.data import Subset
from ..datasets import GeoNRW
from ..transforms import AugmentationSequential
from .geo import NonGeoDataModule
from .utils import group_shuffle_split
class GeoNRWDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the GeoNRW dataset.
Implements 80/20 train/val splits based on city locations.
See :func:`setup` for more details.
.. versionadded: 0.6
"""
def __init__(
self, batch_size: int = 64, num_workers: int = 0, size: int = 256, **kwargs: Any
) -> None:
"""Initialize a new GeoNRWDataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
size: resize images of input size 1000x1000 to size x size
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.GeoNRW`.
"""
super().__init__(GeoNRW, batch_size, num_workers, **kwargs)
self.train_aug = AugmentationSequential(
K.Resize(size),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=['image', 'mask'],
)
self.aug = AugmentationSequential(K.Resize(size), data_keys=['image', 'mask'])
self.size = size
def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ['fit', 'validate']:
dataset = GeoNRW(split='train', **self.kwargs)
city_paths = [os.path.dirname(path) for path in dataset.file_list]
train_indices, val_indices = group_shuffle_split(
city_paths, test_size=0.2, random_state=0
)
self.train_dataset = Subset(dataset, train_indices)
self.val_dataset = Subset(dataset, val_indices)
if stage in ['test']:
self.test_dataset = GeoNRW(split='test', **self.kwargs)

Просмотреть файл

@ -54,6 +54,7 @@ from .geo import (
UnionDataset,
VectorDataset,
)
from .geonrw import GeoNRW
from .gid15 import GID15
from .globbiomass import GlobBiomass
from .idtrees import IDTReeS
@ -213,6 +214,7 @@ __all__ = (
'FAIR1M',
'FireRisk',
'ForestDamage',
'GeoNRW',
'GID15',
'IDTReeS',
'InriaAerialImageLabeling',

346
torchgeo/datasets/geonrw.py Normal file
Просмотреть файл

@ -0,0 +1,346 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""GeoNRW dataset."""
import os
from collections.abc import Callable
from glob import glob
from typing import ClassVar
import matplotlib
import matplotlib.cm
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from torchvision import transforms
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, download_and_extract_archive, extract_archive
class GeoNRW(NonGeoDataset):
"""GeoNRW dataset.
This datasets contains RGB, DEM and segmentation label data from North Rhine-Westphalia, Germany.
Dataset features:
* 7298 training and 485 test samples
* RGB images, 1000x1000px normalized to [0, 1]
* DEM images, unnormalized
* segmentation labels
Dataset format:
* RGB images are three-channel jp2
* DEM images are single-channel tif
* segmentation labels are single-channel tif
Dataset classes:
0. background
1. forest
2. water
3. agricultural
4. residential,commercial,industrial
5. grassland,swamp,shrubbery
6. railway,trainstation
7. highway,squares
8. airport,shipyard
9. roads
10. buildings
Additional information about the dataset can be found `on this site <https://ieee-dataport.org/open-access/geonrw>`__.
If you use this dataset in your research, please cite the following paper:
* https://ieeexplore.ieee.org/document/9406194
.. versionadded:: 0.6
"""
# Splits taken from https://github.com/gbaier/geonrw/blob/ecfcdbca8cfaaeb490a9c6916980f385b9f3941a/pytorch/nrw.py#L48
splits = ('train', 'test')
train_list: tuple[str, ...] = (
'aachen',
'bergisch',
'bielefeld',
'bochum',
'bonn',
'borken',
'bottrop',
'coesfeld',
'dortmund',
'dueren',
'duisburg',
'ennepetal',
'erftstadt',
'essen',
'euskirchen',
'gelsenkirchen',
'guetersloh',
'hagen',
'hamm',
'heinsberg',
'herford',
'hoexter',
'kleve',
'koeln',
'krefeld',
'leverkusen',
'lippetal',
'lippstadt',
'lotte',
'moenchengladbach',
'moers',
'muelheim',
'muenster',
'oberhausen',
'paderborn',
'recklinghausen',
'remscheid',
'siegen',
'solingen',
'wuppertal',
)
test_list: tuple[str, ...] = ('duesseldorf', 'herne', 'neuss')
classes = (
'background',
'forest',
'water',
'agricultural',
'residential,commercial,industrial',
'grassland,swamp,shrubbery',
'railway,trainstation',
'highway,squares',
'airport,shipyard',
'roads',
'buildings',
)
colormap = mcolors.ListedColormap(
[
'#000000', # matplotlib black for background
'#2ca02c', # matplotlib green for forest
'#1f77b4', # matplotlib blue for water
'#8c564b', # matplotlib brown for agricultural
'#7f7f7f', # matplotlib gray residential_commercial_industrial
'#bcbd22', # matplotlib olive for grassland_swamp_shrubbery
'#ff7f0e', # matplotlib orange for railway_trainstation
'#9467bd', # matplotlib purple for highway_squares
'#17becf', # matplotlib cyan for airport_shipyard
'#d62728', # matplotlib red for roads
'#e377c2', # matplotlib pink for buildings
]
)
readers: ClassVar[dict[str, Callable[[str], Image.Image]]] = {
'rgb': lambda path: Image.open(path).convert('RGB'),
'dem': lambda path: Image.open(path).copy(),
'seg': lambda path: Image.open(path).convert('I;16'),
}
modality_filenames: ClassVar[dict[str, Callable[[list[str]], str]]] = {
'rgb': lambda utm_coords: '{}_{}_rgb.jp2'.format(*utm_coords),
'dem': lambda utm_coords: '{}_{}_dem.tif'.format(*utm_coords),
'seg': lambda utm_coords: '{}_{}_seg.tif'.format(*utm_coords),
}
modalities: tuple[str, ...] = ('rgb', 'dem', 'seg')
url = 'https://hf.co/datasets/torchgeo/geonrw/resolve/3cb6bdf2a615b9e526c7dcff85fd1f20728081b7/{}'
filename = 'nrw_dataset.tar.gz'
md5 = 'd56ab50098d5452c33d08ff4e99ce281'
def __init__(
self,
root: Path = 'data',
split: str = 'train',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize the GeoNRW dataset.
Args:
root: root directory where dataset can be found
split: one of "train", or "test"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``split`` argument is invalid
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits, f'split must be one of {self.splits}'
self.root = root
self.split = split
self.transforms = transforms
self.download = download
self.checksum = checksum
self.city_names = self.test_list if split == 'test' else self.train_list
self._verify()
self.file_list = self._get_file_list()
def _get_file_list(self) -> list[str]:
"""Get a list of files for cities in the dataset split.
Returns:
list of filenames in the dataset split
"""
file_list: list[str] = []
for cn in self.city_names:
pattern = os.path.join(self.root, cn, '*rgb.jp2')
file_list.extend(glob(pattern))
return sorted(file_list)
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.file_list)
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
to_tensor = transforms.ToTensor()
path: str = self.file_list[index]
utm_coords = os.path.basename(path).split('_')[:2]
base_dir = os.path.dirname(path)
sample: dict[str, Tensor] = {}
for modality in self.modalities:
modality_path = os.path.join(
base_dir, self.modality_filenames[modality](utm_coords)
)
sample[modality] = to_tensor(self.readers[modality](modality_path))
# rename to torchgeo standard keys
sample['image'] = sample.pop('rgb').float()
sample['mask'] = sample.pop('seg').long()
if self.transforms:
sample = self.transforms(sample)
return sample
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# check if city names directories exist
all_exist = all(
os.path.exists(os.path.join(self.root, cn)) for cn in self.city_names
)
if all_exist:
return
# Check if the tar file has been downloaded
if os.path.exists(os.path.join(self.root, self.filename)):
extract_archive(os.path.join(self.root, self.filename), self.root)
return
# Check if the user requested to download the dataset
if not self.download:
raise DatasetNotFoundError(self)
# Download the dataset
self._download()
def _download(self) -> None:
"""Download the dataset."""
download_and_extract_archive(
self.url.format(self.filename),
download_root=self.root,
md5=self.md5 if self.checksum else None,
)
def plot(
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional suptitle to use for figure
Returns:
a matplotlib Figure with the rendered sample
"""
showing_predictions = 'prediction' in sample
ncols = 3
if showing_predictions:
prediction = sample['prediction'].long()
ncols += 1
fig, axs = plt.subplots(
nrows=1, ncols=ncols, figsize=(ncols * 5, 10), sharex=True
)
axs[0].imshow(sample['image'].permute(1, 2, 0))
axs[0].axis('off')
axs[1].imshow(sample['dem'].squeeze(0), cmap='gray')
axs[1].axis('off')
axs[2].imshow(
sample['mask'].squeeze(0),
self.colormap,
vmin=0,
vmax=10,
interpolation='none',
)
axs[2].axis('off')
if showing_predictions:
axs[3].imshow(
prediction.squeeze(0),
self.colormap,
vmin=0,
vmax=10,
interpolation='none',
)
# show classes in legend
if show_titles:
patches = [matplotlib.patches.Patch(color=c) for c in self.colormap.colors] # type: ignore
axs[2].legend(
patches, self.classes, loc='center left', bbox_to_anchor=(1, 0.5)
)
if show_titles:
axs[0].set_title('RGB Image')
axs[1].set_title('DEM')
axs[2].set_title('Labels')
if suptitle is not None:
fig.suptitle(suptitle, y=0.8)
fig.tight_layout()
return fig