diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 6979999de..0e9cf396b 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -97,6 +97,11 @@ CV4A Kenya Crop Type Competition
.. autoclass:: CV4AKenyaCropType
+2022 IEEE GRSS Data Fusion Contest (DFC2022)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. autoclass:: DFC2022
+
ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/tests/data/dfc2022/data.py b/tests/data/dfc2022/data.py
new file mode 100644
index 000000000..60b0b3d7a
--- /dev/null
+++ b/tests/data/dfc2022/data.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import hashlib
+import os
+import random
+import shutil
+
+import numpy as np
+import rasterio
+
+from torchgeo.datasets import DFC2022
+
+SIZE = 32
+
+np.random.seed(0)
+random.seed(0)
+
+
+train_set = [
+ {
+ "image": "labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif", # noqa: E501
+ "dem": "labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
+ "target": "labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif", # noqa: E501
+ },
+ {
+ "image": "labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif", # noqa: E501
+ "dem": "labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
+ "target": "labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif", # noqa: E501
+ },
+]
+
+unlabeled_set = [
+ {
+ "image": "unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif", # noqa: E501
+ "dem": "unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
+ },
+ {
+ "image": "unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif", # noqa: E501
+ "dem": "unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
+ },
+]
+
+val_set = [
+ {
+ "image": "val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif", # noqa: E501
+ "dem": "val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
+ },
+ {
+ "image": "val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif", # noqa: E501
+ "dem": "val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif", # noqa: E501
+ },
+]
+
+
+def create_file(path: str, dtype: str, num_channels: int) -> None:
+ profile = {}
+ profile["driver"] = "GTiff"
+ profile["dtype"] = dtype
+ profile["count"] = num_channels
+ profile["crs"] = "epsg:4326"
+ profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
+ profile["height"] = SIZE
+ profile["width"] = SIZE
+
+ if "float" in profile["dtype"]:
+ Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"])
+ else:
+ Z = np.random.randint(
+ np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"]
+ )
+
+ src = rasterio.open(path, "w", **profile)
+ for i in range(1, profile["count"] + 1):
+ src.write(Z, i)
+
+
+if __name__ == "__main__":
+ for split in DFC2022.metadata:
+ directory = DFC2022.metadata[split]["directory"]
+ filename = DFC2022.metadata[split]["filename"]
+
+ # Remove old data
+ if os.path.isdir(directory):
+ shutil.rmtree(directory)
+ if os.path.exists(filename):
+ os.remove(filename)
+
+ if split == "train":
+ files = train_set
+ elif split == "train-unlabeled":
+ files = unlabeled_set
+ else:
+ files = val_set
+
+ for file_dict in files:
+ # Create image file
+ path = file_dict["image"]
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ create_file(path, dtype="uint8", num_channels=3)
+
+ # Create DEM file
+ path = file_dict["dem"]
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ create_file(path, dtype="float32", num_channels=1)
+
+ # Create mask file
+ if split == "train":
+ path = file_dict["target"]
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ create_file(path, dtype="uint8", num_channels=1)
+
+ # Compress data
+ shutil.make_archive(filename.replace(".zip", ""), "zip", ".", directory)
+
+ # Compute checksums
+ with open(filename, "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"{filename}: {md5}")
diff --git a/tests/data/dfc2022/labeled_train.zip b/tests/data/dfc2022/labeled_train.zip
new file mode 100644
index 000000000..1ef072e59
Binary files /dev/null and b/tests/data/dfc2022/labeled_train.zip differ
diff --git a/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif
new file mode 100644
index 000000000..c457f748f
Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif differ
diff --git a/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif
new file mode 100644
index 000000000..beb1ee0b3
Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif differ
diff --git a/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif
new file mode 100644
index 000000000..910eb3379
Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif differ
diff --git a/tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif b/tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif
new file mode 100644
index 000000000..cf8d48ea3
Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif differ
diff --git a/tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif
new file mode 100644
index 000000000..3b81b21e2
Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif differ
diff --git a/tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif b/tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif
new file mode 100644
index 000000000..ca14ac077
Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif differ
diff --git a/tests/data/dfc2022/unlabeled_train.zip b/tests/data/dfc2022/unlabeled_train.zip
new file mode 100644
index 000000000..c2a3fb95f
Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train.zip differ
diff --git a/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif b/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif
new file mode 100644
index 000000000..1958f3269
Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif differ
diff --git a/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif
new file mode 100644
index 000000000..102f7a415
Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif differ
diff --git a/tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif b/tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif
new file mode 100644
index 000000000..675851149
Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif differ
diff --git a/tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif
new file mode 100644
index 000000000..cc361ae8d
Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif differ
diff --git a/tests/data/dfc2022/val.zip b/tests/data/dfc2022/val.zip
new file mode 100644
index 000000000..5850a8331
Binary files /dev/null and b/tests/data/dfc2022/val.zip differ
diff --git a/tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif b/tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif
new file mode 100644
index 000000000..f81592768
Binary files /dev/null and b/tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif differ
diff --git a/tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif
new file mode 100644
index 000000000..6f386b137
Binary files /dev/null and b/tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif differ
diff --git a/tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif b/tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif
new file mode 100644
index 000000000..c954167ad
Binary files /dev/null and b/tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif differ
diff --git a/tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif
new file mode 100644
index 000000000..c5444b2e2
Binary files /dev/null and b/tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif differ
diff --git a/tests/datasets/test_dfc2022.py b/tests/datasets/test_dfc2022.py
new file mode 100644
index 000000000..a342a5d51
--- /dev/null
+++ b/tests/datasets/test_dfc2022.py
@@ -0,0 +1,96 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+import shutil
+from pathlib import Path
+from typing import Generator
+
+import matplotlib.pyplot as plt
+import pytest
+import torch
+import torch.nn as nn
+from _pytest.fixtures import SubRequest
+from _pytest.monkeypatch import MonkeyPatch
+
+from torchgeo.datasets import DFC2022
+
+
+class TestDFC2022:
+ @pytest.fixture(params=["train", "train-unlabeled", "val"])
+ def dataset(
+ self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest
+ ) -> DFC2022:
+ monkeypatch.setitem( # type: ignore[attr-defined]
+ DFC2022.metadata["train"], "md5", "6e380c4fa659d05ca93be71b50cacd90"
+ )
+ monkeypatch.setitem( # type: ignore[attr-defined]
+ DFC2022.metadata["train-unlabeled"],
+ "md5",
+ "b2bf3839323d4eae636f198921442945",
+ )
+ monkeypatch.setitem( # type: ignore[attr-defined]
+ DFC2022.metadata["val"], "md5", "e018dc6865bd3086738038fff27b818a"
+ )
+ root = os.path.join("tests", "data", "dfc2022")
+ split = request.param
+ transforms = nn.Identity() # type: ignore[attr-defined]
+ return DFC2022(root, split, transforms, checksum=True)
+
+ def test_getitem(self, dataset: DFC2022) -> None:
+ x = dataset[0]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert x["image"].ndim == 3
+ assert x["image"].shape[0] == 4
+
+ if dataset.split == "train":
+ assert isinstance(x["mask"], torch.Tensor)
+ assert x["mask"].ndim == 2
+
+ def test_len(self, dataset: DFC2022) -> None:
+ assert len(dataset) == 2
+
+ def test_extract(self, tmp_path: Path) -> None:
+ shutil.copyfile(
+ os.path.join("tests", "data", "dfc2022", "labeled_train.zip"),
+ os.path.join(tmp_path, "labeled_train.zip"),
+ )
+ shutil.copyfile(
+ os.path.join("tests", "data", "dfc2022", "unlabeled_train.zip"),
+ os.path.join(tmp_path, "unlabeled_train.zip"),
+ )
+ shutil.copyfile(
+ os.path.join("tests", "data", "dfc2022", "val.zip"),
+ os.path.join(tmp_path, "val.zip"),
+ )
+ DFC2022(root=str(tmp_path))
+
+ def test_corrupted(self, tmp_path: Path) -> None:
+ with open(os.path.join(tmp_path, "labeled_train.zip"), "w") as f:
+ f.write("bad")
+ with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
+ DFC2022(root=str(tmp_path), checksum=True)
+
+ def test_invalid_split(self) -> None:
+ with pytest.raises(AssertionError):
+ DFC2022(split="foo")
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
+ DFC2022(str(tmp_path))
+
+ def test_plot(self, dataset: DFC2022) -> None:
+ x = dataset[0].copy()
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+ dataset.plot(x, show_titles=False)
+ plt.close()
+
+ if dataset.split == "train":
+ x["prediction"] = x["mask"].clone()
+ dataset.plot(x)
+ plt.close()
+ del x["mask"]
+ dataset.plot(x)
+ plt.close()
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index 0571ea765..c1dbf61fc 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -24,6 +24,7 @@ from .chesapeake import (
from .cowc import COWC, COWCCounting, COWCDetection
from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
+from .dfc2022 import DFC2022
from .etci2021 import ETCI2021
from .eurosat import EuroSAT
from .fair1m import FAIR1M
@@ -115,6 +116,7 @@ __all__ = (
"COWCCounting",
"COWCDetection",
"CV4AKenyaCropType",
+ "DFC2022",
"ETCI2021",
"EuroSAT",
"FAIR1M",
diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py
new file mode 100644
index 000000000..1caf7e3d3
--- /dev/null
+++ b/torchgeo/datasets/dfc2022.py
@@ -0,0 +1,361 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""2022 IEEE GRSS Data Fusion Contest (DFC2022) dataset."""
+
+import glob
+import os
+from typing import Callable, Dict, List, Optional, Sequence
+
+import matplotlib.pyplot as plt
+import numpy as np
+import rasterio
+import torch
+from matplotlib import colors
+from rasterio.enums import Resampling
+from torch import Tensor
+
+from .geo import VisionDataset
+from .utils import check_integrity, extract_archive, percentile_normalization
+
+
+class DFC2022(VisionDataset):
+ """DFC2022 dataset.
+
+ The `DFC2022 `_
+ dataset is used as a benchmark dataset for the 2022 IEEE GRSS Data Fusion Contest
+ and extends the MiniFrance dataset for semi-supervised semantic segmentation.
+ The dataset consists of a train set containing labeled and unlabeled imagery and an
+ unlabeled validation set. The dataset can be downloaded from the
+ `IEEEDataPort DFC2022 website `_.
+
+ Dataset features:
+
+ * RGB aerial images at 0.5 m per pixel spatial resolution (~2,000x2,0000 px)
+ * DEMs at 1 m per pixel spatial resolution (~1,000x1,0000 px)
+ * Masks at 0.5 m per pixel spatial resolution (~2,000x2,0000 px)
+ * 16 land use/land cover categories
+ * Images collected from the
+ `IGN BD ORTHO database `_
+ * DEMs collected from the
+ `IGN RGE ALTI database `_
+ * Labels collected from the
+ `UrbanAtlas 2012 database `_
+ * Data collected from 19 regions in France
+
+ Dataset format:
+
+ * images are three-channel geotiffs
+ * DEMS are single-channel geotiffs
+ * masks are single-channel geotiffs with the pixel values represent the class
+
+ Dataset classes:
+
+ 0. No information
+ 1. Urban fabric
+ 2. Industrial, commercial, public, military, private and transport units
+ 3. Mine, dump and construction sites
+ 4. Artificial non-agricultural vegetated areas
+ 5. Arable land (annual crops)
+ 6. Permanent crops
+ 7. Pastures
+ 8. Complex and mixed cultivation patterns
+ 9. Orchards at the fringe of urban classes
+ 10. Forests
+ 11. Herbaceous vegetation associations
+ 12. Open spaces with little or no vegetation
+ 13. Wetlands
+ 14. Water
+ 15. Clouds and Shadows
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://doi.org/10.1007/s10994-020-05943-y
+
+ .. versionadded:: 0.3
+ """ # noqa: E501
+
+ classes = [
+ "No information",
+ "Urban fabric",
+ "Industrial, commercial, public, military, private and transport units",
+ "Mine, dump and construction sites",
+ "Artificial non-agricultural vegetated areas",
+ "Arable land (annual crops)",
+ "Permanent crops",
+ "Pastures",
+ "Complex and mixed cultivation patterns",
+ "Orchards at the fringe of urban classes",
+ "Forests",
+ "Herbaceous vegetation associations",
+ "Open spaces with little or no vegetation",
+ "Wetlands",
+ "Water",
+ "Clouds and Shadows",
+ ]
+ colormap = [
+ "#231F20",
+ "#DB5F57",
+ "#DB9757",
+ "#DBD057",
+ "#ADDB57",
+ "#75DB57",
+ "#7BC47B",
+ "#58B158",
+ "#D4F6D4",
+ "#B0E2B0",
+ "#008000",
+ "#58B0A7",
+ "#995D13",
+ "#579BDB",
+ "#0062FF",
+ "#231F20",
+ ]
+ metadata = {
+ "train": {
+ "filename": "labeled_train.zip",
+ "md5": "2e87d6a218e466dd0566797d7298c7a9",
+ "directory": "labeled_train",
+ },
+ "train-unlabeled": {
+ "filename": "unlabeled_train.zip",
+ "md5": "1016d724bc494b8c50760ae56bb0585e",
+ "directory": "unlabeled_train",
+ },
+ "val": {
+ "filename": "val.zip",
+ "md5": "6ddd9c0f89d8e74b94ea352d4002073f",
+ "directory": "val",
+ },
+ }
+
+ image_root = "BDORTHO"
+ dem_root = "RGEALTI"
+ target_root = "UrbanAtlas"
+
+ def __init__(
+ self,
+ root: str = "data",
+ split: str = "train",
+ transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new DFC2022 dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ split: one of "train" or "test"
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ AssertionError: if ``split`` is invalid
+ """
+ assert split in self.metadata
+ self.root = root
+ self.split = split
+ self.transforms = transforms
+ self.checksum = checksum
+
+ self._verify()
+
+ self.class2idx = {c: i for i, c in enumerate(self.classes)}
+ self.files = self._load_files()
+
+ def __getitem__(self, index: int) -> Dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ data and label at that index
+ """
+ files = self.files[index]
+ image = self._load_image(files["image"])
+ dem = self._load_image(files["dem"], shape=image.shape[1:])
+ image = torch.cat(tensors=[image, dem], dim=0) # type: ignore[attr-defined]
+
+ sample = {"image": image}
+
+ if self.split == "train":
+ mask = self._load_target(files["target"])
+ sample["mask"] = mask
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def __len__(self) -> int:
+ """Return the number of data points in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return len(self.files)
+
+ def _load_files(self) -> List[Dict[str, str]]:
+ """Return the paths of the files in the dataset.
+
+ Returns:
+ list of dicts containing paths for each pair of image/dem/mask
+ """
+ directory = os.path.join(self.root, self.metadata[self.split]["directory"])
+ images = glob.glob(
+ os.path.join(directory, "**", self.image_root, "*.tif"), recursive=True
+ )
+
+ files = []
+ for image in sorted(images):
+ dem = image.replace(self.image_root, self.dem_root)
+ dem = f"{os.path.splitext(dem)[0]}_RGEALTI.tif"
+
+ if self.split == "train":
+ target = image.replace(self.image_root, self.target_root)
+ target = f"{os.path.splitext(target)[0]}_UA2012.tif"
+ files.append(dict(image=image, dem=dem, target=target))
+ else:
+ files.append(dict(image=image, dem=dem))
+
+ return files
+
+ def _load_image(self, path: str, shape: Optional[Sequence[int]] = None) -> Tensor:
+ """Load a single image.
+
+ Args:
+ path: path to the image
+ shape: the (h, w) to resample the image to
+
+ Returns:
+ the image
+ """
+ with rasterio.open(path) as f:
+ array: "np.typing.NDArray[np.float_]" = f.read(
+ out_shape=shape, out_dtype="float32", resampling=Resampling.bilinear
+ )
+ tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
+ return tensor
+
+ def _load_target(self, path: str) -> Tensor:
+ """Load the target mask for a single image.
+
+ Args:
+ path: path to the image
+
+ Returns:
+ the target mask
+ """
+ with rasterio.open(path) as f:
+ array: "np.typing.NDArray[np.int_]" = f.read(
+ indexes=1, out_dtype="int32", resampling=Resampling.bilinear
+ )
+ tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
+ tensor = tensor.to(torch.long) # type: ignore[attr-defined]
+ return tensor
+
+ def _verify(self) -> None:
+ """Verify the integrity of the dataset.
+
+ Raises:
+ RuntimeError: if checksum fails or the dataset is not downloaded
+ """
+ # Check if the files already exist
+ exists = []
+ for split_info in self.metadata.values():
+ exists.append(
+ os.path.exists(os.path.join(self.root, split_info["directory"]))
+ )
+
+ if all(exists):
+ return
+
+ # Check if .zip files already exists (if so then extract)
+ exists = []
+ for split_info in self.metadata.values():
+ filepath = os.path.join(self.root, split_info["filename"])
+ if os.path.isfile(filepath):
+ if self.checksum and not check_integrity(filepath, split_info["md5"]):
+ raise RuntimeError("Dataset found, but corrupted.")
+ exists.append(True)
+ extract_archive(filepath)
+ else:
+ exists.append(False)
+
+ if all(exists):
+ return
+
+ # Check if the user requested to download the dataset
+ raise RuntimeError(
+ "Dataset not found in `root` directory, either specify a different"
+ + " `root` directory or manually download the dataset to this directory."
+ )
+
+ def plot(
+ self,
+ sample: Dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ ) -> plt.Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ suptitle: optional string to use as a suptitle
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+ """
+ ncols = 2
+ image = sample["image"][:3]
+ image = image.to(torch.uint8) # type: ignore[attr-defined]
+ image = image.permute(1, 2, 0).numpy()
+
+ dem = sample["image"][-1].numpy()
+ dem = percentile_normalization(dem, lower=0, upper=100, axis=(0, 1))
+
+ showing_mask = "mask" in sample
+ showing_prediction = "prediction" in sample
+
+ cmap = colors.ListedColormap(self.colormap)
+
+ if showing_mask:
+ mask = sample["mask"].numpy()
+ ncols += 1
+ if showing_prediction:
+ pred = sample["prediction"].numpy()
+ ncols += 1
+
+ fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10))
+
+ axs[0].imshow(image)
+ axs[0].axis("off")
+ axs[1].imshow(dem)
+ axs[1].axis("off")
+ if showing_mask:
+ axs[2].imshow(mask, cmap=cmap, interpolation=None)
+ axs[2].axis("off")
+ if showing_prediction:
+ axs[3].imshow(pred, cmap=cmap, interpolation=None)
+ axs[3].axis("off")
+ elif showing_prediction:
+ axs[2].imshow(pred, cmap=cmap, interpolation=None)
+ axs[2].axis("off")
+
+ if show_titles:
+ axs[0].set_title("Image")
+ axs[1].set_title("DEM")
+
+ if showing_mask:
+ axs[2].set_title("Ground Truth")
+ if showing_prediction:
+ axs[3].set_title("Predictions")
+ elif showing_prediction:
+ axs[2].set_title("Predictions")
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+
+ return fig