зеркало из https://github.com/microsoft/torchgeo.git
PASTIS dataset (#315)
* draft * add dataset to __init__ * reorganize datasets and datamodules * fix mypy errors * draft * add dataset to __init__ * reorganize datasets and datamodules * fix mypy errors * refactor * Adding docs * Adding plotting, cleaning up some stuff * Black and isort * Fix the datamodule import * Pyupgrade * Fixing some docstrings * Flake8 * Isort * Fix docstrings in datamodules * Fixing fns and docstring * Trying to fix the docs * Trying to fix docs * Adding tests * Black * newline * Made the test dataset larger * Remove the datamodules * Update docs/api/non_geo_datasets.csv Co-authored-by: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Updating cmap * Describe the different band combinations * Merging datasets * Handle the instance segmentation case in plotting * Update torchgeo/datasets/pastis.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Made some code prettier * Adding instance plotting --------- Co-authored-by: Caleb Robinson <calebrob6@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
f0cacd5402
Коммит
711a576e38
|
@ -272,6 +272,11 @@ OSCD
|
||||||
|
|
||||||
.. autoclass:: OSCD
|
.. autoclass:: OSCD
|
||||||
|
|
||||||
|
PASTIS
|
||||||
|
^^^^^^
|
||||||
|
|
||||||
|
.. autoclass:: PASTIS
|
||||||
|
|
||||||
PatternNet
|
PatternNet
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
|
||||||
`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB
|
`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB
|
||||||
`NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB
|
`NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB
|
||||||
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
|
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
|
||||||
|
`PASTIS`_,I,Sentinel-1/2,"2,433",19,128x128xT,10,MSI
|
||||||
`PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB
|
`PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB
|
||||||
`Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI
|
`Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI
|
||||||
`ReforesTree`_,"OD, R",Aerial,100,6,"4,000x4,000",0.02,RGB
|
`ReforesTree`_,"OD, R",Aerial,100,6,"4,000x4,000",0.02,RGB
|
||||||
|
|
|
Двоичный файл не отображается.
|
@ -0,0 +1,91 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import fiona
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
SIZE = 32
|
||||||
|
NUM_SAMPLES = 5
|
||||||
|
MAX_NUM_TIME_STEPS = 10
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
|
FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]]
|
||||||
|
|
||||||
|
filenames: FILENAME_HIERARCHY = {
|
||||||
|
"DATA_S2": ["S2"],
|
||||||
|
"DATA_S1A": ["S1A"],
|
||||||
|
"DATA_S1D": ["S1D"],
|
||||||
|
"ANNOTATIONS": ["TARGET"],
|
||||||
|
"INSTANCE_ANNOTATIONS": ["INSTANCES"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_file(path: str) -> None:
|
||||||
|
for i in range(NUM_SAMPLES):
|
||||||
|
new_path = f"{path}_{i}.npy"
|
||||||
|
fn = os.path.basename(new_path)
|
||||||
|
t = np.random.randint(1, MAX_NUM_TIME_STEPS)
|
||||||
|
if fn.startswith("S2"):
|
||||||
|
data = np.random.randint(0, 256, size=(t, 10, SIZE, SIZE)).astype(np.int16)
|
||||||
|
elif fn.startswith("S1A"):
|
||||||
|
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16)
|
||||||
|
elif fn.startswith("S1D"):
|
||||||
|
data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16)
|
||||||
|
elif fn.startswith("TARGET"):
|
||||||
|
data = np.random.randint(0, 20, size=(3, SIZE, SIZE)).astype(np.uint8)
|
||||||
|
elif fn.startswith("INSTANCES"):
|
||||||
|
data = np.random.randint(0, 100, size=(SIZE, SIZE)).astype(np.int64)
|
||||||
|
np.save(new_path, data)
|
||||||
|
|
||||||
|
|
||||||
|
def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
|
||||||
|
if isinstance(hierarchy, dict):
|
||||||
|
# Recursive case
|
||||||
|
for key, value in hierarchy.items():
|
||||||
|
path = os.path.join(directory, key)
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
create_directory(path, value)
|
||||||
|
else:
|
||||||
|
# Base case
|
||||||
|
for value in hierarchy:
|
||||||
|
path = os.path.join(directory, value)
|
||||||
|
create_file(path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
create_directory("PASTIS-R", filenames)
|
||||||
|
|
||||||
|
schema = {"geometry": "Polygon", "properties": {"Fold": "int", "ID_PATCH": "int"}}
|
||||||
|
with fiona.open(
|
||||||
|
os.path.join("PASTIS-R", "metadata.geojson"),
|
||||||
|
"w",
|
||||||
|
"GeoJSON",
|
||||||
|
crs="EPSG:4326",
|
||||||
|
schema=schema,
|
||||||
|
) as f:
|
||||||
|
for i in range(NUM_SAMPLES):
|
||||||
|
f.write(
|
||||||
|
{
|
||||||
|
"geometry": {
|
||||||
|
"type": "Polygon",
|
||||||
|
"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]],
|
||||||
|
},
|
||||||
|
"id": str(i),
|
||||||
|
"properties": {"Fold": i % 5, "ID_PATCH": i},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = "PASTIS-R.zip"
|
||||||
|
shutil.make_archive(filename.replace(".zip", ""), "zip", ".", "PASTIS-R")
|
||||||
|
|
||||||
|
# Compute checksums
|
||||||
|
with open(filename, "rb") as f:
|
||||||
|
md5 = hashlib.md5(f.read()).hexdigest()
|
||||||
|
print(f"{filename}: {md5}")
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from _pytest.fixtures import SubRequest
|
||||||
|
from pytest import MonkeyPatch
|
||||||
|
from torch.utils.data import ConcatDataset
|
||||||
|
|
||||||
|
import torchgeo.datasets.utils
|
||||||
|
from torchgeo.datasets import PASTIS
|
||||||
|
|
||||||
|
|
||||||
|
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||||
|
shutil.copy(url, root)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPASTIS:
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[
|
||||||
|
{"folds": (0, 1), "bands": "s2", "mode": "semantic"},
|
||||||
|
{"folds": (0, 1), "bands": "s1a", "mode": "semantic"},
|
||||||
|
{"folds": (0, 1), "bands": "s1d", "mode": "instance"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def dataset(
|
||||||
|
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||||
|
) -> PASTIS:
|
||||||
|
monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url)
|
||||||
|
|
||||||
|
md5 = "9b11ae132623a0d13f7f0775d2003703"
|
||||||
|
monkeypatch.setattr(PASTIS, "md5", md5)
|
||||||
|
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip")
|
||||||
|
monkeypatch.setattr(PASTIS, "url", url)
|
||||||
|
root = str(tmp_path)
|
||||||
|
folds = request.param["folds"]
|
||||||
|
bands = request.param["bands"]
|
||||||
|
mode = request.param["mode"]
|
||||||
|
transforms = nn.Identity()
|
||||||
|
return PASTIS(
|
||||||
|
root, folds, bands, mode, transforms, download=True, checksum=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_getitem_semantic(self, dataset: PASTIS) -> None:
|
||||||
|
x = dataset[0]
|
||||||
|
assert isinstance(x, dict)
|
||||||
|
assert isinstance(x["image"], torch.Tensor)
|
||||||
|
assert isinstance(x["mask"], torch.Tensor)
|
||||||
|
|
||||||
|
def test_getitem_instance(self, dataset: PASTIS) -> None:
|
||||||
|
dataset.mode = "instance"
|
||||||
|
x = dataset[0]
|
||||||
|
assert isinstance(x, dict)
|
||||||
|
assert isinstance(x["image"], torch.Tensor)
|
||||||
|
assert isinstance(x["mask"], torch.Tensor)
|
||||||
|
assert isinstance(x["boxes"], torch.Tensor)
|
||||||
|
assert isinstance(x["label"], torch.Tensor)
|
||||||
|
|
||||||
|
def test_len(self, dataset: PASTIS) -> None:
|
||||||
|
assert len(dataset) == 2
|
||||||
|
|
||||||
|
def test_add(self, dataset: PASTIS) -> None:
|
||||||
|
ds = dataset + dataset
|
||||||
|
assert isinstance(ds, ConcatDataset)
|
||||||
|
assert len(ds) == 4
|
||||||
|
|
||||||
|
def test_already_extracted(self, dataset: PASTIS) -> None:
|
||||||
|
PASTIS(root=dataset.root, download=True)
|
||||||
|
|
||||||
|
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||||
|
url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip")
|
||||||
|
root = str(tmp_path)
|
||||||
|
shutil.copy(url, root)
|
||||||
|
PASTIS(root)
|
||||||
|
|
||||||
|
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||||
|
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||||
|
PASTIS(str(tmp_path))
|
||||||
|
|
||||||
|
def test_corrupted(self, tmp_path: Path) -> None:
|
||||||
|
with open(os.path.join(tmp_path, "PASTIS-R.zip"), "w") as f:
|
||||||
|
f.write("bad")
|
||||||
|
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
|
||||||
|
PASTIS(root=str(tmp_path), checksum=True)
|
||||||
|
|
||||||
|
def test_invalid_fold(self) -> None:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
PASTIS(folds=(6,))
|
||||||
|
|
||||||
|
def test_invalid_mode(self) -> None:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
PASTIS(mode="invalid")
|
||||||
|
|
||||||
|
def test_plot(self, dataset: PASTIS) -> None:
|
||||||
|
x = dataset[0].copy()
|
||||||
|
dataset.plot(x, suptitle="Test")
|
||||||
|
plt.close()
|
||||||
|
dataset.plot(x, show_titles=False)
|
||||||
|
plt.close()
|
||||||
|
x["prediction"] = x["mask"].clone()
|
||||||
|
if dataset.mode == "instance":
|
||||||
|
x["prediction_labels"] = x["label"].clone()
|
||||||
|
dataset.plot(x)
|
||||||
|
plt.close()
|
|
@ -78,6 +78,7 @@ from .nasa_marine_debris import NASAMarineDebris
|
||||||
from .nlcd import NLCD
|
from .nlcd import NLCD
|
||||||
from .openbuildings import OpenBuildings
|
from .openbuildings import OpenBuildings
|
||||||
from .oscd import OSCD
|
from .oscd import OSCD
|
||||||
|
from .pastis import PASTIS
|
||||||
from .patternnet import PatternNet
|
from .patternnet import PatternNet
|
||||||
from .potsdam import Potsdam2D
|
from .potsdam import Potsdam2D
|
||||||
from .reforestree import ReforesTree
|
from .reforestree import ReforesTree
|
||||||
|
@ -194,6 +195,7 @@ __all__ = (
|
||||||
"MillionAID",
|
"MillionAID",
|
||||||
"NASAMarineDebris",
|
"NASAMarineDebris",
|
||||||
"OSCD",
|
"OSCD",
|
||||||
|
"PASTIS",
|
||||||
"PatternNet",
|
"PatternNet",
|
||||||
"Potsdam2D",
|
"Potsdam2D",
|
||||||
"RESISC45",
|
"RESISC45",
|
||||||
|
|
|
@ -0,0 +1,405 @@
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
"""PASTIS dataset."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import fiona
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from matplotlib.colors import ListedColormap
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .geo import NonGeoDataset
|
||||||
|
from .utils import check_integrity, download_url, extract_archive
|
||||||
|
|
||||||
|
|
||||||
|
class PASTIS(NonGeoDataset):
|
||||||
|
"""PASTIS dataset.
|
||||||
|
|
||||||
|
The `PASTIS <https://github.com/VSainteuf/pastis-benchmark>`__
|
||||||
|
dataset is a dataset for time-series panoptic segmentation of agricultural parcels.
|
||||||
|
|
||||||
|
Dataset features:
|
||||||
|
|
||||||
|
* support for the original PASTIS and PASTIS-R versions of the dataset
|
||||||
|
* 2,433 time-series with 10 m per pixel resolution (128x128 px)
|
||||||
|
* 18 crop categories, 1 background category, 1 void category
|
||||||
|
* semantic and instance annotations
|
||||||
|
* 3 Sentinel-1 Ascending bands
|
||||||
|
* 3 Sentinel-1 Descending bands
|
||||||
|
* 10 Sentinel-2 L2A multispectral bands
|
||||||
|
|
||||||
|
Dataset format:
|
||||||
|
|
||||||
|
* time-series and annotations are in numpy format (.npy)
|
||||||
|
|
||||||
|
Dataset classes:
|
||||||
|
|
||||||
|
0. Background
|
||||||
|
1. Meadow
|
||||||
|
2. Soft Winter Wheat
|
||||||
|
3. Corn
|
||||||
|
4. Winter Barley
|
||||||
|
5. Winter Rapeseed
|
||||||
|
6. Spring Barley
|
||||||
|
7. Sunflower
|
||||||
|
8. Grapevine
|
||||||
|
9. Beet
|
||||||
|
10. Winter Triticale
|
||||||
|
11. Winter Durum Wheat
|
||||||
|
12. Fruits Vegetables Flowers
|
||||||
|
13. Potatoes
|
||||||
|
14. Leguminous Fodder
|
||||||
|
15. Soybeans
|
||||||
|
16. Orchard
|
||||||
|
17. Mixed Cereal
|
||||||
|
18. Sorghum
|
||||||
|
19. Void Label
|
||||||
|
|
||||||
|
If you use this dataset in your research, please cite the following papers:
|
||||||
|
|
||||||
|
* https://doi.org/10.1109/ICCV48922.2021.00483
|
||||||
|
* https://doi.org/10.1016/j.isprsjprs.2022.03.012
|
||||||
|
|
||||||
|
.. versionadded:: 0.5
|
||||||
|
"""
|
||||||
|
|
||||||
|
classes = [
|
||||||
|
"background", # all non-agricultural land
|
||||||
|
"meadow",
|
||||||
|
"soft_winter_wheat",
|
||||||
|
"corn",
|
||||||
|
"winter_barley",
|
||||||
|
"winter_rapeseed",
|
||||||
|
"spring_barley",
|
||||||
|
"sunflower",
|
||||||
|
"grapevine",
|
||||||
|
"beet",
|
||||||
|
"winter_triticale",
|
||||||
|
"winter_durum_wheat",
|
||||||
|
"fruits_vegetables_flowers",
|
||||||
|
"potatoes",
|
||||||
|
"leguminous_fodder",
|
||||||
|
"soybeans",
|
||||||
|
"orchard",
|
||||||
|
"mixed_cereal",
|
||||||
|
"sorghum",
|
||||||
|
"void_label", # for parcels mostly outside their patch
|
||||||
|
]
|
||||||
|
cmap = {
|
||||||
|
0: (0, 0, 0, 255),
|
||||||
|
1: (174, 199, 232, 255),
|
||||||
|
2: (255, 127, 14, 255),
|
||||||
|
3: (255, 187, 120, 255),
|
||||||
|
4: (44, 160, 44, 255),
|
||||||
|
5: (152, 223, 138, 255),
|
||||||
|
6: (214, 39, 40, 255),
|
||||||
|
7: (255, 152, 150, 255),
|
||||||
|
8: (148, 103, 189, 255),
|
||||||
|
9: (197, 176, 213, 255),
|
||||||
|
10: (140, 86, 75, 255),
|
||||||
|
11: (196, 156, 148, 255),
|
||||||
|
12: (227, 119, 194, 255),
|
||||||
|
13: (247, 182, 210, 255),
|
||||||
|
14: (127, 127, 127, 255),
|
||||||
|
15: (199, 199, 199, 255),
|
||||||
|
16: (188, 189, 34, 255),
|
||||||
|
17: (219, 219, 141, 255),
|
||||||
|
18: (23, 190, 207, 255),
|
||||||
|
19: (255, 255, 255, 255),
|
||||||
|
}
|
||||||
|
directory = "PASTIS-R"
|
||||||
|
filename = "PASTIS-R.zip"
|
||||||
|
url = "https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1"
|
||||||
|
md5 = "4887513d6c2d2b07fa935d325bd53e09"
|
||||||
|
prefix = {
|
||||||
|
"s2": os.path.join("DATA_S2", "S2_"),
|
||||||
|
"s1a": os.path.join("DATA_S1A", "S1A_"),
|
||||||
|
"s1d": os.path.join("DATA_S1D", "S1D_"),
|
||||||
|
"semantic": os.path.join("ANNOTATIONS", "TARGET_"),
|
||||||
|
"instance": os.path.join("INSTANCE_ANNOTATIONS", "INSTANCES_"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: str = "data",
|
||||||
|
folds: Sequence[int] = (0, 1, 2, 3, 4),
|
||||||
|
bands: str = "s2",
|
||||||
|
mode: str = "semantic",
|
||||||
|
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||||
|
download: bool = False,
|
||||||
|
checksum: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a new PASTIS dataset instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
root: root directory where dataset can be found
|
||||||
|
folds: a sequence of integers from 0 to 4 specifying which of the five
|
||||||
|
dataset folds to include
|
||||||
|
bands: load Sentinel-1 ascending path data (s1a), Sentinel-1 descending path
|
||||||
|
data (s1d), or Sentinel-2 data (s2)
|
||||||
|
mode: load semantic (semantic) or instance (instance) annotations
|
||||||
|
transforms: a function/transform that takes input sample and its target as
|
||||||
|
entry and returns a transformed version
|
||||||
|
download: if True, download dataset and store it in the root directory
|
||||||
|
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||||
|
"""
|
||||||
|
assert set(folds) <= set(range(6))
|
||||||
|
assert bands in ["s1a", "s1d", "s2"]
|
||||||
|
assert mode in ["semantic", "instance"]
|
||||||
|
self.root = root
|
||||||
|
self.folds = folds
|
||||||
|
self.bands = bands
|
||||||
|
self.mode = mode
|
||||||
|
self.transforms = transforms
|
||||||
|
self.download = download
|
||||||
|
self.checksum = checksum
|
||||||
|
self._verify()
|
||||||
|
self.files = self._load_files()
|
||||||
|
|
||||||
|
colors = []
|
||||||
|
for i in range(len(self.cmap)):
|
||||||
|
colors.append(
|
||||||
|
(
|
||||||
|
self.cmap[i][0] / 255.0,
|
||||||
|
self.cmap[i][1] / 255.0,
|
||||||
|
self.cmap[i][2] / 255.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._cmap = ListedColormap(colors)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||||
|
"""Return an index within the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data and label at that index
|
||||||
|
"""
|
||||||
|
image = self._load_image(index)
|
||||||
|
if self.mode == "semantic":
|
||||||
|
mask = self._load_semantic_targets(index)
|
||||||
|
sample = {"image": image, "mask": mask}
|
||||||
|
elif self.mode == "instance":
|
||||||
|
mask, boxes, labels = self._load_instance_targets(index)
|
||||||
|
sample = {"image": image, "mask": mask, "boxes": boxes, "label": labels}
|
||||||
|
|
||||||
|
if self.transforms is not None:
|
||||||
|
sample = self.transforms(sample)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of data points in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
length of the dataset
|
||||||
|
"""
|
||||||
|
return len(self.idxs)
|
||||||
|
|
||||||
|
def _load_image(self, index: int) -> Tensor:
|
||||||
|
"""Load a single time-series.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the time-series
|
||||||
|
"""
|
||||||
|
path = self.files[index][self.bands]
|
||||||
|
array = np.load(path)
|
||||||
|
|
||||||
|
tensor = torch.from_numpy(array)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _load_semantic_targets(self, index: int) -> Tensor:
|
||||||
|
"""Load the target mask for a single image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the target mask
|
||||||
|
"""
|
||||||
|
# See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501
|
||||||
|
# even though the mask file is 3 bands, we just select the first band
|
||||||
|
array = np.load(self.files[index]["semantic"])[0].astype(np.uint8)
|
||||||
|
tensor = torch.from_numpy(array).long()
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _load_instance_targets(self, index: int) -> tuple[Tensor, Tensor, Tensor]:
|
||||||
|
"""Load the instance segmentation targets for a single sample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: index to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the instance segmentation mask, box, and label for each instance
|
||||||
|
"""
|
||||||
|
mask_array = np.load(self.files[index]["semantic"])[0]
|
||||||
|
instance_array = np.load(self.files[index]["instance"])
|
||||||
|
|
||||||
|
mask_tensor = torch.from_numpy(mask_array)
|
||||||
|
instance_tensor = torch.from_numpy(instance_array)
|
||||||
|
|
||||||
|
# Convert instance mask of N instances to N binary instance masks
|
||||||
|
instance_ids = torch.unique(instance_tensor)
|
||||||
|
# Exclude a mask for unknown/background
|
||||||
|
instance_ids = instance_ids[instance_ids != 0]
|
||||||
|
instance_ids = instance_ids[:, None, None]
|
||||||
|
masks: Tensor = instance_tensor == instance_ids
|
||||||
|
|
||||||
|
# Parse labels for each instance
|
||||||
|
labels_list = []
|
||||||
|
for mask in masks:
|
||||||
|
label = mask_tensor[mask]
|
||||||
|
label = torch.unique(label)[0]
|
||||||
|
labels_list.append(label)
|
||||||
|
|
||||||
|
# Get bounding boxes for each instance
|
||||||
|
boxes_list = []
|
||||||
|
for mask in masks:
|
||||||
|
pos = torch.where(mask)
|
||||||
|
xmin = torch.min(pos[1])
|
||||||
|
xmax = torch.max(pos[1])
|
||||||
|
ymin = torch.min(pos[0])
|
||||||
|
ymax = torch.max(pos[0])
|
||||||
|
boxes_list.append([xmin, ymin, xmax, ymax])
|
||||||
|
|
||||||
|
masks = masks.to(torch.uint8)
|
||||||
|
boxes = torch.tensor(boxes_list).to(torch.float)
|
||||||
|
labels = torch.tensor(labels_list).to(torch.long)
|
||||||
|
|
||||||
|
return masks, boxes, labels
|
||||||
|
|
||||||
|
def _load_files(self) -> list[dict[str, str]]:
|
||||||
|
"""List the image and target files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of dicts containing image and semantic/instance target file paths
|
||||||
|
"""
|
||||||
|
self.idxs = []
|
||||||
|
metadata_fn = os.path.join(self.root, self.directory, "metadata.geojson")
|
||||||
|
with fiona.open(metadata_fn) as f:
|
||||||
|
for row in f:
|
||||||
|
fold = int(row["properties"]["Fold"])
|
||||||
|
if fold in self.folds:
|
||||||
|
self.idxs.append(row["properties"]["ID_PATCH"])
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for i in self.idxs:
|
||||||
|
path = os.path.join(self.root, self.directory, "{}") + str(i) + ".npy"
|
||||||
|
files.append(
|
||||||
|
{
|
||||||
|
"s2": path.format(self.prefix["s2"]),
|
||||||
|
"s1a": path.format(self.prefix["s1a"]),
|
||||||
|
"s1d": path.format(self.prefix["s1d"]),
|
||||||
|
"semantic": path.format(self.prefix["semantic"]),
|
||||||
|
"instance": path.format(self.prefix["instance"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return files
|
||||||
|
|
||||||
|
def _verify(self) -> None:
|
||||||
|
"""Verify the integrity of the dataset.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||||
|
"""
|
||||||
|
# Check if the directory already exists
|
||||||
|
path = os.path.join(self.root, self.directory)
|
||||||
|
if os.path.exists(path):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if zip file already exists (if so then extract)
|
||||||
|
filepath = os.path.join(self.root, self.filename)
|
||||||
|
if os.path.exists(filepath):
|
||||||
|
if self.checksum and not check_integrity(filepath, self.md5):
|
||||||
|
raise RuntimeError("Dataset found, but corrupted.")
|
||||||
|
extract_archive(filepath)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the user requested to download the dataset
|
||||||
|
if not self.download:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||||
|
"either specify a different `root` directory or use `download=True` "
|
||||||
|
"to automatically download the dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download and extract the dataset
|
||||||
|
self._download()
|
||||||
|
|
||||||
|
def _download(self) -> None:
|
||||||
|
"""Download the dataset."""
|
||||||
|
download_url(
|
||||||
|
self.url,
|
||||||
|
self.root,
|
||||||
|
filename=self.filename,
|
||||||
|
md5=self.md5 if self.checksum else None,
|
||||||
|
)
|
||||||
|
extract_archive(os.path.join(self.root, self.filename), self.root)
|
||||||
|
|
||||||
|
def plot(
|
||||||
|
self,
|
||||||
|
sample: dict[str, Tensor],
|
||||||
|
show_titles: bool = True,
|
||||||
|
suptitle: Optional[str] = None,
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""Plot a sample from the dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: a sample returned by :meth:`__getitem__`
|
||||||
|
show_titles: flag indicating whether to show titles above each panel
|
||||||
|
suptitle: optional string to use as a suptitle
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a matplotlib Figure with the rendered sample
|
||||||
|
"""
|
||||||
|
# Keep the RGB bands and convert to T x H x W x C format
|
||||||
|
images = sample["image"][:, [2, 1, 0], :, :].numpy().transpose(0, 2, 3, 1)
|
||||||
|
mask = sample["mask"].numpy()
|
||||||
|
|
||||||
|
if self.mode == "instance":
|
||||||
|
label = sample["label"]
|
||||||
|
mask = label[mask.argmax(axis=0)].numpy()
|
||||||
|
|
||||||
|
num_panels = 3
|
||||||
|
showing_predictions = "prediction" in sample
|
||||||
|
if showing_predictions:
|
||||||
|
predictions = sample["prediction"].numpy()
|
||||||
|
num_panels += 1
|
||||||
|
if self.mode == "instance":
|
||||||
|
predictions = predictions.argmax(axis=0)
|
||||||
|
label = sample["prediction_labels"]
|
||||||
|
predictions = label[predictions].numpy()
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 4))
|
||||||
|
axs[0].imshow(images[0] / 5000)
|
||||||
|
axs[1].imshow(images[1] / 5000)
|
||||||
|
axs[2].imshow(mask, vmin=0, vmax=19, cmap=self._cmap, interpolation="none")
|
||||||
|
axs[0].axis("off")
|
||||||
|
axs[1].axis("off")
|
||||||
|
axs[2].axis("off")
|
||||||
|
if showing_predictions:
|
||||||
|
axs[3].imshow(
|
||||||
|
predictions, vmin=0, vmax=19, cmap=self._cmap, interpolation="none"
|
||||||
|
)
|
||||||
|
axs[3].axis("off")
|
||||||
|
|
||||||
|
if show_titles:
|
||||||
|
axs[0].set_title("Image 0")
|
||||||
|
axs[1].set_title("Image 1")
|
||||||
|
axs[2].set_title("Mask")
|
||||||
|
if showing_predictions:
|
||||||
|
axs[3].set_title("Prediction")
|
||||||
|
|
||||||
|
if suptitle is not None:
|
||||||
|
plt.suptitle(suptitle)
|
||||||
|
return fig
|
Загрузка…
Ссылка в новой задаче