* add potsdam dataset and tests

* add dummy potsdam data

* update potsdam docstring

* mypy fix

* style fixes

* Update datasets.rst

* update per suggestions

* updated docs

* refactor _load_target to use pillow

* format

* Update tests/datasets/test_potsdam.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Update tests/datasets/test_potsdam.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Update torchgeo/datasets/potsdam.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Update torchgeo/datasets/potsdam.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Update torchgeo/datasets/potsdam.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
isaac 2021-11-16 11:13:41 -06:00 коммит произвёл GitHub
Родитель c818827dfd
Коммит dfad08ef3a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
14 изменённых файлов: 506 добавлений и 0 удалений

Просмотреть файл

@ -134,6 +134,12 @@ PatternNet
.. autoclass:: PatternNet
Potsdam
^^^^^^^
.. autoclass:: Potsdam2D
.. autoclass:: Potsdam2DDataModule
RESISC45 (Remote Sensing Image Scene Classification)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Двоичные данные
tests/data/potsdam/4_Ortho_RGBIR.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/potsdam/4_Ortho_RGBIR/top_potsdam_2_10_RGBIR.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/potsdam/4_Ortho_RGBIR/top_potsdam_2_11_RGBIR.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/potsdam/4_Ortho_RGBIR/top_potsdam_5_15_RGBIR.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/potsdam/4_Ortho_RGBIR/top_potsdam_6_15_RGBIR.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/potsdam/5_Labels_all.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/potsdam/top_potsdam_2_10_label.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/potsdam/top_potsdam_2_11_label.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/potsdam/top_potsdam_5_15_label.tif Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/potsdam/top_potsdam_6_15_label.tif Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,97 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import shutil
from pathlib import Path
from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torchgeo.datasets import Potsdam2D, Potsdam2DDataModule
class TestPotsdam2D:
@pytest.fixture(params=["train", "test"])
def dataset(
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
) -> Potsdam2D:
md5s = ["e47175da529c5844052c7d483b483a30", "0cb795003a01154a72db7efaabbc76ae"]
splits = {
"train": ["top_potsdam_2_10", "top_potsdam_2_11"],
"test": ["top_potsdam_5_15", "top_potsdam_6_15"],
}
monkeypatch.setattr(Potsdam2D, "md5s", md5s) # type: ignore[attr-defined]
monkeypatch.setattr(Potsdam2D, "splits", splits) # type: ignore[attr-defined]
root = os.path.join("tests", "data", "potsdam")
split = request.param
transforms = nn.Identity() # type: ignore[attr-defined]
return Potsdam2D(root, split, transforms, checksum=True)
def test_getitem(self, dataset: Potsdam2D) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
def test_len(self, dataset: Potsdam2D) -> None:
assert len(dataset) == 2
def test_extract(self, tmp_path: Path) -> None:
root = os.path.join("tests", "data", "potsdam")
for filename in ["4_Ortho_RGBIR.zip", "5_Labels_all.zip"]:
shutil.copyfile(
os.path.join(root, filename), os.path.join(str(tmp_path), filename)
)
Potsdam2D(root=str(tmp_path))
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "4_Ortho_RGBIR.zip"), "w") as f:
f.write("bad")
with open(os.path.join(tmp_path, "5_Labels_all.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
Potsdam2D(root=str(tmp_path), checksum=True)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
Potsdam2D(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
Potsdam2D(str(tmp_path))
def test_plot(self, dataset: Potsdam2D) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
dataset.plot(x, show_titles=False)
x["prediction"] = x["mask"].clone()
dataset.plot(x)
class TestPotsdam2DDataModule:
@pytest.fixture(scope="class", params=[0.0, 0.5])
def datamodule(self, request: SubRequest) -> Potsdam2DDataModule:
root = os.path.join("tests", "data", "potsdam")
batch_size = 1
num_workers = 0
val_split_size = request.param
dm = Potsdam2DDataModule(
root, batch_size, num_workers, val_split_pct=val_split_size
)
dm.prepare_data()
dm.setup()
return dm
def test_train_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.train_dataloader()))
def test_val_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.val_dataloader()))
def test_test_dataloader(self, datamodule: Potsdam2DDataModule) -> None:
next(iter(datamodule.test_dataloader()))

Просмотреть файл

@ -54,6 +54,7 @@ from .levircd import LEVIRCDPlus
from .naip import NAIP, NAIPChesapeakeDataModule
from .nwpu import VHR10
from .patternnet import PatternNet
from .potsdam import Potsdam2D, Potsdam2DDataModule
from .resisc45 import RESISC45, RESISC45DataModule
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS, SEN12MSDataModule
@ -116,6 +117,8 @@ __all__ = (
"LandCoverAIDataModule",
"LEVIRCDPlus",
"PatternNet",
"Potsdam2D",
"Potsdam2DDataModule",
"RESISC45",
"RESISC45DataModule",
"SeasonalContrastS2",

Просмотреть файл

@ -0,0 +1,400 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Potsdam dataset."""
import os
from typing import Any, Callable, Dict, Optional
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import rasterio
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from ..datasets.utils import dataset_split, draw_semantic_segmentation_masks
from .geo import VisionDataset
from .utils import check_integrity, extract_archive, rgb_to_mask
class Potsdam2D(VisionDataset):
"""Potsdam 2D Semantic Segmentation dataset.
The `Potsdam <https://www2.isprs.org/commissions/comm2/wg4/benchmark/2d-sem-label-potsdam/>`_
dataset is a dataset for urban semantic segmentation used in the 2D Semantic Labeling
Contest - Potsdam. This dataset uses the "4_Ortho_RGBIR.zip" and "5_Labels_all.zip"
files to create the train/test sets used in the challenge. The dataset can be
requested at the challenge homepage. Note, the server contains additional data
for 3D Semantic Labeling which are currently not supported.
Dataset format:
* images are 4-channel geotiffs
* masks are 3-channel geotiffs with unique RGB values representing the class
Dataset classes:
0. Clutter/background
1. Impervious surfaces
2. Building
3. Low Vegetation
4. Tree
5. Car
If you use this dataset in your research, please cite the following paper:
* https://doi.org/10.5194/isprsannals-I-3-293-2012
""" # noqa: E501
filenames = ["4_Ortho_RGBIR.zip", "5_Labels_all.zip"]
md5s = ["c4a8f7d8c7196dd4eba4addd0aae10c1", "cf7403c1a97c0d279414db"]
image_root = "4_Ortho_RGBIR"
splits = {
"train": [
"top_potsdam_2_10",
"top_potsdam_2_11",
"top_potsdam_2_12",
"top_potsdam_3_10",
"top_potsdam_3_11",
"top_potsdam_3_12",
"top_potsdam_4_10",
"top_potsdam_4_11",
"top_potsdam_4_12",
"top_potsdam_5_10",
"top_potsdam_5_11",
"top_potsdam_5_12",
"top_potsdam_6_10",
"top_potsdam_6_11",
"top_potsdam_6_12",
"top_potsdam_6_7",
"top_potsdam_6_8",
"top_potsdam_6_9",
"top_potsdam_7_10",
"top_potsdam_7_11",
"top_potsdam_7_12",
"top_potsdam_7_7",
"top_potsdam_7_8",
"top_potsdam_7_9",
],
"test": [
"top_potsdam_5_15",
"top_potsdam_6_15",
"top_potsdam_6_13",
"top_potsdam_3_13",
"top_potsdam_4_14",
"top_potsdam_6_14",
"top_potsdam_5_14",
"top_potsdam_2_13",
"top_potsdam_4_15",
"top_potsdam_2_14",
"top_potsdam_5_13",
"top_potsdam_4_13",
"top_potsdam_3_14",
"top_potsdam_7_13",
],
}
classes = [
"Clutter/background",
"Impervious surfaces",
"Building",
"Low Vegetation",
"Tree",
"Car",
]
colormap = [
(255, 0, 0),
(255, 255, 255),
(0, 0, 255),
(0, 255, 255),
(0, 255, 0),
(255, 255, 0),
]
def __init__(
self,
root: str = "data",
split: str = "train",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
checksum: bool = False,
) -> None:
"""Initialize a new Potsdam dataset instance.
Args:
root: root directory where dataset can be found
split: one of "train" or "test"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)
"""
assert split in self.splits
self.root = root
self.split = split
self.transforms = transforms
self.checksum = checksum
self._verify()
self.files = []
for name in self.splits[split]:
image = os.path.join(root, self.image_root, name) + "_RGBIR.tif"
mask = os.path.join(root, name) + "_label.tif"
if os.path.exists(image) and os.path.exists(mask):
self.files.append(dict(image=image, mask=mask))
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
image = self._load_image(index)
mask = self._load_target(index)
sample = {"image": image, "mask": mask}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.files)
def _load_image(self, index: int) -> Tensor:
"""Load a single image.
Args:
index: index to return
Returns:
the image
"""
path = self.files[index]["image"]
with rasterio.open(path) as f:
array = f.read()
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor
def _load_target(self, index: int) -> Tensor:
"""Load the target mask for a single image.
Args:
index: index to return
Returns:
the target mask
"""
path = self.files[index]["mask"]
with Image.open(path) as img:
array = np.array(img.convert("RGB"))
array = rgb_to_mask(array, self.colormap)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
# Convert from HxWxC to CxHxW
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
return tensor
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if checksum fails or the dataset is not downloaded
"""
# Check if the files already exist
if os.path.exists(os.path.join(self.root, self.image_root)):
return
# Check if .zip files already exists (if so extract)
exists = []
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if os.path.isfile(filepath):
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError("Dataset found, but corrupted.")
exists.append(True)
extract_archive(filepath)
else:
exists.append(False)
if all(exists):
return
# Check if the user requested to download the dataset
raise RuntimeError(
"Dataset not found in `root` directory, either specify a different"
+ " `root` directory or manually download the dataset to this directory."
)
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
alpha: float = 0.5,
) -> Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
alpha: opacity with which to render predictions on top of the imagery
Returns:
a matplotlib Figure with the rendered sample
"""
ncols = 1
image1 = draw_semantic_segmentation_masks(
sample["image"][:3],
sample["mask"],
alpha=alpha,
colors=self.colormap, # type: ignore[arg-type]
)
if "prediction" in sample:
ncols += 1
image2 = draw_semantic_segmentation_masks(
sample["image"][:3],
sample["prediction"],
alpha=alpha,
colors=self.colormap, # type: ignore[arg-type]
)
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
if ncols > 1:
(ax0, ax1) = axs
else:
ax0 = axs
ax0.imshow(image1)
ax0.axis("off")
if ncols > 1:
ax1.imshow(image2)
ax1.axis("off")
if show_titles:
ax0.set_title("Ground Truth")
if ncols > 1:
ax1.set_title("Predictions")
if suptitle is not None:
plt.suptitle(suptitle)
return fig
class Potsdam2DDataModule(pl.LightningDataModule):
"""LightningDataModule implementation for the Potsdam2D dataset.
Uses the train/test splits from the dataset.
"""
def __init__(
self,
root_dir: str,
batch_size: int = 64,
num_workers: int = 0,
val_split_pct: float = 0.2,
**kwargs: Any,
) -> None:
"""Initialize a LightningDataModule for Potsdam2D based DataLoaders.
Args:
root_dir: The ``root`` argument to pass to the Potsdam2D Dataset classes
batch_size: The batch size to use in all created DataLoaders
num_workers: The number of workers to use in all created DataLoaders
val_split_pct: What percentage of the dataset to use as a validation set
"""
super().__init__() # type: ignore[no-untyped-call]
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split_pct = val_split_pct
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
return sample
def setup(self, stage: Optional[str] = None) -> None:
"""Initialize the main ``Dataset`` objects.
This method is called once per GPU per run.
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])
dataset = Potsdam2D(self.root_dir, "train", transforms=transforms)
if self.val_split_pct > 0.0:
self.train_dataset, self.val_dataset, _ = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=0.0
)
else:
self.train_dataset = dataset # type: ignore[assignment]
self.val_dataset = None # type: ignore[assignment]
self.test_dataset = Potsdam2D(self.root_dir, "test", transforms=transforms)
def train_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for training.
Returns:
training data loader
"""
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for validation.
Returns:
validation data loader
"""
if self.val_split_pct == 0.0:
return self.train_dataloader()
else:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
"""Return a DataLoader for testing.
Returns:
testing data loader
"""
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)