* updated docs

* added torchaudio as optional dependency

* added sample data for tests

* added dataset

* added advance unit tests

* replaced torchaudio with scipy.io.wavefile.read

* Revert "added torchaudio as optional dependency"

This reverts commit 960d94f67533e83facc080d9a6f9b965a8820294.

* updated to lazy import scipy, updated docstring

* add pytest.importorskip check for scipy

* add sample audio wav file creation details

* add scipy dependency

* downgrading scipy dep to scipy>=1.5.4

* fix pytest.importorskip to return none

* update scipy import error message

* fixed dummy audio data dims

* downgrading scipy dep to scipy>=0.9.0

* added tests for missing h5py

* format

* fixed missing import test

* Update tests/datasets/test_advance.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
isaac 2021-09-19 18:00:56 -05:00 коммит произвёл GitHub
Родитель 58d05d5950
Коммит 77094c21fa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
12 изменённых файлов: 348 добавлений и 0 удалений

Просмотреть файл

@ -70,6 +70,11 @@ Non-geospatial Datasets
:class:`VisionDataset` is designed for datasets that lack geospatial information. These datasets can still be combined using :class:`ConcatDataset <torch.utils.data.ConcatDataset>`.
ADVANCE (AuDio Visual Aerial sceNe reCognition datasEt)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: ADVANCE
Smallholder Cashew Plantations in Benin
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Просмотреть файл

@ -49,6 +49,8 @@ rasterio>=1.0.16
rtree>=0.5
# scikit-learn 0.18+ required for sklearn.model_selection module
scikit-learn>=0.18
# scipy 0.9.0+ required for scipy.io.wavfile.read
scipy>=0.9.0
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2
# setuptools 30.4+ required for options.packages.find section in setup.cfg

Просмотреть файл

@ -36,6 +36,7 @@ dependencies:
- radiant-mlhub>=0.2.1
- rtree>=0.5
- scikit-learn>=0.18
- scipy>=0.9.0
- segmentation-models-pytorch>=0.2
- setuptools>=30.4
- sphinx>=3

Просмотреть файл

@ -49,6 +49,8 @@ rasterio>=1.0.16
rtree>=0.5
# scikit-learn 0.18+ required for sklearn.model_selection module
scikit-learn>=0.18
# scipy 0.9.0+ required for scipy.io.wavfile.read
scipy>=0.9.0
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2
# setuptools 30.4+ required for options.packages.find section in setup.cfg

Просмотреть файл

@ -51,6 +51,7 @@ datasets =
pycocotools
radiant-mlhub>=0.2.1
rarfile>=3
scipy>=0.9.0
# Optional developer requirements
docs =

Просмотреть файл

@ -29,6 +29,7 @@ spack:
- "py-rasterio@1.0.16:"
- "py-rtree@0.5:"
- "py-scikit-learn@0.18:"
- "py-scipy@0.9.0:"
- "py-segmentation-models-pytorch@0.2:"
- "py-setuptools@30.4:"
- "py-shapely@1.3.0:"

Просмотреть файл

@ -65,3 +65,13 @@ from PIL import Image
img = Image.new("L", (1, 1))
img.save("02.jpg")
```
### Audio wav files
```python
import numpy as np
from scipy.io import wavfile
audio = np.random.randn(1).astype(np.float32)
wavfile.write("01.wav", rate=22050, data=audio)
```

Двоичные данные
tests/data/advance/ADVANCE_sound.zip Normal file

Двоичный файл не отображается.

Двоичные данные
tests/data/advance/ADVANCE_vision.zip Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from pathlib import Path
from typing import Any, Generator
import pytest
import torch
from _pytest.monkeypatch import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ADVANCE
from torchgeo.transforms import Identity
def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
class TestADVANCE:
@pytest.fixture
def dataset(
self,
monkeypatch: Generator[MonkeyPatch, None, None],
tmp_path: Path,
) -> ADVANCE:
pytest.importorskip("scipy", minversion="0.9.0")
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
)
data_dir = os.path.join("tests", "data", "advance")
urls = [
os.path.join(data_dir, "ADVANCE_vision.zip"),
os.path.join(data_dir, "ADVANCE_sound.zip"),
]
md5s = ["43acacecebecd17a82bc2c1e719fd7e4", "039b7baa47879a8a4e32b9dd8287f6ad"]
monkeypatch.setattr(ADVANCE, "urls", urls) # type: ignore[attr-defined]
monkeypatch.setattr(ADVANCE, "md5s", md5s) # type: ignore[attr-defined]
root = str(tmp_path)
transforms = Identity()
return ADVANCE(root, transforms, download=True, checksum=True)
@pytest.fixture
def mock_missing_module(
self, monkeypatch: Generator[MonkeyPatch, None, None]
) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "scipy.io":
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr( # type: ignore[attr-defined]
builtins, "__import__", mocked_import
)
def test_getitem(self, dataset: ADVANCE) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["audio"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["image"].shape[0] == 3
assert x["image"].ndim == 3
assert x["audio"].shape[0] == 1
assert x["audio"].ndim == 2
assert x["label"].ndim == 0
def test_len(self, dataset: ADVANCE) -> None:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: ADVANCE) -> None:
ADVANCE(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
ADVANCE(str(tmp_path))
def test_mock_missing_module(
self, dataset: ADVANCE, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match="scipy is not installed and is required to use this dataset",
):
dataset[0]

Просмотреть файл

@ -3,6 +3,7 @@
"""TorchGeo datasets."""
from .advance import ADVANCE
from .benin_cashews import BeninSmallHolderCashews
from .cbf import CanadianBuildingFootprints
from .cdl import CDL
@ -80,6 +81,7 @@ __all__ = (
"Sentinel",
"Sentinel2",
# VisionDataset
"ADVANCE",
"BeninSmallHolderCashews",
"COWC",
"COWCCounting",

Просмотреть файл

@ -0,0 +1,234 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""ADVANCE dataset."""
import glob
import os
from typing import Callable, Dict, List, Optional
import numpy as np
import torch
from PIL import Image
from torch import Tensor
from .geo import VisionDataset
from .utils import download_and_extract_archive
class ADVANCE(VisionDataset):
"""ADVANCE dataset.
The `ADVANCE <https://akchen.github.io/ADVANCE-DATASET/>`_
dataset is a dataset for audio visual scene recognition.
Dataset features:
* 5,075 pairs of geotagged audio recordings and images
* three spectral bands - RGB (512x512 px)
* 10-second audio recordings
Dataset format:
* images are three-channel jpgs
* audio files are in wav format
Dataset classes:
0. airport
1. beach
2. bridge
3. farmland
4. forest
5. grassland
6. harbour
7. lake
8. orchard
9. residential
10. sparse shrub land
11. sports land
12. train station
If you use this dataset in your research, please cite the following paper:
* https://doi.org/10.1007/978-3-030-58586-0_5
.. note::
This dataset requires the following additional library to be installed:
* `scipy <https://pypi.org/project/scipy/>`_ to load the audio files to tensors
"""
urls = [
"https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1",
"https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1",
]
filenames = ["ADVANCE_vision.zip", "ADVANCE_sound.zip"]
md5s = ["a9e8748219ef5864d3b5a8979a67b471", "a2d12f2d2a64f5c3d3a9d8c09aaf1c31"]
directories = ["vision", "sound"]
classes = [
"airport",
"beach",
"bridge",
"farmland",
"forest",
"grassland",
"harbour",
"lake",
"orchard",
"residential",
"sparse shrub land",
"sports land",
"train station",
]
def __init__(
self,
root: str = "data",
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new ADVANCE dataset instance.
Args:
root: root directory where dataset can be found
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
self.root = root
self.transforms = transforms
self.checksum = checksum
if download:
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
self.files = self._load_files(self.root)
self.classes = sorted(set(f["cls"] for f in self.files))
self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)}
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
files = self.files[index]
image = self._load_image(files["image"])
audio = self._load_target(files["audio"])
cls_label = self.class_to_idx[files["cls"]]
label = torch.tensor(cls_label, dtype=torch.long) # type: ignore[attr-defined]
sample = {"image": image, "audio": audio, "label": label}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.files)
def _load_files(self, root: str) -> List[Dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
root: root dir of dataset
Returns:
list of dicts containing paths for each pair of image, audio, label
"""
images = sorted(glob.glob(os.path.join(root, "vision", "**", "*.jpg")))
wavs = sorted(glob.glob(os.path.join(root, "sound", "**", "*.wav")))
labels = [image.split(os.sep)[-2] for image in images]
files = [
dict(image=image, audio=wav, cls=label)
for image, wav, label in zip(images, wavs, labels)
]
return files
def _load_image(self, path: str) -> Tensor:
"""Load a single image.
Args:
path: path to the image
Returns:
the image
"""
with Image.open(path) as img:
array = np.array(img.convert("RGB"))
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor
def _load_target(self, path: str) -> Tensor:
"""Load the target audio for a single image.
Args:
path: path to the target
Returns:
the target audio
"""
try:
from scipy.io import wavfile
except ImportError:
raise ImportError(
"scipy is not installed and is required to use this dataset"
)
array = wavfile.read(path)[1]
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
tensor = tensor.unsqueeze(0)
return tensor
def _check_integrity(self) -> bool:
"""Checks the integrity of the dataset structure.
Returns:
True if the dataset directories are found, else False
"""
for directory in self.directories:
filepath = os.path.join(self.root, directory)
if not os.path.exists(filepath):
return False
return True
def _download(self) -> None:
"""Download the dataset and extract it.
Raises:
AssertionError: if the checksum of split.py does not match
"""
if self._check_integrity():
print("Files already downloaded and verified")
return
for filename, url, md5 in zip(self.filenames, self.urls, self.md5s):
download_and_extract_archive(
url,
self.root,
filename=filename,
md5=md5 if self.checksum else None,
)