зеркало из https://github.com/microsoft/torchgeo.git
Add DFC2022 dataset (#354)
* add DFC2022 dataset * plot fix * mypy fixes * add tests and tests data * maximum coverage * remove local dir * update per suggestions * update monkeypatching * update docstring * fix indentation in docstring
This commit is contained in:
Родитель
ff28a3b358
Коммит
45f370389f
|
@ -97,6 +97,11 @@ CV4A Kenya Crop Type Competition
|
|||
|
||||
.. autoclass:: CV4AKenyaCropType
|
||||
|
||||
2022 IEEE GRSS Data Fusion Contest (DFC2022)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: DFC2022
|
||||
|
||||
ETCI2021 Flood Detection
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import rasterio
|
||||
|
||||
from torchgeo.datasets import DFC2022
|
||||
|
||||
SIZE = 32
|
||||
|
||||
np.random.seed(0)
|
||||
random.seed(0)
|
||||
|
||||
|
||||
train_set = [
|
||||
{
|
||||
"image": "labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif", # noqa: E501
|
||||
"dem": "labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||
"target": "labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif", # noqa: E501
|
||||
},
|
||||
{
|
||||
"image": "labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif", # noqa: E501
|
||||
"dem": "labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||
"target": "labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif", # noqa: E501
|
||||
},
|
||||
]
|
||||
|
||||
unlabeled_set = [
|
||||
{
|
||||
"image": "unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif", # noqa: E501
|
||||
"dem": "unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||
},
|
||||
{
|
||||
"image": "unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif", # noqa: E501
|
||||
"dem": "unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||
},
|
||||
]
|
||||
|
||||
val_set = [
|
||||
{
|
||||
"image": "val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif", # noqa: E501
|
||||
"dem": "val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||
},
|
||||
{
|
||||
"image": "val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif", # noqa: E501
|
||||
"dem": "val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def create_file(path: str, dtype: str, num_channels: int) -> None:
|
||||
profile = {}
|
||||
profile["driver"] = "GTiff"
|
||||
profile["dtype"] = dtype
|
||||
profile["count"] = num_channels
|
||||
profile["crs"] = "epsg:4326"
|
||||
profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
|
||||
profile["height"] = SIZE
|
||||
profile["width"] = SIZE
|
||||
|
||||
if "float" in profile["dtype"]:
|
||||
Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"])
|
||||
else:
|
||||
Z = np.random.randint(
|
||||
np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"]
|
||||
)
|
||||
|
||||
src = rasterio.open(path, "w", **profile)
|
||||
for i in range(1, profile["count"] + 1):
|
||||
src.write(Z, i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for split in DFC2022.metadata:
|
||||
directory = DFC2022.metadata[split]["directory"]
|
||||
filename = DFC2022.metadata[split]["filename"]
|
||||
|
||||
# Remove old data
|
||||
if os.path.isdir(directory):
|
||||
shutil.rmtree(directory)
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
if split == "train":
|
||||
files = train_set
|
||||
elif split == "train-unlabeled":
|
||||
files = unlabeled_set
|
||||
else:
|
||||
files = val_set
|
||||
|
||||
for file_dict in files:
|
||||
# Create image file
|
||||
path = file_dict["image"]
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
create_file(path, dtype="uint8", num_channels=3)
|
||||
|
||||
# Create DEM file
|
||||
path = file_dict["dem"]
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
create_file(path, dtype="float32", num_channels=1)
|
||||
|
||||
# Create mask file
|
||||
if split == "train":
|
||||
path = file_dict["target"]
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
create_file(path, dtype="uint8", num_channels=1)
|
||||
|
||||
# Compress data
|
||||
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", directory)
|
||||
|
||||
# Compute checksums
|
||||
with open(filename, "rb") as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f"{filename}: {md5}")
|
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif
Normal file
Двоичные данные
tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif
Normal file
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif
Normal file
Двоичные данные
tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичные данные
tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif
Normal file
Двоичный файл не отображается.
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import DFC2022
|
||||
|
||||
|
||||
class TestDFC2022:
|
||||
@pytest.fixture(params=["train", "train-unlabeled", "val"])
|
||||
def dataset(
|
||||
self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
|
||||
) -> DFC2022:
|
||||
monkeypatch.setitem( # type: ignore[attr-defined]
|
||||
DFC2022.metadata["train"], "md5", "6e380c4fa659d05ca93be71b50cacd90"
|
||||
)
|
||||
monkeypatch.setitem( # type: ignore[attr-defined]
|
||||
DFC2022.metadata["train-unlabeled"],
|
||||
"md5",
|
||||
"b2bf3839323d4eae636f198921442945",
|
||||
)
|
||||
monkeypatch.setitem( # type: ignore[attr-defined]
|
||||
DFC2022.metadata["val"], "md5", "e018dc6865bd3086738038fff27b818a"
|
||||
)
|
||||
root = os.path.join("tests", "data", "dfc2022")
|
||||
split = request.param
|
||||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return DFC2022(root, split, transforms, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: DFC2022) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert x["image"].ndim == 3
|
||||
assert x["image"].shape[0] == 4
|
||||
|
||||
if dataset.split == "train":
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
assert x["mask"].ndim == 2
|
||||
|
||||
def test_len(self, dataset: DFC2022) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_extract(self, tmp_path: Path) -> None:
|
||||
shutil.copyfile(
|
||||
os.path.join("tests", "data", "dfc2022", "labeled_train.zip"),
|
||||
os.path.join(tmp_path, "labeled_train.zip"),
|
||||
)
|
||||
shutil.copyfile(
|
||||
os.path.join("tests", "data", "dfc2022", "unlabeled_train.zip"),
|
||||
os.path.join(tmp_path, "unlabeled_train.zip"),
|
||||
)
|
||||
shutil.copyfile(
|
||||
os.path.join("tests", "data", "dfc2022", "val.zip"),
|
||||
os.path.join(tmp_path, "val.zip"),
|
||||
)
|
||||
DFC2022(root=str(tmp_path))
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, "labeled_train.zip"), "w") as f:
|
||||
f.write("bad")
|
||||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
|
||||
DFC2022(root=str(tmp_path), checksum=True)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
DFC2022(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
|
||||
DFC2022(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: DFC2022) -> None:
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, suptitle="Test")
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
|
||||
if dataset.split == "train":
|
||||
x["prediction"] = x["mask"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
||||
del x["mask"]
|
||||
dataset.plot(x)
|
||||
plt.close()
|
|
@ -24,6 +24,7 @@ from .chesapeake import (
|
|||
from .cowc import COWC, COWCCounting, COWCDetection
|
||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .cyclone import TropicalCycloneWindEstimation
|
||||
from .dfc2022 import DFC2022
|
||||
from .etci2021 import ETCI2021
|
||||
from .eurosat import EuroSAT
|
||||
from .fair1m import FAIR1M
|
||||
|
@ -115,6 +116,7 @@ __all__ = (
|
|||
"COWCCounting",
|
||||
"COWCDetection",
|
||||
"CV4AKenyaCropType",
|
||||
"DFC2022",
|
||||
"ETCI2021",
|
||||
"EuroSAT",
|
||||
"FAIR1M",
|
||||
|
|
|
@ -0,0 +1,361 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""2022 IEEE GRSS Data Fusion Contest (DFC2022) dataset."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib import colors
|
||||
from rasterio.enums import Resampling
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import check_integrity, extract_archive, percentile_normalization
|
||||
|
||||
|
||||
class DFC2022(VisionDataset):
|
||||
"""DFC2022 dataset.
|
||||
|
||||
The `DFC2022 <https://www.grss-ieee.org/community/technical-committees/2022-ieee-grss-data-fusion-contest/>`_
|
||||
dataset is used as a benchmark dataset for the 2022 IEEE GRSS Data Fusion Contest
|
||||
and extends the MiniFrance dataset for semi-supervised semantic segmentation.
|
||||
The dataset consists of a train set containing labeled and unlabeled imagery and an
|
||||
unlabeled validation set. The dataset can be downloaded from the
|
||||
`IEEEDataPort DFC2022 website <https://ieee-dataport.org/competitions/data-fusion-contest-2022-dfc2022/>`_.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* RGB aerial images at 0.5 m per pixel spatial resolution (~2,000x2,0000 px)
|
||||
* DEMs at 1 m per pixel spatial resolution (~1,000x1,0000 px)
|
||||
* Masks at 0.5 m per pixel spatial resolution (~2,000x2,0000 px)
|
||||
* 16 land use/land cover categories
|
||||
* Images collected from the
|
||||
`IGN BD ORTHO database <https://geoservices.ign.fr/documentation/donnees/ortho/bdortho/>`_
|
||||
* DEMs collected from the
|
||||
`IGN RGE ALTI database <https://geoservices.ign.fr/documentation/donnees/alti/rgealti/>`_
|
||||
* Labels collected from the
|
||||
`UrbanAtlas 2012 database <https://land.copernicus.eu/local/urban-atlas/urban-atlas-2012/view/>`_
|
||||
* Data collected from 19 regions in France
|
||||
|
||||
Dataset format:
|
||||
|
||||
* images are three-channel geotiffs
|
||||
* DEMS are single-channel geotiffs
|
||||
* masks are single-channel geotiffs with the pixel values represent the class
|
||||
|
||||
Dataset classes:
|
||||
|
||||
0. No information
|
||||
1. Urban fabric
|
||||
2. Industrial, commercial, public, military, private and transport units
|
||||
3. Mine, dump and construction sites
|
||||
4. Artificial non-agricultural vegetated areas
|
||||
5. Arable land (annual crops)
|
||||
6. Permanent crops
|
||||
7. Pastures
|
||||
8. Complex and mixed cultivation patterns
|
||||
9. Orchards at the fringe of urban classes
|
||||
10. Forests
|
||||
11. Herbaceous vegetation associations
|
||||
12. Open spaces with little or no vegetation
|
||||
13. Wetlands
|
||||
14. Water
|
||||
15. Clouds and Shadows
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://doi.org/10.1007/s10994-020-05943-y
|
||||
|
||||
.. versionadded:: 0.3
|
||||
""" # noqa: E501
|
||||
|
||||
classes = [
|
||||
"No information",
|
||||
"Urban fabric",
|
||||
"Industrial, commercial, public, military, private and transport units",
|
||||
"Mine, dump and construction sites",
|
||||
"Artificial non-agricultural vegetated areas",
|
||||
"Arable land (annual crops)",
|
||||
"Permanent crops",
|
||||
"Pastures",
|
||||
"Complex and mixed cultivation patterns",
|
||||
"Orchards at the fringe of urban classes",
|
||||
"Forests",
|
||||
"Herbaceous vegetation associations",
|
||||
"Open spaces with little or no vegetation",
|
||||
"Wetlands",
|
||||
"Water",
|
||||
"Clouds and Shadows",
|
||||
]
|
||||
colormap = [
|
||||
"#231F20",
|
||||
"#DB5F57",
|
||||
"#DB9757",
|
||||
"#DBD057",
|
||||
"#ADDB57",
|
||||
"#75DB57",
|
||||
"#7BC47B",
|
||||
"#58B158",
|
||||
"#D4F6D4",
|
||||
"#B0E2B0",
|
||||
"#008000",
|
||||
"#58B0A7",
|
||||
"#995D13",
|
||||
"#579BDB",
|
||||
"#0062FF",
|
||||
"#231F20",
|
||||
]
|
||||
metadata = {
|
||||
"train": {
|
||||
"filename": "labeled_train.zip",
|
||||
"md5": "2e87d6a218e466dd0566797d7298c7a9",
|
||||
"directory": "labeled_train",
|
||||
},
|
||||
"train-unlabeled": {
|
||||
"filename": "unlabeled_train.zip",
|
||||
"md5": "1016d724bc494b8c50760ae56bb0585e",
|
||||
"directory": "unlabeled_train",
|
||||
},
|
||||
"val": {
|
||||
"filename": "val.zip",
|
||||
"md5": "6ddd9c0f89d8e74b94ea352d4002073f",
|
||||
"directory": "val",
|
||||
},
|
||||
}
|
||||
|
||||
image_root = "BDORTHO"
|
||||
dem_root = "RGEALTI"
|
||||
target_root = "UrbanAtlas"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new DFC2022 dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
split: one of "train" or "test"
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``split`` is invalid
|
||||
"""
|
||||
assert split in self.metadata
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
||||
self._verify()
|
||||
|
||||
self.class2idx = {c: i for i, c in enumerate(self.classes)}
|
||||
self.files = self._load_files()
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
files = self.files[index]
|
||||
image = self._load_image(files["image"])
|
||||
dem = self._load_image(files["dem"], shape=image.shape[1:])
|
||||
image = torch.cat(tensors=[image, dem], dim=0) # type: ignore[attr-defined]
|
||||
|
||||
sample = {"image": image}
|
||||
|
||||
if self.split == "train":
|
||||
mask = self._load_target(files["target"])
|
||||
sample["mask"] = mask
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.files)
|
||||
|
||||
def _load_files(self) -> List[Dict[str, str]]:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Returns:
|
||||
list of dicts containing paths for each pair of image/dem/mask
|
||||
"""
|
||||
directory = os.path.join(self.root, self.metadata[self.split]["directory"])
|
||||
images = glob.glob(
|
||||
os.path.join(directory, "**", self.image_root, "*.tif"), recursive=True
|
||||
)
|
||||
|
||||
files = []
|
||||
for image in sorted(images):
|
||||
dem = image.replace(self.image_root, self.dem_root)
|
||||
dem = f"{os.path.splitext(dem)[0]}_RGEALTI.tif"
|
||||
|
||||
if self.split == "train":
|
||||
target = image.replace(self.image_root, self.target_root)
|
||||
target = f"{os.path.splitext(target)[0]}_UA2012.tif"
|
||||
files.append(dict(image=image, dem=dem, target=target))
|
||||
else:
|
||||
files.append(dict(image=image, dem=dem))
|
||||
|
||||
return files
|
||||
|
||||
def _load_image(self, path: str, shape: Optional[Sequence[int]] = None) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
path: path to the image
|
||||
shape: the (h, w) to resample the image to
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
with rasterio.open(path) as f:
|
||||
array: "np.typing.NDArray[np.float_]" = f.read(
|
||||
out_shape=shape, out_dtype="float32", resampling=Resampling.bilinear
|
||||
)
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _load_target(self, path: str) -> Tensor:
|
||||
"""Load the target mask for a single image.
|
||||
|
||||
Args:
|
||||
path: path to the image
|
||||
|
||||
Returns:
|
||||
the target mask
|
||||
"""
|
||||
with rasterio.open(path) as f:
|
||||
array: "np.typing.NDArray[np.int_]" = f.read(
|
||||
indexes=1, out_dtype="int32", resampling=Resampling.bilinear
|
||||
)
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
tensor = tensor.to(torch.long) # type: ignore[attr-defined]
|
||||
return tensor
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if checksum fails or the dataset is not downloaded
|
||||
"""
|
||||
# Check if the files already exist
|
||||
exists = []
|
||||
for split_info in self.metadata.values():
|
||||
exists.append(
|
||||
os.path.exists(os.path.join(self.root, split_info["directory"]))
|
||||
)
|
||||
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if .zip files already exists (if so then extract)
|
||||
exists = []
|
||||
for split_info in self.metadata.values():
|
||||
filepath = os.path.join(self.root, split_info["filename"])
|
||||
if os.path.isfile(filepath):
|
||||
if self.checksum and not check_integrity(filepath, split_info["md5"]):
|
||||
raise RuntimeError("Dataset found, but corrupted.")
|
||||
exists.append(True)
|
||||
extract_archive(filepath)
|
||||
else:
|
||||
exists.append(False)
|
||||
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root` directory, either specify a different"
|
||||
+ " `root` directory or manually download the dataset to this directory."
|
||||
)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
ncols = 2
|
||||
image = sample["image"][:3]
|
||||
image = image.to(torch.uint8) # type: ignore[attr-defined]
|
||||
image = image.permute(1, 2, 0).numpy()
|
||||
|
||||
dem = sample["image"][-1].numpy()
|
||||
dem = percentile_normalization(dem, lower=0, upper=100, axis=(0, 1))
|
||||
|
||||
showing_mask = "mask" in sample
|
||||
showing_prediction = "prediction" in sample
|
||||
|
||||
cmap = colors.ListedColormap(self.colormap)
|
||||
|
||||
if showing_mask:
|
||||
mask = sample["mask"].numpy()
|
||||
ncols += 1
|
||||
if showing_prediction:
|
||||
pred = sample["prediction"].numpy()
|
||||
ncols += 1
|
||||
|
||||
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10))
|
||||
|
||||
axs[0].imshow(image)
|
||||
axs[0].axis("off")
|
||||
axs[1].imshow(dem)
|
||||
axs[1].axis("off")
|
||||
if showing_mask:
|
||||
axs[2].imshow(mask, cmap=cmap, interpolation=None)
|
||||
axs[2].axis("off")
|
||||
if showing_prediction:
|
||||
axs[3].imshow(pred, cmap=cmap, interpolation=None)
|
||||
axs[3].axis("off")
|
||||
elif showing_prediction:
|
||||
axs[2].imshow(pred, cmap=cmap, interpolation=None)
|
||||
axs[2].axis("off")
|
||||
|
||||
if show_titles:
|
||||
axs[0].set_title("Image")
|
||||
axs[1].set_title("DEM")
|
||||
|
||||
if showing_mask:
|
||||
axs[2].set_title("Ground Truth")
|
||||
if showing_prediction:
|
||||
axs[3].set_title("Predictions")
|
||||
elif showing_prediction:
|
||||
axs[2].set_title("Predictions")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче