diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 6979999de..0e9cf396b 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -97,6 +97,11 @@ CV4A Kenya Crop Type Competition .. autoclass:: CV4AKenyaCropType +2022 IEEE GRSS Data Fusion Contest (DFC2022) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: DFC2022 + ETCI2021 Flood Detection ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/data/dfc2022/data.py b/tests/data/dfc2022/data.py new file mode 100644 index 000000000..60b0b3d7a --- /dev/null +++ b/tests/data/dfc2022/data.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import random +import shutil + +import numpy as np +import rasterio + +from torchgeo.datasets import DFC2022 + +SIZE = 32 + +np.random.seed(0) +random.seed(0) + + +train_set = [ + { + "image": "labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif", # noqa: E501 + "dem": "labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + "target": "labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif", # noqa: E501 + }, + { + "image": "labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif", # noqa: E501 + "dem": "labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + "target": "labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif", # noqa: E501 + }, +] + +unlabeled_set = [ + { + "image": "unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif", # noqa: E501 + "dem": "unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + }, + { + "image": "unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif", # noqa: E501 + "dem": "unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + }, +] + +val_set = [ + { + "image": "val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif", # noqa: E501 + "dem": "val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + }, + { + "image": "val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif", # noqa: E501 + "dem": "val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif", # noqa: E501 + }, +] + + +def create_file(path: str, dtype: str, num_channels: int) -> None: + profile = {} + profile["driver"] = "GTiff" + profile["dtype"] = dtype + profile["count"] = num_channels + profile["crs"] = "epsg:4326" + profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1) + profile["height"] = SIZE + profile["width"] = SIZE + + if "float" in profile["dtype"]: + Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"]) + else: + Z = np.random.randint( + np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"] + ) + + src = rasterio.open(path, "w", **profile) + for i in range(1, profile["count"] + 1): + src.write(Z, i) + + +if __name__ == "__main__": + for split in DFC2022.metadata: + directory = DFC2022.metadata[split]["directory"] + filename = DFC2022.metadata[split]["filename"] + + # Remove old data + if os.path.isdir(directory): + shutil.rmtree(directory) + if os.path.exists(filename): + os.remove(filename) + + if split == "train": + files = train_set + elif split == "train-unlabeled": + files = unlabeled_set + else: + files = val_set + + for file_dict in files: + # Create image file + path = file_dict["image"] + os.makedirs(os.path.dirname(path), exist_ok=True) + create_file(path, dtype="uint8", num_channels=3) + + # Create DEM file + path = file_dict["dem"] + os.makedirs(os.path.dirname(path), exist_ok=True) + create_file(path, dtype="float32", num_channels=1) + + # Create mask file + if split == "train": + path = file_dict["target"] + os.makedirs(os.path.dirname(path), exist_ok=True) + create_file(path, dtype="uint8", num_channels=1) + + # Compress data + shutil.make_archive(filename.replace(".zip", ""), "zip", ".", directory) + + # Compute checksums + with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{filename}: {md5}") diff --git a/tests/data/dfc2022/labeled_train.zip b/tests/data/dfc2022/labeled_train.zip new file mode 100644 index 000000000..1ef072e59 Binary files /dev/null and b/tests/data/dfc2022/labeled_train.zip differ diff --git a/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif new file mode 100644 index 000000000..c457f748f Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif differ diff --git a/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif new file mode 100644 index 000000000..beb1ee0b3 Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif differ diff --git a/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif new file mode 100644 index 000000000..910eb3379 Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif differ diff --git a/tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif b/tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif new file mode 100644 index 000000000..cf8d48ea3 Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif differ diff --git a/tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif new file mode 100644 index 000000000..3b81b21e2 Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif differ diff --git a/tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif b/tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif new file mode 100644 index 000000000..ca14ac077 Binary files /dev/null and b/tests/data/dfc2022/labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif differ diff --git a/tests/data/dfc2022/unlabeled_train.zip b/tests/data/dfc2022/unlabeled_train.zip new file mode 100644 index 000000000..c2a3fb95f Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train.zip differ diff --git a/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif b/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif new file mode 100644 index 000000000..1958f3269 Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif differ diff --git a/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif new file mode 100644 index 000000000..102f7a415 Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif differ diff --git a/tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif b/tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif new file mode 100644 index 000000000..675851149 Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif differ diff --git a/tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif new file mode 100644 index 000000000..cc361ae8d Binary files /dev/null and b/tests/data/dfc2022/unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif differ diff --git a/tests/data/dfc2022/val.zip b/tests/data/dfc2022/val.zip new file mode 100644 index 000000000..5850a8331 Binary files /dev/null and b/tests/data/dfc2022/val.zip differ diff --git a/tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif b/tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif new file mode 100644 index 000000000..f81592768 Binary files /dev/null and b/tests/data/dfc2022/val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif differ diff --git a/tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif new file mode 100644 index 000000000..6f386b137 Binary files /dev/null and b/tests/data/dfc2022/val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif differ diff --git a/tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif b/tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif new file mode 100644 index 000000000..c954167ad Binary files /dev/null and b/tests/data/dfc2022/val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif differ diff --git a/tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif b/tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif new file mode 100644 index 000000000..c5444b2e2 Binary files /dev/null and b/tests/data/dfc2022/val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif differ diff --git a/tests/datasets/test_dfc2022.py b/tests/datasets/test_dfc2022.py new file mode 100644 index 000000000..a342a5d51 --- /dev/null +++ b/tests/datasets/test_dfc2022.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path +from typing import Generator + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch + +from torchgeo.datasets import DFC2022 + + +class TestDFC2022: + @pytest.fixture(params=["train", "train-unlabeled", "val"]) + def dataset( + self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest + ) -> DFC2022: + monkeypatch.setitem( # type: ignore[attr-defined] + DFC2022.metadata["train"], "md5", "6e380c4fa659d05ca93be71b50cacd90" + ) + monkeypatch.setitem( # type: ignore[attr-defined] + DFC2022.metadata["train-unlabeled"], + "md5", + "b2bf3839323d4eae636f198921442945", + ) + monkeypatch.setitem( # type: ignore[attr-defined] + DFC2022.metadata["val"], "md5", "e018dc6865bd3086738038fff27b818a" + ) + root = os.path.join("tests", "data", "dfc2022") + split = request.param + transforms = nn.Identity() # type: ignore[attr-defined] + return DFC2022(root, split, transforms, checksum=True) + + def test_getitem(self, dataset: DFC2022) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert x["image"].ndim == 3 + assert x["image"].shape[0] == 4 + + if dataset.split == "train": + assert isinstance(x["mask"], torch.Tensor) + assert x["mask"].ndim == 2 + + def test_len(self, dataset: DFC2022) -> None: + assert len(dataset) == 2 + + def test_extract(self, tmp_path: Path) -> None: + shutil.copyfile( + os.path.join("tests", "data", "dfc2022", "labeled_train.zip"), + os.path.join(tmp_path, "labeled_train.zip"), + ) + shutil.copyfile( + os.path.join("tests", "data", "dfc2022", "unlabeled_train.zip"), + os.path.join(tmp_path, "unlabeled_train.zip"), + ) + shutil.copyfile( + os.path.join("tests", "data", "dfc2022", "val.zip"), + os.path.join(tmp_path, "val.zip"), + ) + DFC2022(root=str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "labeled_train.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + DFC2022(root=str(tmp_path), checksum=True) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + DFC2022(split="foo") + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"): + DFC2022(str(tmp_path)) + + def test_plot(self, dataset: DFC2022) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + + if dataset.split == "train": + x["prediction"] = x["mask"].clone() + dataset.plot(x) + plt.close() + del x["mask"] + dataset.plot(x) + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 0571ea765..c1dbf61fc 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -24,6 +24,7 @@ from .chesapeake import ( from .cowc import COWC, COWCCounting, COWCDetection from .cv4a_kenya_crop_type import CV4AKenyaCropType from .cyclone import TropicalCycloneWindEstimation +from .dfc2022 import DFC2022 from .etci2021 import ETCI2021 from .eurosat import EuroSAT from .fair1m import FAIR1M @@ -115,6 +116,7 @@ __all__ = ( "COWCCounting", "COWCDetection", "CV4AKenyaCropType", + "DFC2022", "ETCI2021", "EuroSAT", "FAIR1M", diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py new file mode 100644 index 000000000..1caf7e3d3 --- /dev/null +++ b/torchgeo/datasets/dfc2022.py @@ -0,0 +1,361 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""2022 IEEE GRSS Data Fusion Contest (DFC2022) dataset.""" + +import glob +import os +from typing import Callable, Dict, List, Optional, Sequence + +import matplotlib.pyplot as plt +import numpy as np +import rasterio +import torch +from matplotlib import colors +from rasterio.enums import Resampling +from torch import Tensor + +from .geo import VisionDataset +from .utils import check_integrity, extract_archive, percentile_normalization + + +class DFC2022(VisionDataset): + """DFC2022 dataset. + + The `DFC2022 `_ + dataset is used as a benchmark dataset for the 2022 IEEE GRSS Data Fusion Contest + and extends the MiniFrance dataset for semi-supervised semantic segmentation. + The dataset consists of a train set containing labeled and unlabeled imagery and an + unlabeled validation set. The dataset can be downloaded from the + `IEEEDataPort DFC2022 website `_. + + Dataset features: + + * RGB aerial images at 0.5 m per pixel spatial resolution (~2,000x2,0000 px) + * DEMs at 1 m per pixel spatial resolution (~1,000x1,0000 px) + * Masks at 0.5 m per pixel spatial resolution (~2,000x2,0000 px) + * 16 land use/land cover categories + * Images collected from the + `IGN BD ORTHO database `_ + * DEMs collected from the + `IGN RGE ALTI database `_ + * Labels collected from the + `UrbanAtlas 2012 database `_ + * Data collected from 19 regions in France + + Dataset format: + + * images are three-channel geotiffs + * DEMS are single-channel geotiffs + * masks are single-channel geotiffs with the pixel values represent the class + + Dataset classes: + + 0. No information + 1. Urban fabric + 2. Industrial, commercial, public, military, private and transport units + 3. Mine, dump and construction sites + 4. Artificial non-agricultural vegetated areas + 5. Arable land (annual crops) + 6. Permanent crops + 7. Pastures + 8. Complex and mixed cultivation patterns + 9. Orchards at the fringe of urban classes + 10. Forests + 11. Herbaceous vegetation associations + 12. Open spaces with little or no vegetation + 13. Wetlands + 14. Water + 15. Clouds and Shadows + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.1007/s10994-020-05943-y + + .. versionadded:: 0.3 + """ # noqa: E501 + + classes = [ + "No information", + "Urban fabric", + "Industrial, commercial, public, military, private and transport units", + "Mine, dump and construction sites", + "Artificial non-agricultural vegetated areas", + "Arable land (annual crops)", + "Permanent crops", + "Pastures", + "Complex and mixed cultivation patterns", + "Orchards at the fringe of urban classes", + "Forests", + "Herbaceous vegetation associations", + "Open spaces with little or no vegetation", + "Wetlands", + "Water", + "Clouds and Shadows", + ] + colormap = [ + "#231F20", + "#DB5F57", + "#DB9757", + "#DBD057", + "#ADDB57", + "#75DB57", + "#7BC47B", + "#58B158", + "#D4F6D4", + "#B0E2B0", + "#008000", + "#58B0A7", + "#995D13", + "#579BDB", + "#0062FF", + "#231F20", + ] + metadata = { + "train": { + "filename": "labeled_train.zip", + "md5": "2e87d6a218e466dd0566797d7298c7a9", + "directory": "labeled_train", + }, + "train-unlabeled": { + "filename": "unlabeled_train.zip", + "md5": "1016d724bc494b8c50760ae56bb0585e", + "directory": "unlabeled_train", + }, + "val": { + "filename": "val.zip", + "md5": "6ddd9c0f89d8e74b94ea352d4002073f", + "directory": "val", + }, + } + + image_root = "BDORTHO" + dem_root = "RGEALTI" + target_root = "UrbanAtlas" + + def __init__( + self, + root: str = "data", + split: str = "train", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + checksum: bool = False, + ) -> None: + """Initialize a new DFC2022 dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train" or "test" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if ``split`` is invalid + """ + assert split in self.metadata + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + + self.class2idx = {c: i for i, c in enumerate(self.classes)} + self.files = self._load_files() + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + files = self.files[index] + image = self._load_image(files["image"]) + dem = self._load_image(files["dem"], shape=image.shape[1:]) + image = torch.cat(tensors=[image, dem], dim=0) # type: ignore[attr-defined] + + sample = {"image": image} + + if self.split == "train": + mask = self._load_target(files["target"]) + sample["mask"] = mask + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def _load_files(self) -> List[Dict[str, str]]: + """Return the paths of the files in the dataset. + + Returns: + list of dicts containing paths for each pair of image/dem/mask + """ + directory = os.path.join(self.root, self.metadata[self.split]["directory"]) + images = glob.glob( + os.path.join(directory, "**", self.image_root, "*.tif"), recursive=True + ) + + files = [] + for image in sorted(images): + dem = image.replace(self.image_root, self.dem_root) + dem = f"{os.path.splitext(dem)[0]}_RGEALTI.tif" + + if self.split == "train": + target = image.replace(self.image_root, self.target_root) + target = f"{os.path.splitext(target)[0]}_UA2012.tif" + files.append(dict(image=image, dem=dem, target=target)) + else: + files.append(dict(image=image, dem=dem)) + + return files + + def _load_image(self, path: str, shape: Optional[Sequence[int]] = None) -> Tensor: + """Load a single image. + + Args: + path: path to the image + shape: the (h, w) to resample the image to + + Returns: + the image + """ + with rasterio.open(path) as f: + array: "np.typing.NDArray[np.float_]" = f.read( + out_shape=shape, out_dtype="float32", resampling=Resampling.bilinear + ) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + return tensor + + def _load_target(self, path: str) -> Tensor: + """Load the target mask for a single image. + + Args: + path: path to the image + + Returns: + the target mask + """ + with rasterio.open(path) as f: + array: "np.typing.NDArray[np.int_]" = f.read( + indexes=1, out_dtype="int32", resampling=Resampling.bilinear + ) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + tensor = tensor.to(torch.long) # type: ignore[attr-defined] + return tensor + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if checksum fails or the dataset is not downloaded + """ + # Check if the files already exist + exists = [] + for split_info in self.metadata.values(): + exists.append( + os.path.exists(os.path.join(self.root, split_info["directory"])) + ) + + if all(exists): + return + + # Check if .zip files already exists (if so then extract) + exists = [] + for split_info in self.metadata.values(): + filepath = os.path.join(self.root, split_info["filename"]) + if os.path.isfile(filepath): + if self.checksum and not check_integrity(filepath, split_info["md5"]): + raise RuntimeError("Dataset found, but corrupted.") + exists.append(True) + extract_archive(filepath) + else: + exists.append(False) + + if all(exists): + return + + # Check if the user requested to download the dataset + raise RuntimeError( + "Dataset not found in `root` directory, either specify a different" + + " `root` directory or manually download the dataset to this directory." + ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + ncols = 2 + image = sample["image"][:3] + image = image.to(torch.uint8) # type: ignore[attr-defined] + image = image.permute(1, 2, 0).numpy() + + dem = sample["image"][-1].numpy() + dem = percentile_normalization(dem, lower=0, upper=100, axis=(0, 1)) + + showing_mask = "mask" in sample + showing_prediction = "prediction" in sample + + cmap = colors.ListedColormap(self.colormap) + + if showing_mask: + mask = sample["mask"].numpy() + ncols += 1 + if showing_prediction: + pred = sample["prediction"].numpy() + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10)) + + axs[0].imshow(image) + axs[0].axis("off") + axs[1].imshow(dem) + axs[1].axis("off") + if showing_mask: + axs[2].imshow(mask, cmap=cmap, interpolation=None) + axs[2].axis("off") + if showing_prediction: + axs[3].imshow(pred, cmap=cmap, interpolation=None) + axs[3].axis("off") + elif showing_prediction: + axs[2].imshow(pred, cmap=cmap, interpolation=None) + axs[2].axis("off") + + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("DEM") + + if showing_mask: + axs[2].set_title("Ground Truth") + if showing_prediction: + axs[3].set_title("Predictions") + elif showing_prediction: + axs[2].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig