diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 552665c57..3d417c8a7 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -272,6 +272,11 @@ OSCD .. autoclass:: OSCD +PASTIS +^^^^^^ + +.. autoclass:: PASTIS + PatternNet ^^^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 475a302cb..064257f8a 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -21,6 +21,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB `NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB `OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI +`PASTIS`_,I,Sentinel-1/2,"2,433",19,128x128xT,10,MSI `PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB `Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI `ReforesTree`_,"OD, R",Aerial,100,6,"4,000x4,000",0.02,RGB diff --git a/tests/data/pastis/PASTIS-R.zip b/tests/data/pastis/PASTIS-R.zip new file mode 100644 index 000000000..6c40f2b2a Binary files /dev/null and b/tests/data/pastis/PASTIS-R.zip differ diff --git a/tests/data/pastis/data.py b/tests/data/pastis/data.py new file mode 100644 index 000000000..36742e8a4 --- /dev/null +++ b/tests/data/pastis/data.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil +from typing import Union + +import fiona +import numpy as np + +SIZE = 32 +NUM_SAMPLES = 5 +MAX_NUM_TIME_STEPS = 10 +np.random.seed(0) + +FILENAME_HIERARCHY = Union[dict[str, "FILENAME_HIERARCHY"], list[str]] + +filenames: FILENAME_HIERARCHY = { + "DATA_S2": ["S2"], + "DATA_S1A": ["S1A"], + "DATA_S1D": ["S1D"], + "ANNOTATIONS": ["TARGET"], + "INSTANCE_ANNOTATIONS": ["INSTANCES"], +} + + +def create_file(path: str) -> None: + for i in range(NUM_SAMPLES): + new_path = f"{path}_{i}.npy" + fn = os.path.basename(new_path) + t = np.random.randint(1, MAX_NUM_TIME_STEPS) + if fn.startswith("S2"): + data = np.random.randint(0, 256, size=(t, 10, SIZE, SIZE)).astype(np.int16) + elif fn.startswith("S1A"): + data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16) + elif fn.startswith("S1D"): + data = np.random.randint(0, 256, size=(t, 3, SIZE, SIZE)).astype(np.float16) + elif fn.startswith("TARGET"): + data = np.random.randint(0, 20, size=(3, SIZE, SIZE)).astype(np.uint8) + elif fn.startswith("INSTANCES"): + data = np.random.randint(0, 100, size=(SIZE, SIZE)).astype(np.int64) + np.save(new_path, data) + + +def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: + if isinstance(hierarchy, dict): + # Recursive case + for key, value in hierarchy.items(): + path = os.path.join(directory, key) + os.makedirs(path, exist_ok=True) + create_directory(path, value) + else: + # Base case + for value in hierarchy: + path = os.path.join(directory, value) + create_file(path) + + +if __name__ == "__main__": + create_directory("PASTIS-R", filenames) + + schema = {"geometry": "Polygon", "properties": {"Fold": "int", "ID_PATCH": "int"}} + with fiona.open( + os.path.join("PASTIS-R", "metadata.geojson"), + "w", + "GeoJSON", + crs="EPSG:4326", + schema=schema, + ) as f: + for i in range(NUM_SAMPLES): + f.write( + { + "geometry": { + "type": "Polygon", + "coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]], + }, + "id": str(i), + "properties": {"Fold": i % 5, "ID_PATCH": i}, + } + ) + + filename = "PASTIS-R.zip" + shutil.make_archive(filename.replace(".zip", ""), "zip", ".", "PASTIS-R") + + # Compute checksums + with open(filename, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{filename}: {md5}") diff --git a/tests/datasets/test_pastis.py b/tests/datasets/test_pastis.py new file mode 100644 index 000000000..698d12487 --- /dev/null +++ b/tests/datasets/test_pastis.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from torch.utils.data import ConcatDataset + +import torchgeo.datasets.utils +from torchgeo.datasets import PASTIS + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestPASTIS: + @pytest.fixture( + params=[ + {"folds": (0, 1), "bands": "s2", "mode": "semantic"}, + {"folds": (0, 1), "bands": "s1a", "mode": "semantic"}, + {"folds": (0, 1), "bands": "s1d", "mode": "instance"}, + ] + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> PASTIS: + monkeypatch.setattr(torchgeo.datasets.pastis, "download_url", download_url) + + md5 = "9b11ae132623a0d13f7f0775d2003703" + monkeypatch.setattr(PASTIS, "md5", md5) + url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") + monkeypatch.setattr(PASTIS, "url", url) + root = str(tmp_path) + folds = request.param["folds"] + bands = request.param["bands"] + mode = request.param["mode"] + transforms = nn.Identity() + return PASTIS( + root, folds, bands, mode, transforms, download=True, checksum=True + ) + + def test_getitem_semantic(self, dataset: PASTIS) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + + def test_getitem_instance(self, dataset: PASTIS) -> None: + dataset.mode = "instance" + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x["boxes"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + + def test_len(self, dataset: PASTIS) -> None: + assert len(dataset) == 2 + + def test_add(self, dataset: PASTIS) -> None: + ds = dataset + dataset + assert isinstance(ds, ConcatDataset) + assert len(ds) == 4 + + def test_already_extracted(self, dataset: PASTIS) -> None: + PASTIS(root=dataset.root, download=True) + + def test_already_downloaded(self, tmp_path: Path) -> None: + url = os.path.join("tests", "data", "pastis", "PASTIS-R.zip") + root = str(tmp_path) + shutil.copy(url, root) + PASTIS(root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + PASTIS(str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "PASTIS-R.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + PASTIS(root=str(tmp_path), checksum=True) + + def test_invalid_fold(self) -> None: + with pytest.raises(AssertionError): + PASTIS(folds=(6,)) + + def test_invalid_mode(self) -> None: + with pytest.raises(AssertionError): + PASTIS(mode="invalid") + + def test_plot(self, dataset: PASTIS) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["mask"].clone() + if dataset.mode == "instance": + x["prediction_labels"] = x["label"].clone() + dataset.plot(x) + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index ae0efe90e..2cbed1f76 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -78,6 +78,7 @@ from .nasa_marine_debris import NASAMarineDebris from .nlcd import NLCD from .openbuildings import OpenBuildings from .oscd import OSCD +from .pastis import PASTIS from .patternnet import PatternNet from .potsdam import Potsdam2D from .reforestree import ReforesTree @@ -194,6 +195,7 @@ __all__ = ( "MillionAID", "NASAMarineDebris", "OSCD", + "PASTIS", "PatternNet", "Potsdam2D", "RESISC45", diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py new file mode 100644 index 000000000..f3e771c92 --- /dev/null +++ b/torchgeo/datasets/pastis.py @@ -0,0 +1,405 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""PASTIS dataset.""" + +import os +from collections.abc import Sequence +from typing import Callable, Optional + +import fiona +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.colors import ListedColormap +from torch import Tensor + +from .geo import NonGeoDataset +from .utils import check_integrity, download_url, extract_archive + + +class PASTIS(NonGeoDataset): + """PASTIS dataset. + + The `PASTIS `__ + dataset is a dataset for time-series panoptic segmentation of agricultural parcels. + + Dataset features: + + * support for the original PASTIS and PASTIS-R versions of the dataset + * 2,433 time-series with 10 m per pixel resolution (128x128 px) + * 18 crop categories, 1 background category, 1 void category + * semantic and instance annotations + * 3 Sentinel-1 Ascending bands + * 3 Sentinel-1 Descending bands + * 10 Sentinel-2 L2A multispectral bands + + Dataset format: + + * time-series and annotations are in numpy format (.npy) + + Dataset classes: + + 0. Background + 1. Meadow + 2. Soft Winter Wheat + 3. Corn + 4. Winter Barley + 5. Winter Rapeseed + 6. Spring Barley + 7. Sunflower + 8. Grapevine + 9. Beet + 10. Winter Triticale + 11. Winter Durum Wheat + 12. Fruits Vegetables Flowers + 13. Potatoes + 14. Leguminous Fodder + 15. Soybeans + 16. Orchard + 17. Mixed Cereal + 18. Sorghum + 19. Void Label + + If you use this dataset in your research, please cite the following papers: + + * https://doi.org/10.1109/ICCV48922.2021.00483 + * https://doi.org/10.1016/j.isprsjprs.2022.03.012 + + .. versionadded:: 0.5 + """ + + classes = [ + "background", # all non-agricultural land + "meadow", + "soft_winter_wheat", + "corn", + "winter_barley", + "winter_rapeseed", + "spring_barley", + "sunflower", + "grapevine", + "beet", + "winter_triticale", + "winter_durum_wheat", + "fruits_vegetables_flowers", + "potatoes", + "leguminous_fodder", + "soybeans", + "orchard", + "mixed_cereal", + "sorghum", + "void_label", # for parcels mostly outside their patch + ] + cmap = { + 0: (0, 0, 0, 255), + 1: (174, 199, 232, 255), + 2: (255, 127, 14, 255), + 3: (255, 187, 120, 255), + 4: (44, 160, 44, 255), + 5: (152, 223, 138, 255), + 6: (214, 39, 40, 255), + 7: (255, 152, 150, 255), + 8: (148, 103, 189, 255), + 9: (197, 176, 213, 255), + 10: (140, 86, 75, 255), + 11: (196, 156, 148, 255), + 12: (227, 119, 194, 255), + 13: (247, 182, 210, 255), + 14: (127, 127, 127, 255), + 15: (199, 199, 199, 255), + 16: (188, 189, 34, 255), + 17: (219, 219, 141, 255), + 18: (23, 190, 207, 255), + 19: (255, 255, 255, 255), + } + directory = "PASTIS-R" + filename = "PASTIS-R.zip" + url = "https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1" + md5 = "4887513d6c2d2b07fa935d325bd53e09" + prefix = { + "s2": os.path.join("DATA_S2", "S2_"), + "s1a": os.path.join("DATA_S1A", "S1A_"), + "s1d": os.path.join("DATA_S1D", "S1D_"), + "semantic": os.path.join("ANNOTATIONS", "TARGET_"), + "instance": os.path.join("INSTANCE_ANNOTATIONS", "INSTANCES_"), + } + + def __init__( + self, + root: str = "data", + folds: Sequence[int] = (0, 1, 2, 3, 4), + bands: str = "s2", + mode: str = "semantic", + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new PASTIS dataset instance. + + Args: + root: root directory where dataset can be found + folds: a sequence of integers from 0 to 4 specifying which of the five + dataset folds to include + bands: load Sentinel-1 ascending path data (s1a), Sentinel-1 descending path + data (s1d), or Sentinel-2 data (s2) + mode: load semantic (semantic) or instance (instance) annotations + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + """ + assert set(folds) <= set(range(6)) + assert bands in ["s1a", "s1d", "s2"] + assert mode in ["semantic", "instance"] + self.root = root + self.folds = folds + self.bands = bands + self.mode = mode + self.transforms = transforms + self.download = download + self.checksum = checksum + self._verify() + self.files = self._load_files() + + colors = [] + for i in range(len(self.cmap)): + colors.append( + ( + self.cmap[i][0] / 255.0, + self.cmap[i][1] / 255.0, + self.cmap[i][2] / 255.0, + ) + ) + self._cmap = ListedColormap(colors) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + image = self._load_image(index) + if self.mode == "semantic": + mask = self._load_semantic_targets(index) + sample = {"image": image, "mask": mask} + elif self.mode == "instance": + mask, boxes, labels = self._load_instance_targets(index) + sample = {"image": image, "mask": mask, "boxes": boxes, "label": labels} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.idxs) + + def _load_image(self, index: int) -> Tensor: + """Load a single time-series. + + Args: + index: index to return + + Returns: + the time-series + """ + path = self.files[index][self.bands] + array = np.load(path) + + tensor = torch.from_numpy(array) + return tensor + + def _load_semantic_targets(self, index: int) -> Tensor: + """Load the target mask for a single image. + + Args: + index: index to return + + Returns: + the target mask + """ + # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501 + # even though the mask file is 3 bands, we just select the first band + array = np.load(self.files[index]["semantic"])[0].astype(np.uint8) + tensor = torch.from_numpy(array).long() + return tensor + + def _load_instance_targets(self, index: int) -> tuple[Tensor, Tensor, Tensor]: + """Load the instance segmentation targets for a single sample. + + Args: + index: index to return + + Returns: + the instance segmentation mask, box, and label for each instance + """ + mask_array = np.load(self.files[index]["semantic"])[0] + instance_array = np.load(self.files[index]["instance"]) + + mask_tensor = torch.from_numpy(mask_array) + instance_tensor = torch.from_numpy(instance_array) + + # Convert instance mask of N instances to N binary instance masks + instance_ids = torch.unique(instance_tensor) + # Exclude a mask for unknown/background + instance_ids = instance_ids[instance_ids != 0] + instance_ids = instance_ids[:, None, None] + masks: Tensor = instance_tensor == instance_ids + + # Parse labels for each instance + labels_list = [] + for mask in masks: + label = mask_tensor[mask] + label = torch.unique(label)[0] + labels_list.append(label) + + # Get bounding boxes for each instance + boxes_list = [] + for mask in masks: + pos = torch.where(mask) + xmin = torch.min(pos[1]) + xmax = torch.max(pos[1]) + ymin = torch.min(pos[0]) + ymax = torch.max(pos[0]) + boxes_list.append([xmin, ymin, xmax, ymax]) + + masks = masks.to(torch.uint8) + boxes = torch.tensor(boxes_list).to(torch.float) + labels = torch.tensor(labels_list).to(torch.long) + + return masks, boxes, labels + + def _load_files(self) -> list[dict[str, str]]: + """List the image and target files. + + Returns: + list of dicts containing image and semantic/instance target file paths + """ + self.idxs = [] + metadata_fn = os.path.join(self.root, self.directory, "metadata.geojson") + with fiona.open(metadata_fn) as f: + for row in f: + fold = int(row["properties"]["Fold"]) + if fold in self.folds: + self.idxs.append(row["properties"]["ID_PATCH"]) + + files = [] + for i in self.idxs: + path = os.path.join(self.root, self.directory, "{}") + str(i) + ".npy" + files.append( + { + "s2": path.format(self.prefix["s2"]), + "s1a": path.format(self.prefix["s1a"]), + "s1d": path.format(self.prefix["s1d"]), + "semantic": path.format(self.prefix["semantic"]), + "instance": path.format(self.prefix["instance"]), + } + ) + return files + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the directory already exists + path = os.path.join(self.root, self.directory) + if os.path.exists(path): + return + + # Check if zip file already exists (if so then extract) + filepath = os.path.join(self.root, self.filename) + if os.path.exists(filepath): + if self.checksum and not check_integrity(filepath, self.md5): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automatically download the dataset." + ) + + # Download and extract the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + download_url( + self.url, + self.root, + filename=self.filename, + md5=self.md5 if self.checksum else None, + ) + extract_archive(os.path.join(self.root, self.filename), self.root) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + # Keep the RGB bands and convert to T x H x W x C format + images = sample["image"][:, [2, 1, 0], :, :].numpy().transpose(0, 2, 3, 1) + mask = sample["mask"].numpy() + + if self.mode == "instance": + label = sample["label"] + mask = label[mask.argmax(axis=0)].numpy() + + num_panels = 3 + showing_predictions = "prediction" in sample + if showing_predictions: + predictions = sample["prediction"].numpy() + num_panels += 1 + if self.mode == "instance": + predictions = predictions.argmax(axis=0) + label = sample["prediction_labels"] + predictions = label[predictions].numpy() + + fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 4)) + axs[0].imshow(images[0] / 5000) + axs[1].imshow(images[1] / 5000) + axs[2].imshow(mask, vmin=0, vmax=19, cmap=self._cmap, interpolation="none") + axs[0].axis("off") + axs[1].axis("off") + axs[2].axis("off") + if showing_predictions: + axs[3].imshow( + predictions, vmin=0, vmax=19, cmap=self._cmap, interpolation="none" + ) + axs[3].axis("off") + + if show_titles: + axs[0].set_title("Image 0") + axs[1].set_title("Image 1") + axs[2].set_title("Mask") + if showing_predictions: + axs[3].set_title("Prediction") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig