This commit is contained in:
isaac 2021-09-19 18:25:09 -05:00 коммит произвёл GitHub
Родитель 77094c21fa
Коммит 459524fedc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 358 добавлений и 0 удалений

Просмотреть файл

@ -147,6 +147,11 @@ NWPU VHR-10
.. autoclass:: VHR10
ZueriCrop
^^^^^^^^^
.. autoclass:: ZueriCrop
.. _Base Classes:
Base Classes

Просмотреть файл

@ -75,3 +75,18 @@ from scipy.io import wavfile
audio = np.random.randn(1).astype(np.float32)
wavfile.write("01.wav", rate=22050, data=audio)
```
### HDF5 datasets
```python
import h5py
import numpy as np
f = h5py.File("data.hdf5", "w")
num_classes = 10
images = np.random.randint(low=0, high=255, size=(1, 1, 3)).astype(np.uint8)
masks = np.random.randint(low=0, high=num_classes, size=(1, 1)).astype(np.uint8)
f.create_dataset("images", data=images)
f.create_dataset("masks", data=masks)
f.close()

Двоичные данные
tests/data/zuericrop/ZueriCrop.hdf5 Normal file

Двоичный файл не отображается.

Просмотреть файл

Просмотреть файл

@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from pathlib import Path
from typing import Any, Generator
import pytest
import torch
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ZueriCrop
from torchgeo.transforms import Identity
pytest.importorskip("h5py")
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestZueriCrop:
@pytest.fixture
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> ZueriCrop:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.zuericrop, "download_url", download_url
)
data_dir = os.path.join("tests", "data", "zuericrop")
urls = [
os.path.join(data_dir, "ZueriCrop.hdf5"),
os.path.join(data_dir, "labels.csv"),
]
md5s = ["8c0ca5ad53903aeba8a1d06bba50a5ec", "d41d8cd98f00b204e9800998ecf8427e"]
monkeypatch.setattr(ZueriCrop, "urls", urls) # type: ignore[attr-defined]
monkeypatch.setattr(ZueriCrop, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = Identity()
return ZueriCrop(root, transforms, download=True, checksum=True)
@pytest.fixture
def mock_missing_module(
self, monkeypatch: Generator[MonkeyPatch, None, None]
) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "h5py":
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr( # type: ignore[attr-defined]
builtins, "__import__", mocked_import
)
def test_getitem(self, dataset: ZueriCrop) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
assert isinstance(x["boxes"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
# Image tests
assert x["image"].ndim == 4
# Instance masks tests
assert x["mask"].ndim == 3
assert x["mask"].shape[-2:] == x["image"].shape[-2:]
# Bboxes tests
assert x["boxes"].ndim == 2
assert x["boxes"].shape[1] == 4
# Labels tests
assert x["label"].ndim == 1
def test_len(self, dataset: ZueriCrop) -> None:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: ZueriCrop) -> None:
ZueriCrop(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
ZueriCrop(str(tmp_path))
def test_mock_missing_module(
self, dataset: ZueriCrop, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match="h5py is not installed and is required to use this dataset",
):
ZueriCrop(dataset.root, download=True, checksum=True)

Просмотреть файл

@ -50,6 +50,7 @@ from .sentinel import Sentinel, Sentinel2
from .so2sat import So2Sat
from .spacenet import SpaceNet1
from .utils import BoundingBox, collate_dict
from .zuericrop import ZueriCrop
__all__ = (
# GeoDataset
@ -98,6 +99,7 @@ __all__ = (
"SpaceNet1",
"TropicalCycloneWindEstimation",
"VHR10",
"ZueriCrop",
# Base classes
"GeoDataset",
"RasterDataset",

Просмотреть файл

@ -0,0 +1,232 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""ZueriCrop dataset."""
import os
from typing import Callable, Dict, Optional, Tuple
import torch
from torch import Tensor
from .geo import VisionDataset
from .utils import download_url
class ZueriCrop(VisionDataset):
"""ZueriCrop dataset.
The `ZueriCrop <https://github.com/0zgur0/ms-convSTAR>`_
dataset is a dataset for time-series instance segmentation of crops.
Dataset features:
* Sentinel-2 multispectral imagery
* instance masks of 48 crop categories
* nine multispectral bands
* 116k images with 10 m per pixel resolution (24x24 px)
* ~28k time-series containing 142 images each
Dataset format:
* single hdf5 dataset containing images, semantic masks, and instance masks
* data is parsed into images and instance masks, boxes, and labels
* one mask per time-series
Dataset classes:
* 48 fine-grained hierarchical crop
`categories <https://github.com/0zgur0/ms-convSTAR/blob/master/labels.csv>`_
If you use this dataset in your research, please cite the following paper:
* https://doi.org/10.1016/j.rse.2021.112603
.. note::
This dataset requires the following additional library to be installed:
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
"""
urls = [
"https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download",
"https://raw.githubusercontent.com/0zgur0/ms-convSTAR/master/labels.csv",
]
md5s = ["1635231df67f3d25f4f1e62c98e221a4", "5118398c7a5bbc246f5f6bb35d8d529b"]
filenames = ["ZueriCrop.hdf5", "labels.csv"]
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new ZueriCrop dataset instance.
Args:
root: root directory where dataset can be found
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
self.root = root
self.transforms = transforms
self.download = download
self.checksum = checksum
self.filepath = os.path.join(root, "ZueriCrop.hdf5")
self._verify()
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
"h5py is not installed and is required to use this dataset"
)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
sample containing image, mask, bounding boxes, and target label
"""
image = self._load_image(index)
mask, boxes, label = self._load_target(index)
sample = {"image": image, "mask": mask, "boxes": boxes, "label": label}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
import h5py
with h5py.File(self.filepath, "r") as f:
length: int = f["data"].shape[0]
return length
def _load_image(self, index: int) -> Tensor:
"""Load a single image.
Args:
index: index to return
Returns:
the image
"""
import h5py
with h5py.File(self.filepath, "r") as f:
array = f["data"][index, ...]
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
# Convert from TxHxWxC to TxCxHxW
tensor = tensor.permute((0, 3, 1, 2))
return tensor
def _load_target(self, index: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Load the target mask for a single image.
Args:
index: index to return
Returns:
the target mask and label for each mask
"""
import h5py
with h5py.File(self.filepath, "r") as f:
mask_array = f["gt"][index, ...]
instance_array = f["gt_instance"][index, ...]
mask_tensor = torch.from_numpy(mask_array) # type: ignore[attr-defined]
instance_tensor = torch.from_numpy(instance_array) # type: ignore[attr-defined]
# Convert from HxWxC to CxHxW
mask_tensor = mask_tensor.permute((2, 0, 1))
instance_tensor = instance_tensor.permute((2, 0, 1))
# Convert instance mask of N instances to N binary instance masks
instance_ids = torch.unique(instance_tensor) # type: ignore[attr-defined]
# Exclude a mask for unknown/background
instance_ids = instance_ids[instance_ids != 0]
instance_ids = instance_ids[:, None, None]
masks: Tensor = instance_tensor == instance_ids
# Parse labels for each instance
labels_list = []
for mask in masks:
label = mask_tensor[mask[None, :, :]]
label = torch.unique(label)[0] # type: ignore[attr-defined]
labels_list.append(label)
# Get bounding boxes for each instance
boxes_list = []
for mask in masks:
pos = torch.where(mask) # type: ignore[attr-defined]
xmin = torch.min(pos[1]) # type: ignore[attr-defined]
xmax = torch.max(pos[1]) # type: ignore[attr-defined]
ymin = torch.min(pos[0]) # type: ignore[attr-defined]
ymax = torch.max(pos[0]) # type: ignore[attr-defined]
boxes_list.append([xmin, ymin, xmax, ymax])
masks = masks.to(torch.uint8) # type: ignore[attr-defined]
boxes = torch.tensor(boxes_list).to(torch.float) # type: ignore[attr-defined]
labels = torch.tensor(labels_list).to(torch.long) # type: ignore[attr-defined]
return masks, boxes, labels
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
# Check if the files already exist
exists = []
for filename in self.filenames:
filepath = os.path.join(self.root, filename)
exists.append(os.path.exists(filepath))
if all(exists):
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
"Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)
# Download the dataset
self._download()
def _download(self) -> None:
"""Download the dataset."""
for url, filename, md5 in zip(self.urls, self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if not os.path.exists(filepath):
download_url(
url,
self.root,
filename=filename,
md5=md5 if self.checksum else None,
)