* add mapinwild dataset

* add copyright and move the header

* Apply suggestions from code review

accept suggestions

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* add spaces between sections

* test_black

* test_isort

* test_flake8

* dataset instance and test data

Improves the dataset instance and test data script.

* improvements in test script and dataset class

* update test data

* Update mapinwild.py

* Update mapinwild.py

fix typo

* Update mapinwild.py

* improved test coverage and bug fixes

* improved test coverage

* lazy import pandas

* test coverage for pandas

* test coverage for pandas

* Changes after the review

The changes made after the review.

* delete data

* fix mypy

* fix mypy

* delete residual files

* fix type hinting

* Apply suggestions from code review

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Apply suggestions from code review

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Apply suggestions from code review

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* fix RtD

* fix RtD

* fix syntax

* fix RtD

* fix file namings

* modality naming fix

* fix hidden method naming

* Update torchgeo/datasets/mapinwild.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* address reviews

* refactoring and testing

- passing test_download with monkeypatch
- refactoring
- addressing most of the comments

* fix mypy

* syntax and type conversion

* Update torchgeo/datasets/mapinwild.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* addressing the comments

* fix mypy plt.Figure not defined

* make the _merge_parts slimmer

* pandas and reviews

* monkeypatch tvt sets

* Simplify MonkeyPatch import

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Burak 2023-09-29 12:52:08 +02:00 коммит произвёл GitHub
Родитель b6d78b74c3
Коммит c51014c656
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
22 изменённых файлов: 720 добавлений и 0 удалений

Просмотреть файл

@ -257,6 +257,11 @@ LoveDA
.. autoclass:: LoveDA
MapInWild
^^^^^^^^^
.. autoclass:: MapInWild
Million-AID
^^^^^^^^^^^

Просмотреть файл

@ -18,6 +18,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`LandCover.ai`_,S,Aerial,"10,674",5,512x512,0.25--0.5,RGB
`LEVIR-CD+`_,CD,Google Earth,985,2,"1,024x1,024",0.5,RGB
`LoveDA`_,S,Google Earth,"5,987",7,"1,024x1,024",0.3,RGB
`MapInWild`_,S,"Sentinel-1/2, ESA WorldCover, NOAA VIIRS DNB",1018,1,1920x1920,10--463.83,"SAR, MSI, 2020_Map, avg_rad"
`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB
`NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI

1 Dataset Task Source # Samples # Classes Size (px) Resolution (m) Bands
18 `LandCover.ai`_ S Aerial 10,674 5 512x512 0.25--0.5 RGB
19 `LEVIR-CD+`_ CD Google Earth 985 2 1,024x1,024 0.5 RGB
20 `LoveDA`_ S Google Earth 5,987 7 1,024x1,024 0.3 RGB
21 `MapInWild`_ S Sentinel-1/2, ESA WorldCover, NOAA VIIRS DNB 1018 1 1920x1920 10--463.83 SAR, MSI, 2020_Map, avg_rad
22 `Million-AID`_ C Google Earth 1M 51--73 0.5--153 RGB
23 `NASA Marine Debris`_ OD PlanetScope 707 1 256x256 3 RGB
24 `OSCD`_ CD Sentinel-2 24 2 40--1,180 60 MSI

Просмотреть файл

@ -0,0 +1,166 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import hashlib
import os
import shutil
import numpy as np
import pandas as pd
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine
SIZE = 32
np.random.seed(0)
meta = {
"driver": "GTiff",
"nodata": None,
"width": SIZE,
"height": SIZE,
"crs": CRS.from_epsg(32720),
"transform": Affine(10.0, 0.0, 612190.0, 0.0, -10.0, 7324250.0),
}
count = {
"ESA_WC": 1,
"VIIRS": 1,
"mask": 1,
"s1_part1": 2,
"s1_part2": 2,
"s2_temporal_subset_part1": 10,
"s2_temporal_subset_part2": 10,
"s2_autumn_part1": 10,
"s2_autumn_part2": 10,
"s2_spring_part1": 10,
"s2_spring_part2": 10,
"s2_summer_part1": 10,
"s2_summer_part2": 10,
"s2_winter_part1": 10,
"s2_winter_part2": 10,
}
dtype = {
"ESA_WC": np.uint8,
"VIIRS": np.float32,
"mask": np.byte,
"s1_part1": np.float64,
"s1_part2": np.float64,
"s2_temporal_subset_part1": np.uint16,
"s2_temporal_subset_part2": np.uint16,
"s2_autumn_part1": np.uint16,
"s2_autumn_part2": np.uint16,
"s2_spring_part1": np.uint16,
"s2_spring_part2": np.uint16,
"s2_summer_part1": np.uint16,
"s2_summer_part2": np.uint16,
"s2_winter_part1": np.uint16,
"s2_winter_part2": np.uint16,
}
stop = {
"ESA_WC": np.iinfo(np.uint8).max,
"VIIRS": np.finfo(np.float32).max,
"mask": np.iinfo(np.byte).max,
"s1_part1": np.finfo(np.float64).max,
"s1_part2": np.finfo(np.float64).max,
"s2_temporal_subset_part1": np.iinfo(np.uint16).max,
"s2_temporal_subset_part2": np.iinfo(np.uint16).max,
"s2_autumn_part1": np.iinfo(np.uint16).max,
"s2_autumn_part2": np.iinfo(np.uint16).max,
"s2_spring_part1": np.iinfo(np.uint16).max,
"s2_spring_part2": np.iinfo(np.uint16).max,
"s2_summer_part1": np.iinfo(np.uint16).max,
"s2_summer_part2": np.iinfo(np.uint16).max,
"s2_winter_part1": np.iinfo(np.uint16).max,
"s2_winter_part2": np.iinfo(np.uint16).max,
}
folder_path = os.path.join(os.getcwd(), "tests", "data", "mapinwild")
dict_all = {
"s2_sum": ["s2_summer_part1", "s2_summer_part2"],
"s2_spr": ["s2_spring_part1", "s2_spring_part2"],
"s2_win": ["s2_winter_part1", "s2_winter_part2"],
"s2_aut": ["s2_autumn_part1", "s2_autumn_part2"],
"s1": ["s1_part1", "s1_part2"],
"s2_temp": ["s2_temporal_subset_part1", "s2_temporal_subset_part2"],
}
md5s = {}
keys = count.keys()
modality_download_list = list(count.keys())
for source in modality_download_list:
directory = os.path.join(folder_path, source)
# Remove old data
if os.path.exists(directory):
shutil.rmtree(directory)
os.makedirs(directory, exist_ok=True)
# Random images
for i in range(1, 3):
filename = f"{i}.tif"
filepath = os.path.join(directory, filename)
meta["count"] = count[source]
meta["dtype"] = dtype[source]
with rasterio.open(filepath, "w", **meta) as f:
for j in range(1, count[source] + 1):
if meta["dtype"] is np.float32 or meta["dtype"] is np.float64:
data = np.random.randn(SIZE, SIZE).astype(dtype[source])
else:
data = np.random.randint(stop[source], size=(SIZE, SIZE)).astype(
dtype[source]
)
f.write(data, j)
# Mimic the two-part structure of the dataset
for key in dict_all.keys():
path_list = dict_all[key]
path_list_dir_p1 = os.path.join(folder_path, path_list[0])
path_list_dir_p2 = os.path.join(folder_path, path_list[1])
n_ims = len(os.listdir(path_list_dir_p1))
p1_list = os.listdir(path_list_dir_p1)
p2_list = os.listdir(path_list_dir_p2)
fh_idx = np.arange(0, n_ims / 2, dtype=int)
sh_idx = np.arange(n_ims / 2, n_ims, dtype=int)
for idx in sh_idx:
sh_del = os.path.join(path_list_dir_p1, p1_list[idx])
os.remove(sh_del)
for idx in fh_idx:
fh_del = os.path.join(path_list_dir_p2, p2_list[idx])
os.remove(fh_del)
for i, source in zip(keys, modality_download_list):
directory = os.path.join(folder_path, source)
root = os.path.dirname(directory)
# Compress data
shutil.make_archive(directory, "zip", root_dir=root, base_dir=source)
# Compute checksums
with open(directory + ".zip", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(f"{directory}: {md5}")
name = i + ".zip"
md5s[name] = md5
tvt_split = pd.DataFrame(
[["1", "2", "3"], [np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]],
index=["0", "1", "2"],
columns=["train", "validation", "test"],
)
tvt_split.dropna()
tvt_split.to_csv(os.path.join(folder_path, "split_IDs.csv"))
with open(os.path.join(folder_path, "split_IDs.csv"), "rb") as f:
csv_md5 = hashlib.md5(f.read()).hexdigest()

Двоичные данные
tests/data/mapinwild/esa_wc/ESA_WC.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/mask/mask.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s1/s1_part1.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s1/s1_part2.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s2_autumn/s2_autumn_part1.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s2_autumn/s2_autumn_part2.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s2_spring/s2_spring_part1.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s2_spring/s2_spring_part2.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s2_summer/s2_summer_part1.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s2_summer/s2_summer_part2.zip Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s2_winter/s2_winter_part1.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/mapinwild/s2_winter/s2_winter_part2.zip Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,2 @@
,train,validation,test
0,1,1,1
1 train validation test
2 0 1 1 1

Двоичные данные
tests/data/mapinwild/viirs/VIIRS.zip Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import glob
import os
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import MapInWild
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestMapInWild:
@pytest.fixture(params=["train", "validation", "test"])
def dataset(
self, tmp_path: Path, monkeypatch: MonkeyPatch, request: SubRequest
) -> MapInWild:
monkeypatch.setattr(torchgeo.datasets.mapinwild, "download_url", download_url)
md5s = {
"ESA_WC.zip": "3a1e696353d238c50996958855da02fc",
"VIIRS.zip": "e8b0e230edb1183c02092357af83bd52",
"mask.zip": "15245bb6368d27dbb4bd16310f4604fa",
"s1_part1.zip": "e660da4175518af993b63644e44a9d03",
"s1_part2.zip": "620cf0a7d598a2893bc7642ad7ee6087",
"s2_autumn_part1.zip": "624b6cf0191c5e0bc0d51f92b568e676",
"s2_autumn_part2.zip": "f848c62b8de36f06f12fb6b1b065c7b6",
"s2_spring_part1.zip": "3296f3a7da7e485708dd16be91deb111",
"s2_spring_part2.zip": "d27e94387a59f0558fe142a791682861",
"s2_summer_part1.zip": "41d783706c3c1e4238556a772d3232fb",
"s2_summer_part2.zip": "3495c87b67a771cfac5153d1958daa0c",
"s2_temporal_subset_part1.zip": "06fa463888cb033011a06cf69f82273e",
"s2_temporal_subset_part2.zip": "93e5383adeeea27f00051ecf110fcef8",
"s2_winter_part1.zip": "617abe1c6ad8d38725aa27c9dcc38ceb",
"s2_winter_part2.zip": "4e40d7bb0eec4ddea0b7b00314239a49",
"split_IDs.csv": "ca22c3d30d0b62e001ed0c327c147127",
}
monkeypatch.setattr(MapInWild, "md5s", md5s)
urls = os.path.join("tests", "data", "mapinwild")
monkeypatch.setattr(MapInWild, "url", urls)
root = str(tmp_path)
split = request.param
transforms = nn.Identity()
modality = [
"mask",
"viirs",
"esa_wc",
"s2_winter",
"s1",
"s2_summer",
"s2_spring",
"s2_autumn",
"s2_temporal_subset",
]
return MapInWild(
root,
modality=modality,
split=split,
transforms=transforms,
download=True,
checksum=True,
)
def test_getitem(self, dataset: MapInWild) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
assert x["image"].ndim == 3
def test_len(self, dataset: MapInWild) -> None:
assert len(dataset) == 1
def test_add(self, dataset: MapInWild) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 2
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
MapInWild(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
MapInWild(root=str(tmp_path))
def test_downloaded_not_extracted(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "mapinwild", "**", "*.zip")
pathname_glob = glob.glob(pathname, recursive=True)
root = str(tmp_path)
for zipfile in pathname_glob:
shutil.copy(zipfile, root)
MapInWild(root, download=True)
def test_corrupted(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "mapinwild", "**", "*.zip")
pathname_glob = glob.glob(pathname, recursive=True)
root = str(tmp_path)
for zipfile in pathname_glob:
shutil.copy(zipfile, root)
splitfile = os.path.join(
"tests", "data", "mapinwild", "split_IDs", "split_IDs.csv"
)
shutil.copy(splitfile, root)
with open(os.path.join(tmp_path, "mask.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
MapInWild(root=str(tmp_path), download=True, checksum=True)
def test_already_downloaded(self, dataset: MapInWild, tmp_path: Path) -> None:
MapInWild(root=str(tmp_path), modality=dataset.modality, download=True)
def test_plot(self, dataset: MapInWild) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x["prediction"] = x["mask"].clone()
dataset.plot(x)
plt.close()

Просмотреть файл

@ -72,6 +72,7 @@ from .landsat import (
)
from .levircd import LEVIRCDPlus
from .loveda import LoveDA
from .mapinwild import MapInWild
from .millionaid import MillionAID
from .naip import NAIP
from .nasa_marine_debris import NASAMarineDebris
@ -194,6 +195,7 @@ __all__ = (
"LandCoverAI",
"LEVIRCDPlus",
"LoveDA",
"MapInWild",
"MillionAID",
"NASAMarineDebris",
"OSCD",

Просмотреть файл

@ -0,0 +1,407 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""MapInWild dataset."""
import os
import shutil
from collections import defaultdict
from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rasterio
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
from .utils import (
check_integrity,
download_url,
extract_archive,
percentile_normalization,
)
class MapInWild(NonGeoDataset):
"""MapInWild dataset.
The `MapInWild <https://ieeexplore.ieee.org/document/10089830>`__ dataset is
curated for the task of wilderness mapping on a pixel-level. MapInWild is a
multi-modal dataset and comprises various geodata acquired and formed from
different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season
Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging
Radiometer Suite NightTime Day/Night band. The dataset consists of 8144
images with the shape of 1920 × 1920 pixels. The images are weakly annotated
from the World Database of Protected Areas (WDPA).
Dataset features:
* 1018 areas globally sampled from the WDPA
* 10-Band Sentinel-2
* Dual-pol Sentinel-1
* ESA WorldCover Land Cover
* Visible Infrared Imaging Radiometer Suite NightTime Day/Night Band
If you use this dataset in your research, please cite the following paper:
* https://ieeexplore.ieee.org/document/10089830
.. versionadded:: 0.5
"""
url = "https://huggingface.co/datasets/burakekim/mapinwild/resolve/main/"
modality_urls = {
"esa_wc": {"esa_wc/ESA_WC.zip"},
"viirs": {"viirs/VIIRS.zip"},
"mask": {"mask/mask.zip"},
"s1": {"s1/s1_part1.zip", "s1/s1_part2.zip"},
"s2_temporal_subset": {
"s2_temporal_subset/s2_temporal_subset_part1.zip",
"s2_temporal_subset/s2_temporal_subset_part2.zip",
},
"s2_autumn": {"s2_autumn/s2_autumn_part1.zip", "s2_autumn/s2_autumn_part2.zip"},
"s2_spring": {"s2_spring/s2_spring_part1.zip", "s2_spring/s2_spring_part2.zip"},
"s2_summer": {"s2_summer/s2_summer_part1.zip", "s2_summer/s2_summer_part2.zip"},
"s2_winter": {"s2_winter/s2_winter_part1.zip", "s2_winter/s2_winter_part2.zip"},
"split_IDs": {"split_IDs/split_IDs.csv"},
}
md5s = {
"ESA_WC.zip": "72b2ee578fe10f0df85bdb7f19311c92",
"VIIRS.zip": "4eff014bae127fe536f8a5f17d89ecb4",
"mask.zip": "87c83a23a73998ad60d448d240b66225",
"s1_part1.zip": "d8a911f5c76b50eb0760b8f0047e4674",
"s1_part2.zip": "a30369d17c62d2af5aa52a4189590e3c",
"s2_temporal_subset_part1.zip": "78c2d05514458a036fe133f1e2f11d2a",
"s2_temporal_subset_part2.zip": "076cd3bd00eb5b7f5d80c9e0a0de0275",
"s2_autumn_part1.zip": "6ee7d1ac44b5107e3663636269aecf68",
"s2_autumn_part2.zip": "4fc5e1d5c772421dba553722433ac3b9",
"s2_spring_part1.zip": "2a89687d8fafa7fc7f5e641bfa97d472",
"s2_spring_part2.zip": "5845dcae0ab3cdc174b7c41edd4283a9",
"s2_summer_part1.zip": "73ca8291d3f4fb7533636220a816bb71",
"s2_summer_part2.zip": "5b5816bbd32987619bf72cde5cacd032",
"s2_winter_part1.zip": "ca958f7cd98e37cb59d6f3877573ee6d",
"s2_winter_part2.zip": "e7aacb0806d6d619b6abc408e6b09fdc",
"split_IDs.csv": "cb5c6c073702acee23544e1e6fe5856f",
}
mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)}
wc_cmap = {
10: (0, 160, 0),
20: (150, 100, 0),
30: (255, 180, 0),
40: (255, 255, 100),
50: (195, 20, 0),
60: (255, 245, 215),
70: (255, 255, 255),
80: (0, 70, 200),
90: (0, 220, 130),
95: (0, 150, 120),
100: (255, 235, 175),
}
def __init__(
self,
root: str = "data",
modality: list[str] = ["mask", "esa_wc", "viirs", "s2_summer"],
split: str = "train",
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new MapInWild dataset instance.
Args:
root: root directory where dataset can be found
modality: the modality to download. Choose from: "mask", "esa_wc",
"viirs", "s1", "s2_temporal_subset", "s2_[season]".
split: one of "train", "validation", or "test"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``split`` argument is invalid
"""
assert split in ["train", "validation", "test"]
self.checksum = checksum
self.root = root
self.transforms = transforms
self.modality = modality
self.download = download
modality.append("split_IDs")
for mode in modality:
for modality_link in self.modality_urls[mode]:
modality_url = os.path.join(self.url, modality_link)
self._verify(
url=modality_url, md5=self.md5s[os.path.split(modality_link)[-1]]
)
# Merge modalities downloaded in two parts
if (
download
and mode not in os.listdir(self.root)
and len(self.modality_urls[mode]) == 2
):
self._merge_parts(mode)
# Masks will be loaded seperately in the :meth:`__getitem__`
if "mask" in self.modality:
self.modality.remove("mask")
# Split IDs has been downloaded and is not needed in the list
if "split_IDs" in self.modality:
self.modality.remove("split_IDs")
if os.path.exists(os.path.join(self.root, "split_IDs.csv")):
split_dataframe = pd.read_csv(os.path.join(self.root, "split_IDs.csv"))
self.ids = split_dataframe[split].dropna().values.tolist()
self.ids = list(map(int, self.ids))
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
list_modalities = []
id = self.ids[index]
mask = self._load_raster(id, "mask")
mask[mask != 0] = 1
for mode in self.modality:
mode = mode.upper() if mode in ["esa_wc", "viirs"] else mode
data = self._load_raster(id, mode)
list_modalities.append(data)
image = torch.cat(list_modalities, dim=0)
sample: dict[str, Tensor] = {"image": image, "mask": mask}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.ids)
def _load_raster(self, filename: int, source: str) -> Tensor:
"""Load a single raster image or target.
Args:
filename: name of the file to load
source: the directory of the modality
Returns:
the raster image or target
"""
with rasterio.open(os.path.join(self.root, source, f"{filename}.tif")) as f:
raw_array = f.read()
array: "np.typing.NDArray[np.int_]" = np.stack(raw_array, axis=0)
if array.dtype == np.uint16:
array = array.astype(np.int32)
tensor = torch.from_numpy(array).float()
return tensor
def _verify(self, url: str, md5: Optional[str] = None) -> None:
"""Verify the integrity of the dataset.
Args:
url: url to the file
md5: md5 of the file to be verified
Raises:
RuntimeError: if dataset is not found
"""
modality_folder_name = url.split("/")[-1]
mod_fold_no_ext = modality_folder_name.split(".")[0]
modality_path = os.path.join(self.root, mod_fold_no_ext)
split_path = os.path.join(self.root, modality_folder_name)
if mod_fold_no_ext == "split_IDs":
modality_path = split_path
# Check if the files already exist
if os.path.exists(modality_path):
return
# Check if the zip files have already been downloaded, if so, extract
filepath = os.path.join(self.root, url.split("/")[-1])
if os.path.isfile(filepath) and filepath.endswith(".zip"):
if self.checksum and not check_integrity(filepath, md5):
raise RuntimeError("Dataset found, but corrupted.")
self._extract(url)
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` directory and `download=False`, " # noqa: E501
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
# Download the dataset
self._download(url, md5)
if not url.endswith(".csv"):
self._extract(url)
def _download(self, url: str, md5: Optional[str]) -> None:
"""Downloads a modality.
Args:
url: download url of a modality
md5: md5 of a modality
"""
download_url(
url,
self.root,
filename=os.path.split(url)[1],
md5=md5 if self.checksum else None,
)
def _extract(self, path: str) -> None:
"""Extracts a modality.
Args:
path: path to the modality folder
"""
filepath = os.path.join(self.root, os.path.split(path)[1])
extract_archive(filepath)
def _merge_parts(self, modality: str) -> None:
"""Merge the modalities that are downloaded and extracted in two parts.
Args:
root: root directory where dataset can be found
modality: the filename of the modality
"""
# Create a new folder named after the 'modality' variable
modality_folder = os.path.join(self.root, modality)
# Will not raise an error if the folder already exists
os.makedirs(modality_folder, exist_ok=True)
# List of source folders
source_folders = [
os.path.join(self.root, modality + "_part1"),
os.path.join(self.root, modality + "_part2"),
]
# Move files from each source folder to the new 'modality' folder
for source_folder in source_folders:
for file_name in os.listdir(source_folder):
source = os.path.join(source_folder, file_name)
destination = os.path.join(modality_folder, file_name)
if os.path.isfile(source):
shutil.copy(source, destination) # Move files to 'modality' folder
def _convert_to_color(
self, arr_2d: Tensor, cmap: dict[int, tuple[int, int, int]]
) -> "np.typing.NDArray[np.uint8]":
"""Numeric labels to RGB-color encoding.
Args:
arr_2d: 2D array to be colorized
cmap: colormap to use when mapping the labels
Returns:
3D colored image
"""
arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)
for c, i in cmap.items():
m = arr_2d == c
arr_3d[m] = i
return arr_3d
def plot(
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample image-mask pair returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
"""
modality_channels = defaultdict(lambda: 10, {"viirs": 1, "esa_wc": 1, "s1": 2})
start_idx = 0
split_images = {}
for modality in self.modality:
end_idx = start_idx + modality_channels[modality] # Start + n of channels
split_images[modality] = sample["image"][start_idx:end_idx, :, :] # Slicing
start_idx = end_idx # Update the iterator
# Prepare the mask
mask = sample["mask"].squeeze()
color_mask = self._convert_to_color(mask, cmap=self.mask_cmap)
num_subplots = len(split_images) + 1 # +1 for color_mask
showing_predictions = "prediction" in sample
if showing_predictions:
num_subplots += 1
fig, axs = plt.subplots(1, num_subplots, figsize=(num_subplots * 4, 5))
# Plot each modality in its respective axis
for i, (modality, image) in enumerate(split_images.items()):
ax = axs[i]
img = np.transpose(image, (1, 2, 0)).squeeze()
# Apply transformations based on modality type
if modality.startswith("s2"):
img = img[:, :, [4, 3, 2]]
if modality == "esa_wc":
img = self._convert_to_color(torch.as_tensor(img), cmap=self.wc_cmap)
if modality == "s1":
img = img[:, :, 0]
if not "esa_wc":
img = percentile_normalization(img)
ax.imshow(img)
if show_titles:
ax.set_title(modality)
ax.axis("off")
# Plot color_mask in its own axis
axs[len(split_images)].imshow(color_mask)
if show_titles:
axs[len(split_images)].set_title("Annotation")
axs[len(split_images)].axis("off")
# If available, plot predictions in a new axis
if showing_predictions:
prediction = sample["prediction"].squeeze()
color_predictions = self._convert_to_color(prediction, cmap=self.mask_cmap)
axs[-1].imshow(color_predictions, vmin=0, vmax=1, interpolation="none")
if show_titles:
axs[-1].set_title("Prediction")
axs[-1].axis("off")
plt.tight_layout()
return fig