зеркало из https://github.com/microsoft/torchgeo.git
Add ZueriCrop dataset (#147)
This commit is contained in:
Родитель
77094c21fa
Коммит
459524fedc
|
@ -147,6 +147,11 @@ NWPU VHR-10
|
|||
|
||||
.. autoclass:: VHR10
|
||||
|
||||
ZueriCrop
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: ZueriCrop
|
||||
|
||||
.. _Base Classes:
|
||||
|
||||
Base Classes
|
||||
|
|
|
@ -75,3 +75,18 @@ from scipy.io import wavfile
|
|||
audio = np.random.randn(1).astype(np.float32)
|
||||
wavfile.write("01.wav", rate=22050, data=audio)
|
||||
```
|
||||
|
||||
### HDF5 datasets
|
||||
|
||||
```python
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
f = h5py.File("data.hdf5", "w")
|
||||
|
||||
num_classes = 10
|
||||
images = np.random.randint(low=0, high=255, size=(1, 1, 3)).astype(np.uint8)
|
||||
masks = np.random.randint(low=0, high=num_classes, size=(1, 1)).astype(np.uint8)
|
||||
f.create_dataset("images", data=images)
|
||||
f.create_dataset("masks", data=masks)
|
||||
f.close()
|
||||
|
|
Двоичный файл не отображается.
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ZueriCrop
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
pytest.importorskip("h5py")
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestZueriCrop:
|
||||
@pytest.fixture
|
||||
def dataset(
|
||||
self,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
) -> ZueriCrop:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.zuericrop, "download_url", download_url
|
||||
)
|
||||
data_dir = os.path.join("tests", "data", "zuericrop")
|
||||
urls = [
|
||||
os.path.join(data_dir, "ZueriCrop.hdf5"),
|
||||
os.path.join(data_dir, "labels.csv"),
|
||||
]
|
||||
md5s = ["8c0ca5ad53903aeba8a1d06bba50a5ec", "d41d8cd98f00b204e9800998ecf8427e"]
|
||||
monkeypatch.setattr(ZueriCrop, "urls", urls) # type: ignore[attr-defined]
|
||||
monkeypatch.setattr(ZueriCrop, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
return ZueriCrop(root, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(
|
||||
self, monkeypatch: Generator[MonkeyPatch, None, None]
|
||||
) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == "h5py":
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
builtins, "__import__", mocked_import
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: ZueriCrop) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
assert isinstance(x["boxes"], torch.Tensor)
|
||||
assert isinstance(x["label"], torch.Tensor)
|
||||
|
||||
# Image tests
|
||||
assert x["image"].ndim == 4
|
||||
|
||||
# Instance masks tests
|
||||
assert x["mask"].ndim == 3
|
||||
assert x["mask"].shape[-2:] == x["image"].shape[-2:]
|
||||
|
||||
# Bboxes tests
|
||||
assert x["boxes"].ndim == 2
|
||||
assert x["boxes"].shape[1] == 4
|
||||
|
||||
# Labels tests
|
||||
assert x["label"].ndim == 1
|
||||
|
||||
def test_len(self, dataset: ZueriCrop) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: ZueriCrop) -> None:
|
||||
ZueriCrop(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
ZueriCrop(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: ZueriCrop, tmp_path: Path, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match="h5py is not installed and is required to use this dataset",
|
||||
):
|
||||
ZueriCrop(dataset.root, download=True, checksum=True)
|
|
@ -50,6 +50,7 @@ from .sentinel import Sentinel, Sentinel2
|
|||
from .so2sat import So2Sat
|
||||
from .spacenet import SpaceNet1
|
||||
from .utils import BoundingBox, collate_dict
|
||||
from .zuericrop import ZueriCrop
|
||||
|
||||
__all__ = (
|
||||
# GeoDataset
|
||||
|
@ -98,6 +99,7 @@ __all__ = (
|
|||
"SpaceNet1",
|
||||
"TropicalCycloneWindEstimation",
|
||||
"VHR10",
|
||||
"ZueriCrop",
|
||||
# Base classes
|
||||
"GeoDataset",
|
||||
"RasterDataset",
|
||||
|
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""ZueriCrop dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import download_url
|
||||
|
||||
|
||||
class ZueriCrop(VisionDataset):
|
||||
"""ZueriCrop dataset.
|
||||
|
||||
The `ZueriCrop <https://github.com/0zgur0/ms-convSTAR>`_
|
||||
dataset is a dataset for time-series instance segmentation of crops.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* Sentinel-2 multispectral imagery
|
||||
* instance masks of 48 crop categories
|
||||
* nine multispectral bands
|
||||
* 116k images with 10 m per pixel resolution (24x24 px)
|
||||
* ~28k time-series containing 142 images each
|
||||
|
||||
Dataset format:
|
||||
|
||||
* single hdf5 dataset containing images, semantic masks, and instance masks
|
||||
* data is parsed into images and instance masks, boxes, and labels
|
||||
* one mask per time-series
|
||||
|
||||
Dataset classes:
|
||||
|
||||
* 48 fine-grained hierarchical crop
|
||||
`categories <https://github.com/0zgur0/ms-convSTAR/blob/master/labels.csv>`_
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://doi.org/10.1016/j.rse.2021.112603
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
|
||||
"""
|
||||
|
||||
urls = [
|
||||
"https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download",
|
||||
"https://raw.githubusercontent.com/0zgur0/ms-convSTAR/master/labels.csv",
|
||||
]
|
||||
md5s = ["1635231df67f3d25f4f1e62c98e221a4", "5118398c7a5bbc246f5f6bb35d8d529b"]
|
||||
filenames = ["ZueriCrop.hdf5", "labels.csv"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new ZueriCrop dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
"""
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.filepath = os.path.join(root, "ZueriCrop.hdf5")
|
||||
|
||||
self._verify()
|
||||
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"h5py is not installed and is required to use this dataset"
|
||||
)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
sample containing image, mask, bounding boxes, and target label
|
||||
"""
|
||||
image = self._load_image(index)
|
||||
mask, boxes, label = self._load_target(index)
|
||||
|
||||
sample = {"image": image, "mask": mask, "boxes": boxes, "label": label}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
import h5py
|
||||
|
||||
with h5py.File(self.filepath, "r") as f:
|
||||
length: int = f["data"].shape[0]
|
||||
return length
|
||||
|
||||
def _load_image(self, index: int) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
import h5py
|
||||
|
||||
with h5py.File(self.filepath, "r") as f:
|
||||
array = f["data"][index, ...]
|
||||
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
# Convert from TxHxWxC to TxCxHxW
|
||||
tensor = tensor.permute((0, 3, 1, 2))
|
||||
return tensor
|
||||
|
||||
def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Load the target mask for a single image.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
the target mask and label for each mask
|
||||
"""
|
||||
import h5py
|
||||
|
||||
with h5py.File(self.filepath, "r") as f:
|
||||
mask_array = f["gt"][index, ...]
|
||||
instance_array = f["gt_instance"][index, ...]
|
||||
|
||||
mask_tensor = torch.from_numpy(mask_array) # type: ignore[attr-defined]
|
||||
instance_tensor = torch.from_numpy(instance_array) # type: ignore[attr-defined]
|
||||
|
||||
# Convert from HxWxC to CxHxW
|
||||
mask_tensor = mask_tensor.permute((2, 0, 1))
|
||||
instance_tensor = instance_tensor.permute((2, 0, 1))
|
||||
|
||||
# Convert instance mask of N instances to N binary instance masks
|
||||
instance_ids = torch.unique(instance_tensor) # type: ignore[attr-defined]
|
||||
# Exclude a mask for unknown/background
|
||||
instance_ids = instance_ids[instance_ids != 0]
|
||||
instance_ids = instance_ids[:, None, None]
|
||||
masks: Tensor = instance_tensor == instance_ids
|
||||
|
||||
# Parse labels for each instance
|
||||
labels_list = []
|
||||
for mask in masks:
|
||||
label = mask_tensor[mask[None, :, :]]
|
||||
label = torch.unique(label)[0] # type: ignore[attr-defined]
|
||||
labels_list.append(label)
|
||||
|
||||
# Get bounding boxes for each instance
|
||||
boxes_list = []
|
||||
for mask in masks:
|
||||
pos = torch.where(mask) # type: ignore[attr-defined]
|
||||
xmin = torch.min(pos[1]) # type: ignore[attr-defined]
|
||||
xmax = torch.max(pos[1]) # type: ignore[attr-defined]
|
||||
ymin = torch.min(pos[0]) # type: ignore[attr-defined]
|
||||
ymax = torch.max(pos[0]) # type: ignore[attr-defined]
|
||||
boxes_list.append([xmin, ymin, xmax, ymax])
|
||||
|
||||
masks = masks.to(torch.uint8) # type: ignore[attr-defined]
|
||||
boxes = torch.tensor(boxes_list).to(torch.float) # type: ignore[attr-defined]
|
||||
labels = torch.tensor(labels_list).to(torch.long) # type: ignore[attr-defined]
|
||||
|
||||
return masks, boxes, labels
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
# Check if the files already exist
|
||||
exists = []
|
||||
for filename in self.filenames:
|
||||
filepath = os.path.join(self.root, filename)
|
||||
exists.append(os.path.exists(filepath))
|
||||
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automaticaly download the dataset."
|
||||
)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
for url, filename, md5 in zip(self.urls, self.filenames, self.md5s):
|
||||
filepath = os.path.join(self.root, filename)
|
||||
if not os.path.exists(filepath):
|
||||
download_url(
|
||||
url,
|
||||
self.root,
|
||||
filename=filename,
|
||||
md5=md5 if self.checksum else None,
|
||||
)
|
Загрузка…
Ссылка в новой задаче