зеркало из https://github.com/microsoft/torchgeo.git
Add ADVANCE dataset (#133)
* updated docs * added torchaudio as optional dependency * added sample data for tests * added dataset * added advance unit tests * replaced torchaudio with scipy.io.wavefile.read * Revert "added torchaudio as optional dependency" This reverts commit 960d94f67533e83facc080d9a6f9b965a8820294. * updated to lazy import scipy, updated docstring * add pytest.importorskip check for scipy * add sample audio wav file creation details * add scipy dependency * downgrading scipy dep to scipy>=1.5.4 * fix pytest.importorskip to return none * update scipy import error message * fixed dummy audio data dims * downgrading scipy dep to scipy>=0.9.0 * added tests for missing h5py * format * fixed missing import test * Update tests/datasets/test_advance.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
58d05d5950
Коммит
77094c21fa
|
@ -70,6 +70,11 @@ Non-geospatial Datasets
|
|||
|
||||
:class:`VisionDataset` is designed for datasets that lack geospatial information. These datasets can still be combined using :class:`ConcatDataset <torch.utils.data.ConcatDataset>`.
|
||||
|
||||
ADVANCE (AuDio Visual Aerial sceNe reCognition datasEt)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: ADVANCE
|
||||
|
||||
Smallholder Cashew Plantations in Benin
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -49,6 +49,8 @@ rasterio>=1.0.16
|
|||
rtree>=0.5
|
||||
# scikit-learn 0.18+ required for sklearn.model_selection module
|
||||
scikit-learn>=0.18
|
||||
# scipy 0.9.0+ required for scipy.io.wavfile.read
|
||||
scipy>=0.9.0
|
||||
# segmentation-models-pytorch 0.2+ required for smp.losses module
|
||||
segmentation-models-pytorch>=0.2
|
||||
# setuptools 30.4+ required for options.packages.find section in setup.cfg
|
||||
|
|
|
@ -36,6 +36,7 @@ dependencies:
|
|||
- radiant-mlhub>=0.2.1
|
||||
- rtree>=0.5
|
||||
- scikit-learn>=0.18
|
||||
- scipy>=0.9.0
|
||||
- segmentation-models-pytorch>=0.2
|
||||
- setuptools>=30.4
|
||||
- sphinx>=3
|
||||
|
|
|
@ -49,6 +49,8 @@ rasterio>=1.0.16
|
|||
rtree>=0.5
|
||||
# scikit-learn 0.18+ required for sklearn.model_selection module
|
||||
scikit-learn>=0.18
|
||||
# scipy 0.9.0+ required for scipy.io.wavfile.read
|
||||
scipy>=0.9.0
|
||||
# segmentation-models-pytorch 0.2+ required for smp.losses module
|
||||
segmentation-models-pytorch>=0.2
|
||||
# setuptools 30.4+ required for options.packages.find section in setup.cfg
|
||||
|
|
|
@ -51,6 +51,7 @@ datasets =
|
|||
pycocotools
|
||||
radiant-mlhub>=0.2.1
|
||||
rarfile>=3
|
||||
scipy>=0.9.0
|
||||
|
||||
# Optional developer requirements
|
||||
docs =
|
||||
|
|
|
@ -29,6 +29,7 @@ spack:
|
|||
- "py-rasterio@1.0.16:"
|
||||
- "py-rtree@0.5:"
|
||||
- "py-scikit-learn@0.18:"
|
||||
- "py-scipy@0.9.0:"
|
||||
- "py-segmentation-models-pytorch@0.2:"
|
||||
- "py-setuptools@30.4:"
|
||||
- "py-shapely@1.3.0:"
|
||||
|
|
|
@ -65,3 +65,13 @@ from PIL import Image
|
|||
img = Image.new("L", (1, 1))
|
||||
img.save("02.jpg")
|
||||
```
|
||||
|
||||
### Audio wav files
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
audio = np.random.randn(1).astype(np.float32)
|
||||
wavfile.write("01.wav", rate=22050, data=audio)
|
||||
```
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ADVANCE
|
||||
from torchgeo.transforms import Identity
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestADVANCE:
|
||||
@pytest.fixture
|
||||
def dataset(
|
||||
self,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
) -> ADVANCE:
|
||||
pytest.importorskip("scipy", minversion="0.9.0")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.utils, "download_url", download_url
|
||||
)
|
||||
data_dir = os.path.join("tests", "data", "advance")
|
||||
urls = [
|
||||
os.path.join(data_dir, "ADVANCE_vision.zip"),
|
||||
os.path.join(data_dir, "ADVANCE_sound.zip"),
|
||||
]
|
||||
md5s = ["43acacecebecd17a82bc2c1e719fd7e4", "039b7baa47879a8a4e32b9dd8287f6ad"]
|
||||
monkeypatch.setattr(ADVANCE, "urls", urls) # type: ignore[attr-defined]
|
||||
monkeypatch.setattr(ADVANCE, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = str(tmp_path)
|
||||
transforms = Identity()
|
||||
return ADVANCE(root, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(
|
||||
self, monkeypatch: Generator[MonkeyPatch, None, None]
|
||||
) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == "scipy.io":
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
builtins, "__import__", mocked_import
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: ADVANCE) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
assert isinstance(x["audio"], torch.Tensor)
|
||||
assert isinstance(x["label"], torch.Tensor)
|
||||
assert x["image"].shape[0] == 3
|
||||
assert x["image"].ndim == 3
|
||||
assert x["audio"].shape[0] == 1
|
||||
assert x["audio"].ndim == 2
|
||||
assert x["label"].ndim == 0
|
||||
|
||||
def test_len(self, dataset: ADVANCE) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: ADVANCE) -> None:
|
||||
ADVANCE(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
ADVANCE(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: ADVANCE, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match="scipy is not installed and is required to use this dataset",
|
||||
):
|
||||
dataset[0]
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
"""TorchGeo datasets."""
|
||||
|
||||
from .advance import ADVANCE
|
||||
from .benin_cashews import BeninSmallHolderCashews
|
||||
from .cbf import CanadianBuildingFootprints
|
||||
from .cdl import CDL
|
||||
|
@ -80,6 +81,7 @@ __all__ = (
|
|||
"Sentinel",
|
||||
"Sentinel2",
|
||||
# VisionDataset
|
||||
"ADVANCE",
|
||||
"BeninSmallHolderCashews",
|
||||
"COWC",
|
||||
"COWCCounting",
|
||||
|
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""ADVANCE dataset."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import VisionDataset
|
||||
from .utils import download_and_extract_archive
|
||||
|
||||
|
||||
class ADVANCE(VisionDataset):
|
||||
"""ADVANCE dataset.
|
||||
|
||||
The `ADVANCE <https://akchen.github.io/ADVANCE-DATASET/>`_
|
||||
dataset is a dataset for audio visual scene recognition.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* 5,075 pairs of geotagged audio recordings and images
|
||||
* three spectral bands - RGB (512x512 px)
|
||||
* 10-second audio recordings
|
||||
|
||||
Dataset format:
|
||||
|
||||
* images are three-channel jpgs
|
||||
* audio files are in wav format
|
||||
|
||||
Dataset classes:
|
||||
|
||||
0. airport
|
||||
1. beach
|
||||
2. bridge
|
||||
3. farmland
|
||||
4. forest
|
||||
5. grassland
|
||||
6. harbour
|
||||
7. lake
|
||||
8. orchard
|
||||
9. residential
|
||||
10. sparse shrub land
|
||||
11. sports land
|
||||
12. train station
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://doi.org/10.1007/978-3-030-58586-0_5
|
||||
|
||||
.. note::
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `scipy <https://pypi.org/project/scipy/>`_ to load the audio files to tensors
|
||||
"""
|
||||
|
||||
urls = [
|
||||
"https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1",
|
||||
"https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1",
|
||||
]
|
||||
filenames = ["ADVANCE_vision.zip", "ADVANCE_sound.zip"]
|
||||
md5s = ["a9e8748219ef5864d3b5a8979a67b471", "a2d12f2d2a64f5c3d3a9d8c09aaf1c31"]
|
||||
directories = ["vision", "sound"]
|
||||
classes = [
|
||||
"airport",
|
||||
"beach",
|
||||
"bridge",
|
||||
"farmland",
|
||||
"forest",
|
||||
"grassland",
|
||||
"harbour",
|
||||
"lake",
|
||||
"orchard",
|
||||
"residential",
|
||||
"sparse shrub land",
|
||||
"sports land",
|
||||
"train station",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new ADVANCE dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
"""
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
||||
if download:
|
||||
self._download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
|
||||
self.files = self._load_files(self.root)
|
||||
self.classes = sorted(set(f["cls"] for f in self.files))
|
||||
self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)}
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
files = self.files[index]
|
||||
image = self._load_image(files["image"])
|
||||
audio = self._load_target(files["audio"])
|
||||
cls_label = self.class_to_idx[files["cls"]]
|
||||
label = torch.tensor(cls_label, dtype=torch.long) # type: ignore[attr-defined]
|
||||
sample = {"image": image, "audio": audio, "label": label}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.files)
|
||||
|
||||
def _load_files(self, root: str) -> List[Dict[str, str]]:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
root: root dir of dataset
|
||||
|
||||
Returns:
|
||||
list of dicts containing paths for each pair of image, audio, label
|
||||
"""
|
||||
images = sorted(glob.glob(os.path.join(root, "vision", "**", "*.jpg")))
|
||||
wavs = sorted(glob.glob(os.path.join(root, "sound", "**", "*.wav")))
|
||||
labels = [image.split(os.sep)[-2] for image in images]
|
||||
files = [
|
||||
dict(image=image, audio=wav, cls=label)
|
||||
for image, wav, label in zip(images, wavs, labels)
|
||||
]
|
||||
return files
|
||||
|
||||
def _load_image(self, path: str) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
path: path to the image
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
with Image.open(path) as img:
|
||||
array = np.array(img.convert("RGB"))
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
# Convert from HxWxC to CxHxW
|
||||
tensor = tensor.permute((2, 0, 1))
|
||||
return tensor
|
||||
|
||||
def _load_target(self, path: str) -> Tensor:
|
||||
"""Load the target audio for a single image.
|
||||
|
||||
Args:
|
||||
path: path to the target
|
||||
|
||||
Returns:
|
||||
the target audio
|
||||
"""
|
||||
try:
|
||||
from scipy.io import wavfile
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"scipy is not installed and is required to use this dataset"
|
||||
)
|
||||
|
||||
array = wavfile.read(path)[1]
|
||||
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
|
||||
tensor = tensor.unsqueeze(0)
|
||||
return tensor
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Checks the integrity of the dataset structure.
|
||||
|
||||
Returns:
|
||||
True if the dataset directories are found, else False
|
||||
"""
|
||||
for directory in self.directories:
|
||||
filepath = os.path.join(self.root, directory)
|
||||
if not os.path.exists(filepath):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Raises:
|
||||
AssertionError: if the checksum of split.py does not match
|
||||
"""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
||||
for filename, url, md5 in zip(self.filenames, self.urls, self.md5s):
|
||||
download_and_extract_archive(
|
||||
url,
|
||||
self.root,
|
||||
filename=filename,
|
||||
md5=md5 if self.checksum else None,
|
||||
)
|
Загрузка…
Ссылка в новой задаче