зеркало из https://github.com/microsoft/torchgeo.git
Add DFC2022 dataset (#354)
* add DFC2022 dataset * plot fix * mypy fixes * add tests and tests data * maximum coverage * remove local dir * update per suggestions * update monkeypatching * update docstring * fix indentation in docstring
This commit is contained in:
Родитель
ff28a3b358
Коммит
45f370389f
|
@ -97,6 +97,11 @@ CV4A Kenya Crop Type Competition
|
||||||
|
|
||||||
.. autoclass:: CV4AKenyaCropType
|
.. autoclass:: CV4AKenyaCropType
|
||||||
|
|
||||||
|
2022 IEEE GRSS Data Fusion Contest (DFC2022)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. autoclass:: DFC2022
|
||||||
|
|
||||||
ETCI2021 Flood Detection
|
ETCI2021 Flood Detection
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import rasterio
|
||||||
|
|
||||||
|
from torchgeo.datasets import DFC2022
|
||||||
|
|
||||||
|
SIZE = 32
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
random.seed(0)
|
||||||
|
|
||||||
|
|
||||||
|
train_set = [
|
||||||
|
{
|
||||||
|
"image": "labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif", # noqa: E501
|
||||||
|
"dem": "labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||||
|
"target": "labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif", # noqa: E501
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image": "labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif", # noqa: E501
|
||||||
|
"dem": "labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||||
|
"target": "labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif", # noqa: E501
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
unlabeled_set = [
|
||||||
|
{
|
||||||
|
"image": "unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif", # noqa: E501
|
||||||
|
"dem": "unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image": "unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif", # noqa: E501
|
||||||
|
"dem": "unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
val_set = [
|
||||||
|
{
|
||||||
|
"image": "val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif", # noqa: E501
|
||||||
|
"dem": "val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"image": "val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif", # noqa: E501
|
||||||
|
"dem": "val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_file(path: str, dtype: str, num_channels: int) -> None:
|
||||||
|
profile = {}
|
||||||
|
profile["driver"] = "GTiff"
|
||||||
|
profile["dtype"] = dtype
|
||||||
|
profile["count"] = num_channels
|
||||||
|
profile["crs"] = "epsg:4326"
|
||||||
|
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
|
||||||
|
profile["height"] = SIZE
|
||||||
|
profile["width"] = SIZE
|
||||||
|
|
||||||
|
if "float" in profile["dtype"]:
|
||||||
|
Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"])
|
||||||
|
else:
|
||||||
|
Z = np.random.randint(
|
||||||
|
np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"]
|
||||||
|
)
|
||||||
|
|
||||||
|
src = rasterio.open(path, "w", **profile)
|
||||||
|
for i in range(1, profile["count"] + 1):
|
||||||
|
src.write(Z, i)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
for split in DFC2022.metadata:
|
||||||
|
directory = DFC2022.metadata[split]["directory"]
|
||||||
|
filename = DFC2022.metadata[split]["filename"]
|
||||||
|
|
||||||
|
# Remove old data
|
||||||
|
if os.path.isdir(directory):
|
||||||
|
shutil.rmtree(directory)
|
||||||
|
if os.path.exists(filename):
|
||||||
|
os.remove(filename)
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
files = train_set
|
||||||
|
elif split == "train-unlabeled":
|
||||||
|
files = unlabeled_set
|
||||||
|
else:
|
||||||
|
files = val_set
|
||||||
|
|
||||||
|
for file_dict in files:
|
||||||
|
# Create image file
|
||||||
|
path = file_dict["image"]
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
create_file(path, dtype="uint8", num_channels=3)
|
||||||
|
|
||||||
|
# Create DEM file
|
||||||
|
path = file_dict["dem"]
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
create_file(path, dtype="float32", num_channels=1)
|
||||||
|
|
||||||
|
# Create mask file
|
||||||
|
if split == "train":
|
||||||
|
path = file_dict["target"]
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
create_file(path, dtype="uint8", num_channels=1)
|
||||||
|
|
||||||
|
# Compress data
|
||||||
|
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", directory)
|
||||||
|
|
||||||
|
# Compute checksums
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
md5 = hashlib.md5(f.read()).hexdigest()
|
||||||
|
print(f"{filename}: {md5}")
|
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif
Normal file
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
|
@ -0,0 +1,96 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from _pytest.fixtures import SubRequest
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
|
||||||
|
from torchgeo.datasets import DFC2022
|
||||||
|
|
||||||
|
|
||||||
|
class TestDFC2022:
|
||||||
|
@pytest.fixture(params=["train", "train-unlabeled", "val"])
|
||||||
|
def dataset(
|
||||||
|
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
|
||||||
|
) -> DFC2022:
|
||||||
|
monkeypatch.setitem( # type: ignore[attr-defined]
|
||||||
|
DFC2022.metadata["train"], "md5", "6e380c4fa659d05ca93be71b50cacd90"
|
||||||
|
)
|
||||||
|
monkeypatch.setitem( # type: ignore[attr-defined]
|
||||||
|
DFC2022.metadata["train-unlabeled"],
|
||||||
|
"md5",
|
||||||
|
"b2bf3839323d4eae636f198921442945",
|
||||||
|
)
|
||||||
|
monkeypatch.setitem( # type: ignore[attr-defined]
|
||||||
|
DFC2022.metadata["val"], "md5", "e018dc6865bd3086738038fff27b818a"
|
||||||
|
)
|
||||||
|
root = os.path.join("tests", "data", "dfc2022")
|
||||||
|
split = request.param
|
||||||
|
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||||
|
return DFC2022(root, split, transforms, checksum=True)
|
||||||
|
|
||||||
|
def test_getitem(self, dataset: DFC2022) -> None:
|
||||||
|
x = dataset[0]
|
||||||
|
assert isinstance(x, dict)
|
||||||
|
assert isinstance(x["image"], torch.Tensor)
|
||||||
|
assert x["image"].ndim == 3
|
||||||
|
assert x["image"].shape[0] == 4
|
||||||
|
|
||||||
|
if dataset.split == "train":
|
||||||
|
assert isinstance(x["mask"], torch.Tensor)
|
||||||
|
assert x["mask"].ndim == 2
|
||||||
|
|
||||||
|
def test_len(self, dataset: DFC2022) -> None:
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
|
def test_extract(self, tmp_path: Path) -> None:
|
||||||
|
shutil.copyfile(
|
||||||
|
os.path.join("tests", "data", "dfc2022", "labeled_train.zip"),
|
||||||
|
os.path.join(tmp_path, "labeled_train.zip"),
|
||||||
|
)
|
||||||
|
shutil.copyfile(
|
||||||
|
os.path.join("tests", "data", "dfc2022", "unlabeled_train.zip"),
|
||||||
|
os.path.join(tmp_path, "unlabeled_train.zip"),
|
||||||
|
)
|
||||||
|
shutil.copyfile(
|
||||||
|
os.path.join("tests", "data", "dfc2022", "val.zip"),
|
||||||
|
os.path.join(tmp_path, "val.zip"),
|
||||||
|
)
|
||||||
|
DFC2022(root=str(tmp_path))
|
||||||
|
|
||||||
|
def test_corrupted(self, tmp_path: Path) -> None:
|
||||||
|
with open(os.path.join(tmp_path, "labeled_train.zip"), "w") as f:
|
||||||
|
f.write("bad")
|
||||||
|
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
|
||||||
|
DFC2022(root=str(tmp_path), checksum=True)
|
||||||
|
|
||||||
|
def test_invalid_split(self) -> None:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
DFC2022(split="foo")
|
||||||
|
|
||||||
|
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||||
|
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
|
||||||
|
DFC2022(str(tmp_path))
|
||||||
|
|
||||||
|
def test_plot(self, dataset: DFC2022) -> None:
|
||||||
|
x = dataset[0].copy()
|
||||||
|
dataset.plot(x, suptitle="Test")
|
||||||
|
plt.close()
|
||||||
|
dataset.plot(x, show_titles=False)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
if dataset.split == "train":
|
||||||
|
x["prediction"] = x["mask"].clone()
|
||||||
|
dataset.plot(x)
|
||||||
|
plt.close()
|
||||||
|
del x["mask"]
|
||||||
|
dataset.plot(x)
|
||||||
|
plt.close()
|
|
@ -24,6 +24,7 @@ from .chesapeake import (
|
||||||
from .cowc import COWC, COWCCounting, COWCDetection
|
from .cowc import COWC, COWCCounting, COWCDetection
|
||||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||||
from .cyclone import TropicalCycloneWindEstimation
|
from .cyclone import TropicalCycloneWindEstimation
|
||||||
|
from .dfc2022 import DFC2022
|
||||||
from .etci2021 import ETCI2021
|
from .etci2021 import ETCI2021
|
||||||
from .eurosat import EuroSAT
|
from .eurosat import EuroSAT
|
||||||
from .fair1m import FAIR1M
|
from .fair1m import FAIR1M
|
||||||
|
@ -115,6 +116,7 @@ __all__ = (
|
||||||
"COWCCounting",
|
"COWCCounting",
|
||||||
"COWCDetection",
|
"COWCDetection",
|
||||||
"CV4AKenyaCropType",
|
"CV4AKenyaCropType",
|
||||||
|
"DFC2022",
|
||||||
"ETCI2021",
|
"ETCI2021",
|
||||||
"EuroSAT",
|
"EuroSAT",
|
||||||
"FAIR1M",
|
"FAIR1M",
|
||||||
|
|
|
@ -0,0 +1,361 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""2022 IEEE GRSS Data Fusion Contest (DFC2022) dataset."""
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
from typing import Callable, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import rasterio
|
||||||
|
import torch
|
||||||
|
from matplotlib import colors
|
||||||
|
from rasterio.enums import Resampling
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .geo import VisionDataset
|
||||||
|
from .utils import check_integrity, extract_archive, percentile_normalization
|
||||||
|
|
||||||
|
|
||||||
|
class DFC2022(VisionDataset):
|
||||||
|
"""DFC2022 dataset.
|
||||||
|
|
||||||
|
The `DFC2022 <https://www.grss-ieee.org/community/technical-committees/2022-ieee-grss-data-fusion-contest/>`_
|
||||||
|
dataset is used as a benchmark dataset for the 2022 IEEE GRSS Data Fusion Contest
|
||||||
|
and extends the MiniFrance dataset for semi-supervised semantic segmentation.
|
||||||
|
The dataset consists of a train set containing labeled and unlabeled imagery and an
|
||||||
|
unlabeled validation set. The dataset can be downloaded from the
|
||||||
|
`IEEEDataPort DFC2022 website <https://ieee-dataport.org/competitions/data-fusion-contest-2022-dfc2022/>`_.
|
||||||
|
|
||||||
|
Dataset features:
|
||||||
|
|
||||||
|
* RGB aerial images at 0.5 m per pixel spatial resolution (~2,000x2,0000 px)
|
||||||
|
* DEMs at 1 m per pixel spatial resolution (~1,000x1,0000 px)
|
||||||
|
* Masks at 0.5 m per pixel spatial resolution (~2,000x2,0000 px)
|
||||||
|
* 16 land use/land cover categories
|
||||||
|
* Images collected from the
|
||||||
|
`IGN BD ORTHO database <https://geoservices.ign.fr/documentation/donnees/ortho/bdortho/>`_
|
||||||
|
* DEMs collected from the
|
||||||
|
`IGN RGE ALTI database <https://geoservices.ign.fr/documentation/donnees/alti/rgealti/>`_
|
||||||
|
* Labels collected from the
|
||||||
|
`UrbanAtlas 2012 database <https://land.copernicus.eu/local/urban-atlas/urban-atlas-2012/view/>`_
|
||||||
|
* Data collected from 19 regions in France
|
||||||
|
|
||||||
|
Dataset format:
|
||||||
|
|
||||||
|
* images are three-channel geotiffs
|
||||||
|
* DEMS are single-channel geotiffs
|
||||||
|
* masks are single-channel geotiffs with the pixel values represent the class
|
||||||
|
|
||||||
|
Dataset classes:
|
||||||
|
|
||||||
|
0. No information
|
||||||
|
1. Urban fabric
|
||||||
|
2. Industrial, commercial, public, military, private and transport units
|
||||||
|
3. Mine, dump and construction sites
|
||||||
|
4. Artificial non-agricultural vegetated areas
|
||||||
|
5. Arable land (annual crops)
|
||||||
|
6. Permanent crops
|
||||||
|
7. Pastures
|
||||||
|
8. Complex and mixed cultivation patterns
|
||||||
|
9. Orchards at the fringe of urban classes
|
||||||
|
10. Forests
|
||||||
|
11. Herbaceous vegetation associations
|
||||||
|
12. Open spaces with little or no vegetation
|
||||||
|
13. Wetlands
|
||||||
|
14. Water
|
||||||
|
15. Clouds and Shadows
|
||||||
|
|
||||||
|
If you use this dataset in your research, please cite the following paper:
|
||||||
|
|
||||||
|
* https://doi.org/10.1007/s10994-020-05943-y
|
||||||
|
|
||||||
|
.. versionadded:: 0.3
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
classes = [
|
||||||
|
"No information",
|
||||||
|
"Urban fabric",
|
||||||
|
"Industrial, commercial, public, military, private and transport units",
|
||||||
|
"Mine, dump and construction sites",
|
||||||
|
"Artificial non-agricultural vegetated areas",
|
||||||
|
"Arable land (annual crops)",
|
||||||
|
"Permanent crops",
|
||||||
|
"Pastures",
|
||||||
|
"Complex and mixed cultivation patterns",
|
||||||
|
"Orchards at the fringe of urban classes",
|
||||||
|
"Forests",
|
||||||
|
"Herbaceous vegetation associations",
|
||||||
|
"Open spaces with little or no vegetation",
|
||||||
|
"Wetlands",
|
||||||
|
"Water",
|
||||||
|
"Clouds and Shadows",
|
||||||
|
]
|
||||||
|
colormap = [
|
||||||
|
"#231F20",
|
||||||
|
"#DB5F57",
|
||||||
|
"#DB9757",
|
||||||
|
"#DBD057",
|
||||||
|
"#ADDB57",
|
||||||
|
"#75DB57",
|
||||||
|
"#7BC47B",
|
||||||
|
"#58B158",
|
||||||
|
"#D4F6D4",
|
||||||
|
"#B0E2B0",
|
||||||
|
"#008000",
|
||||||
|
"#58B0A7",
|
||||||
|
"#995D13",
|
||||||
|
"#579BDB",
|
||||||
|
"#0062FF",
|
||||||
|
"#231F20",
|
||||||
|
]
|
||||||
|
metadata = {
|
||||||
|
"train": {
|
||||||
|
"filename": "labeled_train.zip",
|
||||||
|
"md5": "2e87d6a218e466dd0566797d7298c7a9",
|
||||||
|
"directory": "labeled_train",
|
||||||
|
},
|
||||||
|
"train-unlabeled": {
|
||||||
|
"filename": "unlabeled_train.zip",
|
||||||
|
"md5": "1016d724bc494b8c50760ae56bb0585e",
|
||||||
|
"directory": "unlabeled_train",
|
||||||
|
},
|
||||||
|
"val": {
|
||||||
|
"filename": "val.zip",
|
||||||
|
"md5": "6ddd9c0f89d8e74b94ea352d4002073f",
|
||||||
|
"directory": "val",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
image_root = "BDORTHO"
|
||||||
|
dem_root = "RGEALTI"
|
||||||
|
target_root = "UrbanAtlas"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: str = "data",
|
||||||
|
split: str = "train",
|
||||||
|
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||||
|
checksum: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a new DFC2022 dataset instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root: root directory where dataset can be found
|
||||||
|
split: one of "train" or "test"
|
||||||
|
transforms: a function/transform that takes input sample and its target as
|
||||||
|
entry and returns a transformed version
|
||||||
|
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: if ``split`` is invalid
|
||||||
|
"""
|
||||||
|
assert split in self.metadata
|
||||||
|
self.root = root
|
||||||
|
self.split = split
|
||||||
|
self.transforms = transforms
|
||||||
|
self.checksum = checksum
|
||||||
|
|
||||||
|
self._verify()
|
||||||
|
|
||||||
|
self.class2idx = {c: i for i, c in enumerate(self.classes)}
|
||||||
|
self.files = self._load_files()
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
|
"""Return an index within the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data and label at that index
|
||||||
|
"""
|
||||||
|
files = self.files[index]
|
||||||
|
image = self._load_image(files["image"])
|
||||||
|
dem = self._load_image(files["dem"], shape=image.shape[1:])
|
||||||
|
image = torch.cat(tensors=[image, dem], dim=0) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
sample = {"image": image}
|
||||||
|
|
||||||
|
if self.split == "train":
|
||||||
|
mask = self._load_target(files["target"])
|
||||||
|
sample["mask"] = mask
|
||||||
|
|
||||||
|
if self.transforms is not None:
|
||||||
|
sample = self.transforms(sample)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of data points in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
length of the dataset
|
||||||
|
"""
|
||||||
|
return len(self.files)
|
||||||
|
|
||||||
|
def _load_files(self) -> List[Dict[str, str]]:
|
||||||
|
"""Return the paths of the files in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of dicts containing paths for each pair of image/dem/mask
|
||||||
|
"""
|
||||||
|
directory = os.path.join(self.root, self.metadata[self.split]["directory"])
|
||||||
|
images = glob.glob(
|
||||||
|
os.path.join(directory, "**", self.image_root, "*.tif"), recursive=True
|
||||||
|
)
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for image in sorted(images):
|
||||||
|
dem = image.replace(self.image_root, self.dem_root)
|
||||||
|
dem = f"{os.path.splitext(dem)[0]}_RGEALTI.tif"
|
||||||
|
|
||||||
|
if self.split == "train":
|
||||||
|
target = image.replace(self.image_root, self.target_root)
|
||||||
|
target = f"{os.path.splitext(target)[0]}_UA2012.tif"
|
||||||
|
files.append(dict(image=image, dem=dem, target=target))
|
||||||
|
else:
|
||||||
|
files.append(dict(image=image, dem=dem))
|
||||||
|
|
||||||
|
return files
|
||||||
|
|
||||||
|
def _load_image(self, path: str, shape: Optional[Sequence[int]] = None) -> Tensor:
|
||||||
|
"""Load a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: path to the image
|
||||||
|
shape: the (h, w) to resample the image to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the image
|
||||||
|
"""
|
||||||
|
with rasterio.open(path) as f:
|
||||||
|
array: "np.typing.NDArray[np.float_]" = f.read(
|
||||||
|
out_shape=shape, out_dtype="float32", resampling=Resampling.bilinear
|
||||||
|
)
|
||||||
|
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _load_target(self, path: str) -> Tensor:
|
||||||
|
"""Load the target mask for a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: path to the image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the target mask
|
||||||
|
"""
|
||||||
|
with rasterio.open(path) as f:
|
||||||
|
array: "np.typing.NDArray[np.int_]" = f.read(
|
||||||
|
indexes=1, out_dtype="int32", resampling=Resampling.bilinear
|
||||||
|
)
|
||||||
|
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||||
|
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _verify(self) -> None:
|
||||||
|
"""Verify the integrity of the dataset.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: if checksum fails or the dataset is not downloaded
|
||||||
|
"""
|
||||||
|
# Check if the files already exist
|
||||||
|
exists = []
|
||||||
|
for split_info in self.metadata.values():
|
||||||
|
exists.append(
|
||||||
|
os.path.exists(os.path.join(self.root, split_info["directory"]))
|
||||||
|
)
|
||||||
|
|
||||||
|
if all(exists):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if .zip files already exists (if so then extract)
|
||||||
|
exists = []
|
||||||
|
for split_info in self.metadata.values():
|
||||||
|
filepath = os.path.join(self.root, split_info["filename"])
|
||||||
|
if os.path.isfile(filepath):
|
||||||
|
if self.checksum and not check_integrity(filepath, split_info["md5"]):
|
||||||
|
raise RuntimeError("Dataset found, but corrupted.")
|
||||||
|
exists.append(True)
|
||||||
|
extract_archive(filepath)
|
||||||
|
else:
|
||||||
|
exists.append(False)
|
||||||
|
|
||||||
|
if all(exists):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the user requested to download the dataset
|
||||||
|
raise RuntimeError(
|
||||||
|
"Dataset not found in `root` directory, either specify a different"
|
||||||
|
+ " `root` directory or manually download the dataset to this directory."
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot(
|
||||||
|
self,
|
||||||
|
sample: Dict[str, Tensor],
|
||||||
|
show_titles: bool = True,
|
||||||
|
suptitle: Optional[str] = None,
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""Plot a sample from the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: a sample returned by :meth:`__getitem__`
|
||||||
|
show_titles: flag indicating whether to show titles above each panel
|
||||||
|
suptitle: optional string to use as a suptitle
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a matplotlib Figure with the rendered sample
|
||||||
|
"""
|
||||||
|
ncols = 2
|
||||||
|
image = sample["image"][:3]
|
||||||
|
image = image.to(torch.uint8) # type: ignore[attr-defined]
|
||||||
|
image = image.permute(1, 2, 0).numpy()
|
||||||
|
|
||||||
|
dem = sample["image"][-1].numpy()
|
||||||
|
dem = percentile_normalization(dem, lower=0, upper=100, axis=(0, 1))
|
||||||
|
|
||||||
|
showing_mask = "mask" in sample
|
||||||
|
showing_prediction = "prediction" in sample
|
||||||
|
|
||||||
|
cmap = colors.ListedColormap(self.colormap)
|
||||||
|
|
||||||
|
if showing_mask:
|
||||||
|
mask = sample["mask"].numpy()
|
||||||
|
ncols += 1
|
||||||
|
if showing_prediction:
|
||||||
|
pred = sample["prediction"].numpy()
|
||||||
|
ncols += 1
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10))
|
||||||
|
|
||||||
|
axs[0].imshow(image)
|
||||||
|
axs[0].axis("off")
|
||||||
|
axs[1].imshow(dem)
|
||||||
|
axs[1].axis("off")
|
||||||
|
if showing_mask:
|
||||||
|
axs[2].imshow(mask, cmap=cmap, interpolation=None)
|
||||||
|
axs[2].axis("off")
|
||||||
|
if showing_prediction:
|
||||||
|
axs[3].imshow(pred, cmap=cmap, interpolation=None)
|
||||||
|
axs[3].axis("off")
|
||||||
|
elif showing_prediction:
|
||||||
|
axs[2].imshow(pred, cmap=cmap, interpolation=None)
|
||||||
|
axs[2].axis("off")
|
||||||
|
|
||||||
|
if show_titles:
|
||||||
|
axs[0].set_title("Image")
|
||||||
|
axs[1].set_title("DEM")
|
||||||
|
|
||||||
|
if showing_mask:
|
||||||
|
axs[2].set_title("Ground Truth")
|
||||||
|
if showing_prediction:
|
||||||
|
axs[3].set_title("Predictions")
|
||||||
|
elif showing_prediction:
|
||||||
|
axs[2].set_title("Predictions")
|
||||||
|
|
||||||
|
if suptitle is not None:
|
||||||
|
plt.suptitle(suptitle)
|
||||||
|
|
||||||
|
return fig
|
Загрузка…
Ссылка в новой задаче