* OSCD: initial template

* updating download pattern

* package: including OSCD in torchgeo.datasets

* download: adapting download method to OSCD dataset + adding simple test for debugging

* _load_files method: temporary implementation

* OCSD: minimum working example, needs plenty improvement

* adding OSCD to docs

* Moving test to appropriate location

* OSCD: remove sort_bands and use utils.sort_sentinel2_bands

* Using rasterio instead of tifffile

* remove useless import

* style changes

* fix: style

* Developing tests for OSCD dataset

* Updating dataset description

* change name

* style fixes

* fixing mypy errors

* style fixes

* cast to string to fix typing errors

* style change

* isort fix

* remove TODO

* adding dataset for testing

* change len

* check if sum is concatdataset

* isort fix

* fixing some issues + correct md5 in dataset

* closing rasterio file handles

* removing some TODO's

* transitioning to fake data

* mypy fix attempt

* set fake data md5

* flake8 fix

* starting plot method

* updating plot method

* no predictions for now

* fixing style errors

* add testing for plot

* making some changes to fake testing data

* full coverage

* Use RGB channels in the plot function

* adding shape tests in test_getitem

* remove features and add to description

* fixing some things

* transitioning to authors dataset link

* No need to change file names + adapt test dataset

* adapting tests to new data format

* Update docs/api/datasets.rst

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* closing plot at end of terst

* add versionadded

* style fixes + indentation fixes

* style fixes

* forgot the .zip

* fix zipfile name

* temporary fix for flake8

* Add link to docs

* forgot to adjust this

* changing flake8 solve

* slimming down the test dataset

* removing imgs_x files which aren't needed for current testing but might be in the future

* Revert "removing imgs_x files which aren't needed for current testing but might be in the future"

This reverts commit cfbf26c1d3.

* nevermind, this was the issue

* trying to remove these once again

* adding band choosing functionality

* removing double code

* removing more double code

* flake8 fix

* adding one more training sample to dummy dataset and testing split

* typing numpy array

* back to this

* Fixing tests and mypy

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Maciej Kilian 2021-11-19 12:21:51 -08:00 коммит произвёл GitHub
Родитель 68112c749c
Коммит 0b8a8461bc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
91 изменённых файлов: 433 добавлений и 0 удалений

Просмотреть файл

@ -129,6 +129,11 @@ LEVIR-CD+ (LEVIR Change Detection +)
.. autoclass:: LEVIRCDPlus
OSCD (Onera Satellite Change Detection)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: OSCD
PatternNet
^^^^^^^^^^

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,2 @@
date_1: 20151211
date_2: 20180330

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,2 @@
date_1: 20161130
date_2: 20170829

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,2 @@
date_1: 20161130
date_2: 20170829

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 71 B

Двоичный файл не отображается.

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 71 B

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 71 B

107
tests/datasets/test_oscd.py Normal file
Просмотреть файл

@ -0,0 +1,107 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import glob
import os
import shutil
from pathlib import Path
from typing import Generator
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from matplotlib import pyplot as plt
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import OSCD
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestOSCD:
@pytest.fixture(params=zip(["all", "rgb"], ["train", "test"]))
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
request: SubRequest,
) -> OSCD:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.oscd, "download_url", download_url
)
md5 = "d6ebaae1ea0f3ae960af31531d394521"
monkeypatch.setattr(OSCD, "md5", md5) # type: ignore[attr-defined]
urls = {
"Onera Satellite Change Detection dataset - Images.zip": os.path.join(
"tests",
"data",
"oscd",
"Onera Satellite Change Detection dataset - Images.zip",
),
"Onera Satellite Change Detection dataset - Train Labels.zip": os.path.join(
"tests",
"data",
"oscd",
"Onera Satellite Change Detection dataset - Train Labels.zip",
),
"Onera Satellite Change Detection dataset - Test Labels.zip": os.path.join(
"tests",
"data",
"oscd",
"Onera Satellite Change Detection dataset - Test Labels.zip",
),
}
monkeypatch.setattr(OSCD, "urls", urls) # type: ignore[attr-defined]
bands, split = request.param
root = str(tmp_path)
transforms = nn.Identity() # type: ignore[attr-defined]
return OSCD(
root, split, bands, transforms=transforms, download=True, checksum=True
)
def test_getitem(self, dataset: OSCD) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].ndim == 4
assert isinstance(x["mask"], torch.Tensor)
assert x["mask"].ndim == 2
if dataset.bands == "rgb":
assert x["image"].shape[:2] == (2, 3)
else:
assert x["image"].shape[:2] == (2, 13)
def test_len(self, dataset: OSCD) -> None:
if dataset.split == "train":
assert len(dataset) == 1
else:
assert len(dataset) == 1
def test_add(self, dataset: OSCD) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
def test_already_extracted(self, dataset: OSCD) -> None:
OSCD(root=dataset.root, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "oscd", "*Onera*.zip")
root = str(tmp_path)
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
OSCD(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
OSCD(str(tmp_path))
def test_plot(self, dataset: OSCD) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()

Просмотреть файл

@ -53,6 +53,7 @@ from .landsat import (
from .levircd import LEVIRCDPlus
from .naip import NAIP, NAIPChesapeakeDataModule
from .nwpu import VHR10
from .oscd import OSCD
from .patternnet import PatternNet
from .potsdam import Potsdam2D, Potsdam2DDataModule
from .resisc45 import RESISC45, RESISC45DataModule
@ -116,6 +117,7 @@ __all__ = (
"LandCoverAI",
"LandCoverAIDataModule",
"LEVIRCDPlus",
"OSCD",
"PatternNet",
"Potsdam2D",
"Potsdam2DDataModule",

313
torchgeo/datasets/oscd.py Normal file
Просмотреть файл

@ -0,0 +1,313 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""OSCD dataset."""
import glob
import os
from typing import Callable, Dict, List, Optional, Sequence, Union
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from numpy import ndarray as Array
from PIL import Image
from torch import Tensor
from ..datasets.utils import draw_semantic_segmentation_masks
from .geo import VisionDataset
from .utils import download_url, extract_archive, sort_sentinel2_bands
class OSCD(VisionDataset):
"""OSCD dataset.
The `Onera Satellite Change Detection <https://rcdaudt.github.io/oscd/>`_
dataset addresses the issue of detecting changes between
satellite images from different dates. Imagery comes from
Sentinel-2 which contains varying resolutions per band.
Dataset format:
* images are 13-channel tifs
* masks are single-channel pngs where no change = 0, change = 255
Dataset classes:
0. no change
1. change
If you use this dataset in your research, please cite the following paper:
* https://doi.org/10.1109/IGARSS.2018.8518015
.. versionadded:: 0.2
"""
folder_prefix = "Onera Satellite Change Detection dataset - "
urls = {
"Onera Satellite Change Detection dataset - Images.zip": (
"https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download"
),
"Onera Satellite Change Detection dataset - Train Labels.zip": (
"https://partage.mines-telecom.fr/index.php/s/2D6n03k58ygBSpu/download"
),
"Onera Satellite Change Detection dataset - Test Labels.zip": (
"https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download"
),
}
md5 = "7383412da7ece1dca1c12dc92ac77f09"
zipfile_glob = "*Onera*.zip"
filename_glob = "*Onera*"
splits = ["train", "test"]
colormap = ["blue"]
def __init__(
self,
root: str = "data",
split: str = "train",
bands: str = "all",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new OSCD dataset instance.
Args:
root: root directory where dataset can be found
split: one of "train" or "test"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
assert split in self.splits
assert bands in ["rgb", "all"]
self.root = root
self.split = split
self.bands = bands
self.transforms = transforms
self.download = download
self.checksum = checksum
self._verify()
self.files = self._load_files()
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
files = self.files[index]
image1 = self._load_image(files["images1"])
image2 = self._load_image(files["images2"])
mask = self._load_target(str(files["mask"]))
image = torch.stack(tensors=[image1, image2], dim=0)
sample = {"image": image, "mask": mask}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.files)
def _load_files(self) -> List[Dict[str, Union[str, Sequence[str]]]]:
regions = []
labels_root = os.path.join(
self.root,
f"Onera Satellite Change Detection dataset - {self.split.capitalize()} "
+ "Labels",
)
images_root = os.path.join(
self.root, "Onera Satellite Change Detection dataset - Images"
)
folders = glob.glob(os.path.join(labels_root, "*/"))
for folder in folders:
region = folder.split(os.sep)[-2]
mask = os.path.join(labels_root, region, "cm", "cm.png")
def get_image_paths(ind: int) -> List[str]:
return sorted(
glob.glob(
os.path.join(images_root, region, f"imgs_{ind}_rect", "*.tif")
),
key=sort_sentinel2_bands,
)
images1, images2 = get_image_paths(1), get_image_paths(2)
if self.bands == "rgb":
images1, images2 = images1[1:4][::-1], images2[1:4][::-1]
with open(os.path.join(images_root, region, "dates.txt")) as f:
dates = tuple(
[line.split()[-1] for line in f.read().strip().splitlines()]
)
regions.append(
dict(
region=region,
images1=images1,
images2=images2,
mask=mask,
dates=dates,
)
)
return regions
def _load_image(self, paths: Sequence[str]) -> Tensor:
"""Load a single image.
Args:
path: path to the image
Returns:
the image
"""
images = []
for path in paths:
with rasterio.open(path) as f:
images.append(f.read())
array = np.stack(images, axis=0).astype(np.int_).squeeze()
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
return tensor
def _load_target(self, path: str) -> Tensor:
"""Load the target mask for a single image.
Args:
path: path to the image
Returns:
the target mask
"""
filename = os.path.join(path)
with Image.open(filename) as img:
array = np.array(img.convert("L"))
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
tensor = torch.clamp(tensor, min=0, max=1) # type: ignore[attr-defined]
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
return tensor
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the extracted files already exist
pathname = os.path.join(self.root, "**", self.filename_glob)
for fname in glob.iglob(pathname, recursive=True):
if not fname.endswith(".zip"):
return
# Check if the zip files have already been downloaded
pathname = os.path.join(self.root, self.zipfile_glob)
if glob.glob(pathname):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)
# Download the dataset
self._download()
self._extract()
def _download(self) -> None:
"""Download the dataset."""
for f_name in self.urls:
download_url(
self.urls[f_name],
self.root,
filename=f_name,
md5=self.md5 if self.checksum else None,
)
def _extract(self) -> None:
"""Extract the dataset."""
pathname = os.path.join(self.root, self.zipfile_glob)
for zipfile in glob.iglob(pathname):
extract_archive(zipfile)
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
alpha: float = 0.5,
) -> Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
alpha: opacity with which to render predictions on top of the imagery
Returns:
a matplotlib Figure with the rendered sample
"""
ncols = 2
rgb_inds = [3, 2, 1] if self.bands == "all" else [0, 1, 2]
def get_masked(img: Tensor) -> Array: # type: ignore[type-arg]
rgb_img = img[rgb_inds].float().numpy()
per02 = np.percentile(rgb_img, 2) # type: ignore[no-untyped-call]
per98 = np.percentile(rgb_img, 98) # type: ignore[no-untyped-call]
rgb_img = (np.clip((rgb_img - per02) / (per98 - per02), 0, 1) * 255).astype(
np.uint8
)
array: Array = draw_semantic_segmentation_masks( # type: ignore[type-arg]
torch.from_numpy(rgb_img), # type: ignore[attr-defined]
sample["mask"],
alpha=alpha,
colors=self.colormap, # type: ignore[arg-type]
)
return array
image1, image2 = get_masked(sample["image"][0]), get_masked(sample["image"][1])
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
axs[0].imshow(image1)
axs[0].axis("off")
axs[1].imshow(image2)
axs[1].axis("off")
if show_titles:
axs[0].set_title("Pre change")
axs[1].set_title("Post change")
if suptitle is not None:
plt.suptitle(suptitle)
return fig