зеркало из https://github.com/microsoft/torchgeo.git
Add MapInWild dataset (#1131)
* add mapinwild dataset * add copyright and move the header * Apply suggestions from code review accept suggestions Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * add spaces between sections * test_black * test_isort * test_flake8 * dataset instance and test data Improves the dataset instance and test data script. * improvements in test script and dataset class * update test data * Update mapinwild.py * Update mapinwild.py fix typo * Update mapinwild.py * improved test coverage and bug fixes * improved test coverage * lazy import pandas * test coverage for pandas * test coverage for pandas * Changes after the review The changes made after the review. * delete data * fix mypy * fix mypy * delete residual files * fix type hinting * Apply suggestions from code review Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Apply suggestions from code review Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Apply suggestions from code review Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * fix RtD * fix RtD * fix syntax * fix RtD * fix file namings * modality naming fix * fix hidden method naming * Update torchgeo/datasets/mapinwild.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * address reviews * refactoring and testing - passing test_download with monkeypatch - refactoring - addressing most of the comments * fix mypy * syntax and type conversion * Update torchgeo/datasets/mapinwild.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * addressing the comments * fix mypy plt.Figure not defined * make the _merge_parts slimmer * pandas and reviews * monkeypatch tvt sets * Simplify MonkeyPatch import --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
b6d78b74c3
Коммит
c51014c656
|
@ -257,6 +257,11 @@ LoveDA
|
|||
|
||||
.. autoclass:: LoveDA
|
||||
|
||||
MapInWild
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: MapInWild
|
||||
|
||||
Million-AID
|
||||
^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
|
|||
`LandCover.ai`_,S,Aerial,"10,674",5,512x512,0.25--0.5,RGB
|
||||
`LEVIR-CD+`_,CD,Google Earth,985,2,"1,024x1,024",0.5,RGB
|
||||
`LoveDA`_,S,Google Earth,"5,987",7,"1,024x1,024",0.3,RGB
|
||||
`MapInWild`_,S,"Sentinel-1/2, ESA WorldCover, NOAA VIIRS DNB",1018,1,1920x1920,10--463.83,"SAR, MSI, 2020_Map, avg_rad"
|
||||
`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB
|
||||
`NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB
|
||||
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
|
||||
|
|
|
|
@ -0,0 +1,166 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import rasterio
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.transform import Affine
|
||||
|
||||
SIZE = 32
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
meta = {
|
||||
"driver": "GTiff",
|
||||
"nodata": None,
|
||||
"width": SIZE,
|
||||
"height": SIZE,
|
||||
"crs": CRS.from_epsg(32720),
|
||||
"transform": Affine(10.0, 0.0, 612190.0, 0.0, -10.0, 7324250.0),
|
||||
}
|
||||
|
||||
count = {
|
||||
"ESA_WC": 1,
|
||||
"VIIRS": 1,
|
||||
"mask": 1,
|
||||
"s1_part1": 2,
|
||||
"s1_part2": 2,
|
||||
"s2_temporal_subset_part1": 10,
|
||||
"s2_temporal_subset_part2": 10,
|
||||
"s2_autumn_part1": 10,
|
||||
"s2_autumn_part2": 10,
|
||||
"s2_spring_part1": 10,
|
||||
"s2_spring_part2": 10,
|
||||
"s2_summer_part1": 10,
|
||||
"s2_summer_part2": 10,
|
||||
"s2_winter_part1": 10,
|
||||
"s2_winter_part2": 10,
|
||||
}
|
||||
dtype = {
|
||||
"ESA_WC": np.uint8,
|
||||
"VIIRS": np.float32,
|
||||
"mask": np.byte,
|
||||
"s1_part1": np.float64,
|
||||
"s1_part2": np.float64,
|
||||
"s2_temporal_subset_part1": np.uint16,
|
||||
"s2_temporal_subset_part2": np.uint16,
|
||||
"s2_autumn_part1": np.uint16,
|
||||
"s2_autumn_part2": np.uint16,
|
||||
"s2_spring_part1": np.uint16,
|
||||
"s2_spring_part2": np.uint16,
|
||||
"s2_summer_part1": np.uint16,
|
||||
"s2_summer_part2": np.uint16,
|
||||
"s2_winter_part1": np.uint16,
|
||||
"s2_winter_part2": np.uint16,
|
||||
}
|
||||
stop = {
|
||||
"ESA_WC": np.iinfo(np.uint8).max,
|
||||
"VIIRS": np.finfo(np.float32).max,
|
||||
"mask": np.iinfo(np.byte).max,
|
||||
"s1_part1": np.finfo(np.float64).max,
|
||||
"s1_part2": np.finfo(np.float64).max,
|
||||
"s2_temporal_subset_part1": np.iinfo(np.uint16).max,
|
||||
"s2_temporal_subset_part2": np.iinfo(np.uint16).max,
|
||||
"s2_autumn_part1": np.iinfo(np.uint16).max,
|
||||
"s2_autumn_part2": np.iinfo(np.uint16).max,
|
||||
"s2_spring_part1": np.iinfo(np.uint16).max,
|
||||
"s2_spring_part2": np.iinfo(np.uint16).max,
|
||||
"s2_summer_part1": np.iinfo(np.uint16).max,
|
||||
"s2_summer_part2": np.iinfo(np.uint16).max,
|
||||
"s2_winter_part1": np.iinfo(np.uint16).max,
|
||||
"s2_winter_part2": np.iinfo(np.uint16).max,
|
||||
}
|
||||
|
||||
folder_path = os.path.join(os.getcwd(), "tests", "data", "mapinwild")
|
||||
|
||||
dict_all = {
|
||||
"s2_sum": ["s2_summer_part1", "s2_summer_part2"],
|
||||
"s2_spr": ["s2_spring_part1", "s2_spring_part2"],
|
||||
"s2_win": ["s2_winter_part1", "s2_winter_part2"],
|
||||
"s2_aut": ["s2_autumn_part1", "s2_autumn_part2"],
|
||||
"s1": ["s1_part1", "s1_part2"],
|
||||
"s2_temp": ["s2_temporal_subset_part1", "s2_temporal_subset_part2"],
|
||||
}
|
||||
|
||||
md5s = {}
|
||||
keys = count.keys()
|
||||
modality_download_list = list(count.keys())
|
||||
|
||||
for source in modality_download_list:
|
||||
directory = os.path.join(folder_path, source)
|
||||
|
||||
# Remove old data
|
||||
if os.path.exists(directory):
|
||||
shutil.rmtree(directory)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
# Random images
|
||||
for i in range(1, 3):
|
||||
filename = f"{i}.tif"
|
||||
filepath = os.path.join(directory, filename)
|
||||
|
||||
meta["count"] = count[source]
|
||||
meta["dtype"] = dtype[source]
|
||||
with rasterio.open(filepath, "w", **meta) as f:
|
||||
for j in range(1, count[source] + 1):
|
||||
if meta["dtype"] is np.float32 or meta["dtype"] is np.float64:
|
||||
data = np.random.randn(SIZE, SIZE).astype(dtype[source])
|
||||
|
||||
else:
|
||||
data = np.random.randint(stop[source], size=(SIZE, SIZE)).astype(
|
||||
dtype[source]
|
||||
)
|
||||
f.write(data, j)
|
||||
|
||||
# Mimic the two-part structure of the dataset
|
||||
for key in dict_all.keys():
|
||||
path_list = dict_all[key]
|
||||
path_list_dir_p1 = os.path.join(folder_path, path_list[0])
|
||||
path_list_dir_p2 = os.path.join(folder_path, path_list[1])
|
||||
n_ims = len(os.listdir(path_list_dir_p1))
|
||||
|
||||
p1_list = os.listdir(path_list_dir_p1)
|
||||
p2_list = os.listdir(path_list_dir_p2)
|
||||
|
||||
fh_idx = np.arange(0, n_ims / 2, dtype=int)
|
||||
sh_idx = np.arange(n_ims / 2, n_ims, dtype=int)
|
||||
|
||||
for idx in sh_idx:
|
||||
sh_del = os.path.join(path_list_dir_p1, p1_list[idx])
|
||||
os.remove(sh_del)
|
||||
|
||||
for idx in fh_idx:
|
||||
fh_del = os.path.join(path_list_dir_p2, p2_list[idx])
|
||||
os.remove(fh_del)
|
||||
|
||||
for i, source in zip(keys, modality_download_list):
|
||||
directory = os.path.join(folder_path, source)
|
||||
root = os.path.dirname(directory)
|
||||
|
||||
# Compress data
|
||||
shutil.make_archive(directory, "zip", root_dir=root, base_dir=source)
|
||||
|
||||
# Compute checksums
|
||||
with open(directory + ".zip", "rb") as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f"{directory}: {md5}")
|
||||
name = i + ".zip"
|
||||
md5s[name] = md5
|
||||
|
||||
tvt_split = pd.DataFrame(
|
||||
[["1", "2", "3"], [np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]],
|
||||
index=["0", "1", "2"],
|
||||
columns=["train", "validation", "test"],
|
||||
)
|
||||
tvt_split.dropna()
|
||||
tvt_split.to_csv(os.path.join(folder_path, "split_IDs.csv"))
|
||||
|
||||
with open(os.path.join(folder_path, "split_IDs.csv"), "rb") as f:
|
||||
csv_md5 = hashlib.md5(f.read()).hexdigest()
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,2 @@
|
|||
,train,validation,test
|
||||
0,1,1,1
|
|
Двоичный файл не отображается.
|
@ -0,0 +1,137 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import MapInWild
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestMapInWild:
|
||||
@pytest.fixture(params=["train", "validation", "test"])
|
||||
def dataset(
|
||||
self, tmp_path: Path, monkeypatch: MonkeyPatch, request: SubRequest
|
||||
) -> MapInWild:
|
||||
monkeypatch.setattr(torchgeo.datasets.mapinwild, "download_url", download_url)
|
||||
|
||||
md5s = {
|
||||
"ESA_WC.zip": "3a1e696353d238c50996958855da02fc",
|
||||
"VIIRS.zip": "e8b0e230edb1183c02092357af83bd52",
|
||||
"mask.zip": "15245bb6368d27dbb4bd16310f4604fa",
|
||||
"s1_part1.zip": "e660da4175518af993b63644e44a9d03",
|
||||
"s1_part2.zip": "620cf0a7d598a2893bc7642ad7ee6087",
|
||||
"s2_autumn_part1.zip": "624b6cf0191c5e0bc0d51f92b568e676",
|
||||
"s2_autumn_part2.zip": "f848c62b8de36f06f12fb6b1b065c7b6",
|
||||
"s2_spring_part1.zip": "3296f3a7da7e485708dd16be91deb111",
|
||||
"s2_spring_part2.zip": "d27e94387a59f0558fe142a791682861",
|
||||
"s2_summer_part1.zip": "41d783706c3c1e4238556a772d3232fb",
|
||||
"s2_summer_part2.zip": "3495c87b67a771cfac5153d1958daa0c",
|
||||
"s2_temporal_subset_part1.zip": "06fa463888cb033011a06cf69f82273e",
|
||||
"s2_temporal_subset_part2.zip": "93e5383adeeea27f00051ecf110fcef8",
|
||||
"s2_winter_part1.zip": "617abe1c6ad8d38725aa27c9dcc38ceb",
|
||||
"s2_winter_part2.zip": "4e40d7bb0eec4ddea0b7b00314239a49",
|
||||
"split_IDs.csv": "ca22c3d30d0b62e001ed0c327c147127",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(MapInWild, "md5s", md5s)
|
||||
|
||||
urls = os.path.join("tests", "data", "mapinwild")
|
||||
monkeypatch.setattr(MapInWild, "url", urls)
|
||||
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
|
||||
transforms = nn.Identity()
|
||||
modality = [
|
||||
"mask",
|
||||
"viirs",
|
||||
"esa_wc",
|
||||
"s2_winter",
|
||||
"s1",
|
||||
"s2_summer",
|
||||
"s2_spring",
|
||||
"s2_autumn",
|
||||
"s2_temporal_subset",
|
||||
]
|
||||
return MapInWild(
|
||||
root,
|
||||
modality=modality,
|
||||
split=split,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
checksum=True,
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: MapInWild) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
assert x["image"].ndim == 3
|
||||
|
||||
def test_len(self, dataset: MapInWild) -> None:
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_add(self, dataset: MapInWild) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ConcatDataset)
|
||||
assert len(ds) == 2
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
MapInWild(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
MapInWild(root=str(tmp_path))
|
||||
|
||||
def test_downloaded_not_extracted(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join("tests", "data", "mapinwild", "**", "*.zip")
|
||||
pathname_glob = glob.glob(pathname, recursive=True)
|
||||
root = str(tmp_path)
|
||||
for zipfile in pathname_glob:
|
||||
shutil.copy(zipfile, root)
|
||||
MapInWild(root, download=True)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join("tests", "data", "mapinwild", "**", "*.zip")
|
||||
pathname_glob = glob.glob(pathname, recursive=True)
|
||||
root = str(tmp_path)
|
||||
for zipfile in pathname_glob:
|
||||
shutil.copy(zipfile, root)
|
||||
splitfile = os.path.join(
|
||||
"tests", "data", "mapinwild", "split_IDs", "split_IDs.csv"
|
||||
)
|
||||
shutil.copy(splitfile, root)
|
||||
with open(os.path.join(tmp_path, "mask.zip"), "w") as f:
|
||||
f.write("bad")
|
||||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
|
||||
MapInWild(root=str(tmp_path), download=True, checksum=True)
|
||||
|
||||
def test_already_downloaded(self, dataset: MapInWild, tmp_path: Path) -> None:
|
||||
MapInWild(root=str(tmp_path), modality=dataset.modality, download=True)
|
||||
|
||||
def test_plot(self, dataset: MapInWild) -> None:
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, suptitle="Test")
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
x["prediction"] = x["mask"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
|
@ -72,6 +72,7 @@ from .landsat import (
|
|||
)
|
||||
from .levircd import LEVIRCDPlus
|
||||
from .loveda import LoveDA
|
||||
from .mapinwild import MapInWild
|
||||
from .millionaid import MillionAID
|
||||
from .naip import NAIP
|
||||
from .nasa_marine_debris import NASAMarineDebris
|
||||
|
@ -194,6 +195,7 @@ __all__ = (
|
|||
"LandCoverAI",
|
||||
"LEVIRCDPlus",
|
||||
"LoveDA",
|
||||
"MapInWild",
|
||||
"MillionAID",
|
||||
"NASAMarineDebris",
|
||||
"OSCD",
|
||||
|
|
|
@ -0,0 +1,407 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""MapInWild dataset."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import (
|
||||
check_integrity,
|
||||
download_url,
|
||||
extract_archive,
|
||||
percentile_normalization,
|
||||
)
|
||||
|
||||
|
||||
class MapInWild(NonGeoDataset):
|
||||
"""MapInWild dataset.
|
||||
|
||||
The `MapInWild <https://ieeexplore.ieee.org/document/10089830>`__ dataset is
|
||||
curated for the task of wilderness mapping on a pixel-level. MapInWild is a
|
||||
multi-modal dataset and comprises various geodata acquired and formed from
|
||||
different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season
|
||||
Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging
|
||||
Radiometer Suite NightTime Day/Night band. The dataset consists of 8144
|
||||
images with the shape of 1920 × 1920 pixels. The images are weakly annotated
|
||||
from the World Database of Protected Areas (WDPA).
|
||||
|
||||
Dataset features:
|
||||
|
||||
* 1018 areas globally sampled from the WDPA
|
||||
* 10-Band Sentinel-2
|
||||
* Dual-pol Sentinel-1
|
||||
* ESA WorldCover Land Cover
|
||||
* Visible Infrared Imaging Radiometer Suite NightTime Day/Night Band
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://ieeexplore.ieee.org/document/10089830
|
||||
|
||||
.. versionadded:: 0.5
|
||||
"""
|
||||
|
||||
url = "https://huggingface.co/datasets/burakekim/mapinwild/resolve/main/"
|
||||
|
||||
modality_urls = {
|
||||
"esa_wc": {"esa_wc/ESA_WC.zip"},
|
||||
"viirs": {"viirs/VIIRS.zip"},
|
||||
"mask": {"mask/mask.zip"},
|
||||
"s1": {"s1/s1_part1.zip", "s1/s1_part2.zip"},
|
||||
"s2_temporal_subset": {
|
||||
"s2_temporal_subset/s2_temporal_subset_part1.zip",
|
||||
"s2_temporal_subset/s2_temporal_subset_part2.zip",
|
||||
},
|
||||
"s2_autumn": {"s2_autumn/s2_autumn_part1.zip", "s2_autumn/s2_autumn_part2.zip"},
|
||||
"s2_spring": {"s2_spring/s2_spring_part1.zip", "s2_spring/s2_spring_part2.zip"},
|
||||
"s2_summer": {"s2_summer/s2_summer_part1.zip", "s2_summer/s2_summer_part2.zip"},
|
||||
"s2_winter": {"s2_winter/s2_winter_part1.zip", "s2_winter/s2_winter_part2.zip"},
|
||||
"split_IDs": {"split_IDs/split_IDs.csv"},
|
||||
}
|
||||
|
||||
md5s = {
|
||||
"ESA_WC.zip": "72b2ee578fe10f0df85bdb7f19311c92",
|
||||
"VIIRS.zip": "4eff014bae127fe536f8a5f17d89ecb4",
|
||||
"mask.zip": "87c83a23a73998ad60d448d240b66225",
|
||||
"s1_part1.zip": "d8a911f5c76b50eb0760b8f0047e4674",
|
||||
"s1_part2.zip": "a30369d17c62d2af5aa52a4189590e3c",
|
||||
"s2_temporal_subset_part1.zip": "78c2d05514458a036fe133f1e2f11d2a",
|
||||
"s2_temporal_subset_part2.zip": "076cd3bd00eb5b7f5d80c9e0a0de0275",
|
||||
"s2_autumn_part1.zip": "6ee7d1ac44b5107e3663636269aecf68",
|
||||
"s2_autumn_part2.zip": "4fc5e1d5c772421dba553722433ac3b9",
|
||||
"s2_spring_part1.zip": "2a89687d8fafa7fc7f5e641bfa97d472",
|
||||
"s2_spring_part2.zip": "5845dcae0ab3cdc174b7c41edd4283a9",
|
||||
"s2_summer_part1.zip": "73ca8291d3f4fb7533636220a816bb71",
|
||||
"s2_summer_part2.zip": "5b5816bbd32987619bf72cde5cacd032",
|
||||
"s2_winter_part1.zip": "ca958f7cd98e37cb59d6f3877573ee6d",
|
||||
"s2_winter_part2.zip": "e7aacb0806d6d619b6abc408e6b09fdc",
|
||||
"split_IDs.csv": "cb5c6c073702acee23544e1e6fe5856f",
|
||||
}
|
||||
|
||||
mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)}
|
||||
|
||||
wc_cmap = {
|
||||
10: (0, 160, 0),
|
||||
20: (150, 100, 0),
|
||||
30: (255, 180, 0),
|
||||
40: (255, 255, 100),
|
||||
50: (195, 20, 0),
|
||||
60: (255, 245, 215),
|
||||
70: (255, 255, 255),
|
||||
80: (0, 70, 200),
|
||||
90: (0, 220, 130),
|
||||
95: (0, 150, 120),
|
||||
100: (255, 235, 175),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
modality: list[str] = ["mask", "esa_wc", "viirs", "s2_summer"],
|
||||
split: str = "train",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new MapInWild dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
modality: the modality to download. Choose from: "mask", "esa_wc",
|
||||
"viirs", "s1", "s2_temporal_subset", "s2_[season]".
|
||||
split: one of "train", "validation", or "test"
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
"""
|
||||
assert split in ["train", "validation", "test"]
|
||||
|
||||
self.checksum = checksum
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
self.modality = modality
|
||||
self.download = download
|
||||
|
||||
modality.append("split_IDs")
|
||||
for mode in modality:
|
||||
for modality_link in self.modality_urls[mode]:
|
||||
modality_url = os.path.join(self.url, modality_link)
|
||||
self._verify(
|
||||
url=modality_url, md5=self.md5s[os.path.split(modality_link)[-1]]
|
||||
)
|
||||
|
||||
# Merge modalities downloaded in two parts
|
||||
if (
|
||||
download
|
||||
and mode not in os.listdir(self.root)
|
||||
and len(self.modality_urls[mode]) == 2
|
||||
):
|
||||
self._merge_parts(mode)
|
||||
|
||||
# Masks will be loaded seperately in the :meth:`__getitem__`
|
||||
if "mask" in self.modality:
|
||||
self.modality.remove("mask")
|
||||
|
||||
# Split IDs has been downloaded and is not needed in the list
|
||||
if "split_IDs" in self.modality:
|
||||
self.modality.remove("split_IDs")
|
||||
|
||||
if os.path.exists(os.path.join(self.root, "split_IDs.csv")):
|
||||
split_dataframe = pd.read_csv(os.path.join(self.root, "split_IDs.csv"))
|
||||
self.ids = split_dataframe[split].dropna().values.tolist()
|
||||
self.ids = list(map(int, self.ids))
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
list_modalities = []
|
||||
id = self.ids[index]
|
||||
|
||||
mask = self._load_raster(id, "mask")
|
||||
mask[mask != 0] = 1
|
||||
|
||||
for mode in self.modality:
|
||||
mode = mode.upper() if mode in ["esa_wc", "viirs"] else mode
|
||||
data = self._load_raster(id, mode)
|
||||
list_modalities.append(data)
|
||||
|
||||
image = torch.cat(list_modalities, dim=0)
|
||||
|
||||
sample: dict[str, Tensor] = {"image": image, "mask": mask}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.ids)
|
||||
|
||||
def _load_raster(self, filename: int, source: str) -> Tensor:
|
||||
"""Load a single raster image or target.
|
||||
|
||||
Args:
|
||||
filename: name of the file to load
|
||||
source: the directory of the modality
|
||||
|
||||
Returns:
|
||||
the raster image or target
|
||||
"""
|
||||
with rasterio.open(os.path.join(self.root, source, f"{filename}.tif")) as f:
|
||||
raw_array = f.read()
|
||||
array: "np.typing.NDArray[np.int_]" = np.stack(raw_array, axis=0)
|
||||
if array.dtype == np.uint16:
|
||||
array = array.astype(np.int32)
|
||||
tensor = torch.from_numpy(array).float()
|
||||
return tensor
|
||||
|
||||
def _verify(self, url: str, md5: Optional[str] = None) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Args:
|
||||
url: url to the file
|
||||
md5: md5 of the file to be verified
|
||||
|
||||
Raises:
|
||||
RuntimeError: if dataset is not found
|
||||
"""
|
||||
modality_folder_name = url.split("/")[-1]
|
||||
mod_fold_no_ext = modality_folder_name.split(".")[0]
|
||||
modality_path = os.path.join(self.root, mod_fold_no_ext)
|
||||
split_path = os.path.join(self.root, modality_folder_name)
|
||||
if mod_fold_no_ext == "split_IDs":
|
||||
modality_path = split_path
|
||||
|
||||
# Check if the files already exist
|
||||
if os.path.exists(modality_path):
|
||||
return
|
||||
|
||||
# Check if the zip files have already been downloaded, if so, extract
|
||||
filepath = os.path.join(self.root, url.split("/")[-1])
|
||||
if os.path.isfile(filepath) and filepath.endswith(".zip"):
|
||||
if self.checksum and not check_integrity(filepath, md5):
|
||||
raise RuntimeError("Dataset found, but corrupted.")
|
||||
self._extract(url)
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` directory and `download=False`, " # noqa: E501
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
|
||||
# Download the dataset
|
||||
self._download(url, md5)
|
||||
if not url.endswith(".csv"):
|
||||
self._extract(url)
|
||||
|
||||
def _download(self, url: str, md5: Optional[str]) -> None:
|
||||
"""Downloads a modality.
|
||||
|
||||
Args:
|
||||
url: download url of a modality
|
||||
md5: md5 of a modality
|
||||
"""
|
||||
download_url(
|
||||
url,
|
||||
self.root,
|
||||
filename=os.path.split(url)[1],
|
||||
md5=md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def _extract(self, path: str) -> None:
|
||||
"""Extracts a modality.
|
||||
|
||||
Args:
|
||||
path: path to the modality folder
|
||||
"""
|
||||
filepath = os.path.join(self.root, os.path.split(path)[1])
|
||||
extract_archive(filepath)
|
||||
|
||||
def _merge_parts(self, modality: str) -> None:
|
||||
"""Merge the modalities that are downloaded and extracted in two parts.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
modality: the filename of the modality
|
||||
"""
|
||||
# Create a new folder named after the 'modality' variable
|
||||
modality_folder = os.path.join(self.root, modality)
|
||||
# Will not raise an error if the folder already exists
|
||||
os.makedirs(modality_folder, exist_ok=True)
|
||||
|
||||
# List of source folders
|
||||
source_folders = [
|
||||
os.path.join(self.root, modality + "_part1"),
|
||||
os.path.join(self.root, modality + "_part2"),
|
||||
]
|
||||
|
||||
# Move files from each source folder to the new 'modality' folder
|
||||
for source_folder in source_folders:
|
||||
for file_name in os.listdir(source_folder):
|
||||
source = os.path.join(source_folder, file_name)
|
||||
destination = os.path.join(modality_folder, file_name)
|
||||
if os.path.isfile(source):
|
||||
shutil.copy(source, destination) # Move files to 'modality' folder
|
||||
|
||||
def _convert_to_color(
|
||||
self, arr_2d: Tensor, cmap: dict[int, tuple[int, int, int]]
|
||||
) -> "np.typing.NDArray[np.uint8]":
|
||||
"""Numeric labels to RGB-color encoding.
|
||||
|
||||
Args:
|
||||
arr_2d: 2D array to be colorized
|
||||
cmap: colormap to use when mapping the labels
|
||||
|
||||
Returns:
|
||||
3D colored image
|
||||
"""
|
||||
arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)
|
||||
|
||||
for c, i in cmap.items():
|
||||
m = arr_2d == c
|
||||
arr_3d[m] = i
|
||||
return arr_3d
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample image-mask pair returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
modality_channels = defaultdict(lambda: 10, {"viirs": 1, "esa_wc": 1, "s1": 2})
|
||||
|
||||
start_idx = 0
|
||||
split_images = {}
|
||||
|
||||
for modality in self.modality:
|
||||
end_idx = start_idx + modality_channels[modality] # Start + n of channels
|
||||
split_images[modality] = sample["image"][start_idx:end_idx, :, :] # Slicing
|
||||
start_idx = end_idx # Update the iterator
|
||||
|
||||
# Prepare the mask
|
||||
mask = sample["mask"].squeeze()
|
||||
color_mask = self._convert_to_color(mask, cmap=self.mask_cmap)
|
||||
|
||||
num_subplots = len(split_images) + 1 # +1 for color_mask
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
num_subplots += 1
|
||||
|
||||
fig, axs = plt.subplots(1, num_subplots, figsize=(num_subplots * 4, 5))
|
||||
|
||||
# Plot each modality in its respective axis
|
||||
for i, (modality, image) in enumerate(split_images.items()):
|
||||
ax = axs[i]
|
||||
img = np.transpose(image, (1, 2, 0)).squeeze()
|
||||
# Apply transformations based on modality type
|
||||
if modality.startswith("s2"):
|
||||
img = img[:, :, [4, 3, 2]]
|
||||
if modality == "esa_wc":
|
||||
img = self._convert_to_color(torch.as_tensor(img), cmap=self.wc_cmap)
|
||||
if modality == "s1":
|
||||
img = img[:, :, 0]
|
||||
|
||||
if not "esa_wc":
|
||||
img = percentile_normalization(img)
|
||||
|
||||
ax.imshow(img)
|
||||
if show_titles:
|
||||
ax.set_title(modality)
|
||||
ax.axis("off")
|
||||
|
||||
# Plot color_mask in its own axis
|
||||
axs[len(split_images)].imshow(color_mask)
|
||||
if show_titles:
|
||||
axs[len(split_images)].set_title("Annotation")
|
||||
axs[len(split_images)].axis("off")
|
||||
|
||||
# If available, plot predictions in a new axis
|
||||
if showing_predictions:
|
||||
prediction = sample["prediction"].squeeze()
|
||||
color_predictions = self._convert_to_color(prediction, cmap=self.mask_cmap)
|
||||
axs[-1].imshow(color_predictions, vmin=0, vmax=1, interpolation="none")
|
||||
if show_titles:
|
||||
axs[-1].set_title("Prediction")
|
||||
axs[-1].axis("off")
|
||||
|
||||
plt.tight_layout()
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче