зеркало из https://github.com/microsoft/torchgeo.git
PASTIS dataset (#315)
* draft * add dataset to __init__ * reorganize datasets and datamodules * fix mypy errors * draft * add dataset to __init__ * reorganize datasets and datamodules * fix mypy errors * refactor * Adding docs * Adding plotting, cleaning up some stuff * Black and isort * Fix the datamodule import * Pyupgrade * Fixing some docstrings * Flake8 * Isort * Fix docstrings in datamodules * Fixing fns and docstring * Trying to fix the docs * Trying to fix docs * Adding tests * Black * newline * Made the test dataset larger * Remove the datamodules * Update docs/api/non_geo_datasets.csv Co-authored-by: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Updating cmap * Describe the different band combinations * Merging datasets * Handle the instance segmentation case in plotting * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Made some code prettier * Adding instance plotting --------- Co-authored-by: Caleb Robinson <calebrob6@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
f0cacd5402
Коммит
711a576e38
|
@ -272,6 +272,11 @@ OSCD
|
|||
|
||||
.. autoclass:: OSCD
|
||||
|
||||
PASTIS
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: PASTIS
|
||||
|
||||
PatternNet
|
||||
^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
|
|||
`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB
|
||||
`NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB
|
||||
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
|
||||
`PASTIS`_,I,Sentinel-1/2,"2,433",19,128x128xT,10,MSI
|
||||
`PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB
|
||||
`Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI
|
||||
`ReforesTree`_,"OD, R",Aerial,100,6,"4,000x4,000",0.02,RGB
|
||||
|
|
|
Двоичный файл не отображается.
|
@ -0,0 +1,91 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
from typing import Union
|
||||
|
||||
import fiona
|
||||
import numpy as np
|
||||
|
||||
SIZE = 32
|
||||
NUM_SAMPLES = 5
|
||||
MAX_NUM_TIME_STEPS = 10
|
||||
np.random.seed(0)
|
||||
|
||||
FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]]
|
||||
|
||||
filenames: FILENAME_HIERARCHY = {
|
||||
"DATA_S2": ["S2"],
|
||||
"DATA_S1A": ["S1A"],
|
||||
"DATA_S1D": ["S1D"],
|
||||
"ANNOTATIONS": ["TARGET"],
|
||||
"INSTANCE_ANNOTATIONS": ["INSTANCES"],
|
||||
}
|
||||
|
||||
|
||||
def create_file(path: str) -> None:
|
||||
for i in range(NUM_SAMPLES):
|
||||
new_path = f"{path}_{i}.npy"
|
||||
fn = os.path.basename(new_path)
|
||||
t = np.random.randint(1, MAX_NUM_TIME_STEPS)
|
||||
if fn.startswith("S2"):
|
||||
data = np.random.randint(0, 256, size=(t, 10, SIZE, SIZE)).astype(np.int16)
|
||||
elif fn.startswith("S1A"):
|
||||
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16)
|
||||
elif fn.startswith("S1D"):
|
||||
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16)
|
||||
elif fn.startswith("TARGET"):
|
||||
data = np.random.randint(0, 20, size=(3, SIZE, SIZE)).astype(np.uint8)
|
||||
elif fn.startswith("INSTANCES"):
|
||||
data = np.random.randint(0, 100, size=(SIZE, SIZE)).astype(np.int64)
|
||||
np.save(new_path, data)
|
||||
|
||||
|
||||
def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
|
||||
if isinstance(hierarchy, dict):
|
||||
# Recursive case
|
||||
for key, value in hierarchy.items():
|
||||
path = os.path.join(directory, key)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
create_directory(path, value)
|
||||
else:
|
||||
# Base case
|
||||
for value in hierarchy:
|
||||
path = os.path.join(directory, value)
|
||||
create_file(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_directory("PASTIS-R", filenames)
|
||||
|
||||
schema = {"geometry": "Polygon", "properties": {"Fold": "int", "ID_PATCH": "int"}}
|
||||
with fiona.open(
|
||||
os.path.join("PASTIS-R", "metadata.geojson"),
|
||||
"w",
|
||||
"GeoJSON",
|
||||
crs="EPSG:4326",
|
||||
schema=schema,
|
||||
) as f:
|
||||
for i in range(NUM_SAMPLES):
|
||||
f.write(
|
||||
{
|
||||
"geometry": {
|
||||
"type": "Polygon",
|
||||
"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]],
|
||||
},
|
||||
"id": str(i),
|
||||
"properties": {"Fold": i % 5, "ID_PATCH": i},
|
||||
}
|
||||
)
|
||||
|
||||
filename = "PASTIS-R.zip"
|
||||
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", "PASTIS-R")
|
||||
|
||||
# Compute checksums
|
||||
with open(filename, "rb") as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f"{filename}: {md5}")
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import PASTIS
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestPASTIS:
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
{"folds": (0, 1), "bands": "s2", "mode": "semantic"},
|
||||
{"folds": (0, 1), "bands": "s1a", "mode": "semantic"},
|
||||
{"folds": (0, 1), "bands": "s1d", "mode": "instance"},
|
||||
]
|
||||
)
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> PASTIS:
|
||||
monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url)
|
||||
|
||||
md5 = "9b11ae132623a0d13f7f0775d2003703"
|
||||
monkeypatch.setattr(PASTIS, "md5", md5)
|
||||
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip")
|
||||
monkeypatch.setattr(PASTIS, "url", url)
|
||||
root = str(tmp_path)
|
||||
folds = request.param["folds"]
|
||||
bands = request.param["bands"]
|
||||
mode = request.param["mode"]
|
||||
transforms = nn.Identity()
|
||||
return PASTIS(
|
||||
root, folds, bands, mode, transforms, download=True, checksum=True
|
||||
)
|
||||
|
||||
def test_getitem_semantic(self, dataset: PASTIS) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_getitem_instance(self, dataset: PASTIS) -> None:
|
||||
dataset.mode = "instance"
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
assert isinstance(x["boxes"], torch.Tensor)
|
||||
assert isinstance(x["label"], torch.Tensor)
|
||||
|
||||
def test_len(self, dataset: PASTIS) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_add(self, dataset: PASTIS) -> None:
|
||||
ds = dataset + dataset
|
||||
assert isinstance(ds, ConcatDataset)
|
||||
assert len(ds) == 4
|
||||
|
||||
def test_already_extracted(self, dataset: PASTIS) -> None:
|
||||
PASTIS(root=dataset.root, download=True)
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip")
|
||||
root = str(tmp_path)
|
||||
shutil.copy(url, root)
|
||||
PASTIS(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
PASTIS(str(tmp_path))
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, "PASTIS-R.zip"), "w") as f:
|
||||
f.write("bad")
|
||||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
|
||||
PASTIS(root=str(tmp_path), checksum=True)
|
||||
|
||||
def test_invalid_fold(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
PASTIS(folds=(6,))
|
||||
|
||||
def test_invalid_mode(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
PASTIS(mode="invalid")
|
||||
|
||||
def test_plot(self, dataset: PASTIS) -> None:
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, suptitle="Test")
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
x["prediction"] = x["mask"].clone()
|
||||
if dataset.mode == "instance":
|
||||
x["prediction_labels"] = x["label"].clone()
|
||||
dataset.plot(x)
|
||||
plt.close()
|
|
@ -78,6 +78,7 @@ from .nasa_marine_debris import NASAMarineDebris
|
|||
from .nlcd import NLCD
|
||||
from .openbuildings import OpenBuildings
|
||||
from .oscd import OSCD
|
||||
from .pastis import PASTIS
|
||||
from .patternnet import PatternNet
|
||||
from .potsdam import Potsdam2D
|
||||
from .reforestree import ReforesTree
|
||||
|
@ -194,6 +195,7 @@ __all__ = (
|
|||
"MillionAID",
|
||||
"NASAMarineDebris",
|
||||
"OSCD",
|
||||
"PASTIS",
|
||||
"PatternNet",
|
||||
"Potsdam2D",
|
||||
"RESISC45",
|
||||
|
|
|
@ -0,0 +1,405 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""PASTIS dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional
|
||||
|
||||
import fiona
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.colors import ListedColormap
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_url, extract_archive
|
||||
|
||||
|
||||
class PASTIS(NonGeoDataset):
|
||||
"""PASTIS dataset.
|
||||
|
||||
The `PASTIS <https://github.com/VSainteuf/pastis-benchmark>`__
|
||||
dataset is a dataset for time-series panoptic segmentation of agricultural parcels.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* support for the original PASTIS and PASTIS-R versions of the dataset
|
||||
* 2,433 time-series with 10 m per pixel resolution (128x128 px)
|
||||
* 18 crop categories, 1 background category, 1 void category
|
||||
* semantic and instance annotations
|
||||
* 3 Sentinel-1 Ascending bands
|
||||
* 3 Sentinel-1 Descending bands
|
||||
* 10 Sentinel-2 L2A multispectral bands
|
||||
|
||||
Dataset format:
|
||||
|
||||
* time-series and annotations are in numpy format (.npy)
|
||||
|
||||
Dataset classes:
|
||||
|
||||
0. Background
|
||||
1. Meadow
|
||||
2. Soft Winter Wheat
|
||||
3. Corn
|
||||
4. Winter Barley
|
||||
5. Winter Rapeseed
|
||||
6. Spring Barley
|
||||
7. Sunflower
|
||||
8. Grapevine
|
||||
9. Beet
|
||||
10. Winter Triticale
|
||||
11. Winter Durum Wheat
|
||||
12. Fruits Vegetables Flowers
|
||||
13. Potatoes
|
||||
14. Leguminous Fodder
|
||||
15. Soybeans
|
||||
16. Orchard
|
||||
17. Mixed Cereal
|
||||
18. Sorghum
|
||||
19. Void Label
|
||||
|
||||
If you use this dataset in your research, please cite the following papers:
|
||||
|
||||
* https://doi.org/10.1109/ICCV48922.2021.00483
|
||||
* https://doi.org/10.1016/j.isprsjprs.2022.03.012
|
||||
|
||||
.. versionadded:: 0.5
|
||||
"""
|
||||
|
||||
classes = [
|
||||
"background", # all non-agricultural land
|
||||
"meadow",
|
||||
"soft_winter_wheat",
|
||||
"corn",
|
||||
"winter_barley",
|
||||
"winter_rapeseed",
|
||||
"spring_barley",
|
||||
"sunflower",
|
||||
"grapevine",
|
||||
"beet",
|
||||
"winter_triticale",
|
||||
"winter_durum_wheat",
|
||||
"fruits_vegetables_flowers",
|
||||
"potatoes",
|
||||
"leguminous_fodder",
|
||||
"soybeans",
|
||||
"orchard",
|
||||
"mixed_cereal",
|
||||
"sorghum",
|
||||
"void_label", # for parcels mostly outside their patch
|
||||
]
|
||||
cmap = {
|
||||
0: (0, 0, 0, 255),
|
||||
1: (174, 199, 232, 255),
|
||||
2: (255, 127, 14, 255),
|
||||
3: (255, 187, 120, 255),
|
||||
4: (44, 160, 44, 255),
|
||||
5: (152, 223, 138, 255),
|
||||
6: (214, 39, 40, 255),
|
||||
7: (255, 152, 150, 255),
|
||||
8: (148, 103, 189, 255),
|
||||
9: (197, 176, 213, 255),
|
||||
10: (140, 86, 75, 255),
|
||||
11: (196, 156, 148, 255),
|
||||
12: (227, 119, 194, 255),
|
||||
13: (247, 182, 210, 255),
|
||||
14: (127, 127, 127, 255),
|
||||
15: (199, 199, 199, 255),
|
||||
16: (188, 189, 34, 255),
|
||||
17: (219, 219, 141, 255),
|
||||
18: (23, 190, 207, 255),
|
||||
19: (255, 255, 255, 255),
|
||||
}
|
||||
directory = "PASTIS-R"
|
||||
filename = "PASTIS-R.zip"
|
||||
url = "https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1"
|
||||
md5 = "4887513d6c2d2b07fa935d325bd53e09"
|
||||
prefix = {
|
||||
"s2": os.path.join("DATA_S2", "S2_"),
|
||||
"s1a": os.path.join("DATA_S1A", "S1A_"),
|
||||
"s1d": os.path.join("DATA_S1D", "S1D_"),
|
||||
"semantic": os.path.join("ANNOTATIONS", "TARGET_"),
|
||||
"instance": os.path.join("INSTANCE_ANNOTATIONS", "INSTANCES_"),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
folds: Sequence[int] = (0, 1, 2, 3, 4),
|
||||
bands: str = "s2",
|
||||
mode: str = "semantic",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new PASTIS dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
folds: a sequence of integers from 0 to 4 specifying which of the five
|
||||
dataset folds to include
|
||||
bands: load Sentinel-1 ascending path data (s1a), Sentinel-1 descending path
|
||||
data (s1d), or Sentinel-2 data (s2)
|
||||
mode: load semantic (semantic) or instance (instance) annotations
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
"""
|
||||
assert set(folds) <= set(range(6))
|
||||
assert bands in ["s1a", "s1d", "s2"]
|
||||
assert mode in ["semantic", "instance"]
|
||||
self.root = root
|
||||
self.folds = folds
|
||||
self.bands = bands
|
||||
self.mode = mode
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self._verify()
|
||||
self.files = self._load_files()
|
||||
|
||||
colors = []
|
||||
for i in range(len(self.cmap)):
|
||||
colors.append(
|
||||
(
|
||||
self.cmap[i][0] / 255.0,
|
||||
self.cmap[i][1] / 255.0,
|
||||
self.cmap[i][2] / 255.0,
|
||||
)
|
||||
)
|
||||
self._cmap = ListedColormap(colors)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
image = self._load_image(index)
|
||||
if self.mode == "semantic":
|
||||
mask = self._load_semantic_targets(index)
|
||||
sample = {"image": image, "mask": mask}
|
||||
elif self.mode == "instance":
|
||||
mask, boxes, labels = self._load_instance_targets(index)
|
||||
sample = {"image": image, "mask": mask, "boxes": boxes, "label": labels}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.idxs)
|
||||
|
||||
def _load_image(self, index: int) -> Tensor:
|
||||
"""Load a single time-series.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the time-series
|
||||
"""
|
||||
path = self.files[index][self.bands]
|
||||
array = np.load(path)
|
||||
|
||||
tensor = torch.from_numpy(array)
|
||||
return tensor
|
||||
|
||||
def _load_semantic_targets(self, index: int) -> Tensor:
|
||||
"""Load the target mask for a single image.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the target mask
|
||||
"""
|
||||
# See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501
|
||||
# even though the mask file is 3 bands, we just select the first band
|
||||
array = np.load(self.files[index]["semantic"])[0].astype(np.uint8)
|
||||
tensor = torch.from_numpy(array).long()
|
||||
return tensor
|
||||
|
||||
def _load_instance_targets(self, index: int) -> tuple[Tensor, Tensor, Tensor]:
|
||||
"""Load the instance segmentation targets for a single sample.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the instance segmentation mask, box, and label for each instance
|
||||
"""
|
||||
mask_array = np.load(self.files[index]["semantic"])[0]
|
||||
instance_array = np.load(self.files[index]["instance"])
|
||||
|
||||
mask_tensor = torch.from_numpy(mask_array)
|
||||
instance_tensor = torch.from_numpy(instance_array)
|
||||
|
||||
# Convert instance mask of N instances to N binary instance masks
|
||||
instance_ids = torch.unique(instance_tensor)
|
||||
# Exclude a mask for unknown/background
|
||||
instance_ids = instance_ids[instance_ids != 0]
|
||||
instance_ids = instance_ids[:, None, None]
|
||||
masks: Tensor = instance_tensor == instance_ids
|
||||
|
||||
# Parse labels for each instance
|
||||
labels_list = []
|
||||
for mask in masks:
|
||||
label = mask_tensor[mask]
|
||||
label = torch.unique(label)[0]
|
||||
labels_list.append(label)
|
||||
|
||||
# Get bounding boxes for each instance
|
||||
boxes_list = []
|
||||
for mask in masks:
|
||||
pos = torch.where(mask)
|
||||
xmin = torch.min(pos[1])
|
||||
xmax = torch.max(pos[1])
|
||||
ymin = torch.min(pos[0])
|
||||
ymax = torch.max(pos[0])
|
||||
boxes_list.append([xmin, ymin, xmax, ymax])
|
||||
|
||||
masks = masks.to(torch.uint8)
|
||||
boxes = torch.tensor(boxes_list).to(torch.float)
|
||||
labels = torch.tensor(labels_list).to(torch.long)
|
||||
|
||||
return masks, boxes, labels
|
||||
|
||||
def _load_files(self) -> list[dict[str, str]]:
|
||||
"""List the image and target files.
|
||||
|
||||
Returns:
|
||||
list of dicts containing image and semantic/instance target file paths
|
||||
"""
|
||||
self.idxs = []
|
||||
metadata_fn = os.path.join(self.root, self.directory, "metadata.geojson")
|
||||
with fiona.open(metadata_fn) as f:
|
||||
for row in f:
|
||||
fold = int(row["properties"]["Fold"])
|
||||
if fold in self.folds:
|
||||
self.idxs.append(row["properties"]["ID_PATCH"])
|
||||
|
||||
files = []
|
||||
for i in self.idxs:
|
||||
path = os.path.join(self.root, self.directory, "{}") + str(i) + ".npy"
|
||||
files.append(
|
||||
{
|
||||
"s2": path.format(self.prefix["s2"]),
|
||||
"s1a": path.format(self.prefix["s1a"]),
|
||||
"s1d": path.format(self.prefix["s1d"]),
|
||||
"semantic": path.format(self.prefix["semantic"]),
|
||||
"instance": path.format(self.prefix["instance"]),
|
||||
}
|
||||
)
|
||||
return files
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
# Check if the directory already exists
|
||||
path = os.path.join(self.root, self.directory)
|
||||
if os.path.exists(path):
|
||||
return
|
||||
|
||||
# Check if zip file already exists (if so then extract)
|
||||
filepath = os.path.join(self.root, self.filename)
|
||||
if os.path.exists(filepath):
|
||||
if self.checksum and not check_integrity(filepath, self.md5):
|
||||
raise RuntimeError("Dataset found, but corrupted.")
|
||||
extract_archive(filepath)
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
|
||||
# Download and extract the dataset
|
||||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
download_url(
|
||||
self.url,
|
||||
self.root,
|
||||
filename=self.filename,
|
||||
md5=self.md5 if self.checksum else None,
|
||||
)
|
||||
extract_archive(os.path.join(self.root, self.filename), self.root)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
# Keep the RGB bands and convert to T x H x W x C format
|
||||
images = sample["image"][:, [2, 1, 0], :, :].numpy().transpose(0, 2, 3, 1)
|
||||
mask = sample["mask"].numpy()
|
||||
|
||||
if self.mode == "instance":
|
||||
label = sample["label"]
|
||||
mask = label[mask.argmax(axis=0)].numpy()
|
||||
|
||||
num_panels = 3
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
predictions = sample["prediction"].numpy()
|
||||
num_panels += 1
|
||||
if self.mode == "instance":
|
||||
predictions = predictions.argmax(axis=0)
|
||||
label = sample["prediction_labels"]
|
||||
predictions = label[predictions].numpy()
|
||||
|
||||
fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 4))
|
||||
axs[0].imshow(images[0] / 5000)
|
||||
axs[1].imshow(images[1] / 5000)
|
||||
axs[2].imshow(mask, vmin=0, vmax=19, cmap=self._cmap, interpolation="none")
|
||||
axs[0].axis("off")
|
||||
axs[1].axis("off")
|
||||
axs[2].axis("off")
|
||||
if showing_predictions:
|
||||
axs[3].imshow(
|
||||
predictions, vmin=0, vmax=19, cmap=self._cmap, interpolation="none"
|
||||
)
|
||||
axs[3].axis("off")
|
||||
|
||||
if show_titles:
|
||||
axs[0].set_title("Image 0")
|
||||
axs[1].set_title("Image 1")
|
||||
axs[2].set_title("Mask")
|
||||
if showing_predictions:
|
||||
axs[3].set_title("Prediction")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче