Add VisionClassificationDataset (#171)

* updated docs

* added VisionClassificationDataset

* refactor PatternNet and RESISC45 to inherit VisionClassificationDataset

* added unit tests and sample data

* refactor PatternNet and RESISC45 to new download/verify checks and fix code coverage

* remove override of __str__

* set default_loader as loader default

* removed loader arg from datasets

* update tests

* format

* remove duplicate code

* updated docstrings
This commit is contained in:
isaac 2021-09-27 20:55:50 -05:00 коммит произвёл GitHub
Родитель 79476bb42a
Коммит a3b636fe99
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 247 добавлений и 188 удалений

Просмотреть файл

@ -189,6 +189,11 @@ VisionDataset
.. autoclass:: VisionDataset
VisionClassificationDataset
^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: VisionClassificationDataset
ZipDataset
^^^^^^^^^^

Двоичные данные
tests/data/visionclassificationdataset/class0/001.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 631 B

Двоичные данные
tests/data/visionclassificationdataset/class1/001.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 631 B

Просмотреть файл

@ -17,6 +17,7 @@ from torchgeo.datasets import (
Landsat8,
RasterDataset,
VectorDataset,
VisionClassificationDataset,
VisionDataset,
ZipDataset,
)
@ -162,6 +163,55 @@ class TestVisionDataset:
VisionDataset() # type: ignore[abstract]
class TestVisionClassificationDataset:
@pytest.fixture(scope="class")
def dataset(self, root: str) -> VisionClassificationDataset:
return VisionClassificationDataset(root)
@pytest.fixture(scope="class")
def root(self) -> str:
root = os.path.join("tests", "data", "visionclassificationdataset")
return root
def test_getitem(self, dataset: VisionClassificationDataset) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert x["image"].shape[0] == 3
def test_len(self, dataset: VisionClassificationDataset) -> None:
assert len(dataset) == 2
def test_add_two(self, root: str) -> None:
ds1 = VisionClassificationDataset(root)
ds2 = VisionClassificationDataset(root)
dataset = ds1 + ds2
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 4
def test_add_three(self, root: str) -> None:
ds1 = VisionClassificationDataset(root)
ds2 = VisionClassificationDataset(root)
ds3 = VisionClassificationDataset(root)
dataset = ds1 + ds2 + ds3
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 6
def test_add_four(self, root: str) -> None:
ds1 = VisionClassificationDataset(root)
ds2 = VisionClassificationDataset(root)
ds3 = VisionClassificationDataset(root)
ds4 = VisionClassificationDataset(root)
dataset = (ds1 + ds2) + (ds3 + ds4)
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 8
def test_str(self, dataset: VisionClassificationDataset) -> None:
assert "type: VisionDataset" in str(dataset)
assert "size: 2" in str(dataset)
class TestZipDataset:
@pytest.fixture(scope="class")
def dataset(self) -> ZipDataset:

Просмотреть файл

@ -15,7 +15,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import PatternNet
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -25,7 +25,7 @@ class TestPatternNet:
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
) -> PatternNet:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
torchgeo.datasets.patternnet, "download_url", download_url
)
md5 = "5649754c78219a2c19074ff93666cc61"
monkeypatch.setattr(PatternNet, "md5", md5) # type: ignore[attr-defined]
@ -45,9 +45,19 @@ class TestPatternNet:
def test_len(self, dataset: PatternNet) -> None:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: PatternNet) -> None:
PatternNet(root=dataset.root, download=True)
def test_already_downloaded(self, dataset: PatternNet, tmp_path: Path) -> None:
PatternNet(root=str(tmp_path), download=True)
def test_already_downloaded_not_extracted(
self, dataset: PatternNet, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
download_url(dataset.url, root=str(tmp_path))
PatternNet(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
PatternNet(str(tmp_path))

Просмотреть файл

@ -18,7 +18,7 @@ from torchgeo.datasets import RESISC45
pytest.importorskip("rarfile")
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -29,7 +29,7 @@ class TestRESISC45:
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
) -> RESISC45:
monkeypatch.setattr( # type: ignore[attr-defined]
torchgeo.datasets.utils, "download_url", download_url
torchgeo.datasets.resisc45, "download_url", download_url
)
md5 = "9c221122164d17b8118d2b6527ee5e9c"
monkeypatch.setattr(RESISC45, "md5", md5) # type: ignore[attr-defined]
@ -49,9 +49,19 @@ class TestRESISC45:
def test_len(self, dataset: RESISC45) -> None:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: RESISC45) -> None:
RESISC45(root=dataset.root, download=True)
def test_already_downloaded(self, dataset: RESISC45, tmp_path: Path) -> None:
RESISC45(root=str(tmp_path), download=True)
def test_already_downloaded_not_extracted(
self, dataset: RESISC45, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
download_url(dataset.url, root=str(tmp_path))
RESISC45(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
RESISC45(str(tmp_path))

Просмотреть файл

@ -25,7 +25,14 @@ from .cv4a_kenya_crop_type import CV4AKenyaCropType
from .cyclone import TropicalCycloneWindEstimation
from .etci2021 import ETCI2021
from .eurosat import EuroSAT
from .geo import GeoDataset, RasterDataset, VectorDataset, VisionDataset, ZipDataset
from .geo import (
GeoDataset,
RasterDataset,
VectorDataset,
VisionClassificationDataset,
VisionDataset,
ZipDataset,
)
from .gid15 import GID15
from .landcoverai import LandCoverAI
from .landsat import (
@ -109,6 +116,7 @@ __all__ = (
"RasterDataset",
"VectorDataset",
"VisionDataset",
"VisionClassificationDataset",
"ZipDataset",
# Utilities
"BoundingBox",

Просмотреть файл

@ -25,6 +25,7 @@ from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.datasets.folder import ImageFolder, default_loader
from .utils import BoundingBox, disambiguate_timestamp
@ -578,6 +579,79 @@ class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
size: {len(self)}"""
class VisionClassificationDataset(VisionDataset, ImageFolder): # type: ignore[misc]
"""Abstract base class for classification datasets lacking geospatial information.
This base class is designed for datasets with pre-defined image chips which
are separated into separate folders per class.
"""
def __init__(
self,
root: str,
transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
loader: Optional[Callable[[str], Any]] = default_loader,
) -> None:
"""Initialize a new VisionClassificationDataset instance.
Args:
root: root directory where dataset can be found
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
loader: a callable function which takes as input a path to an image and
returns a PIL Image or numpy array
"""
# When transform & target_transform are None, ImageFolder.__getitem__(index)
# returns a PIL.Image and int for image and label, respectively
super().__init__(
root=root, transform=None, target_transform=None, loader=loader
)
# Must be set after calling super().__init__()
self.transforms = transforms
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
image, label = self._load_image(index)
sample = {"image": image, "label": label}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.imgs)
def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
"""Load a single image and it's class label.
Args:
index: index to return
Returns:
the image
the image class label
"""
img, label = ImageFolder.__getitem__(self, index)
array = np.array(img)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
label = torch.tensor(label) # type: ignore[attr-defined]
return tensor, label
class ZipDataset(GeoDataset):
"""Dataset for merging two or more GeoDatasets.

Просмотреть файл

@ -4,18 +4,15 @@
"""PatternNet dataset."""
import os
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional
import numpy as np
import torch
from torch import Tensor
from torchvision.datasets import ImageFolder
from .geo import VisionDataset
from .utils import download_and_extract_archive
from .geo import VisionClassificationDataset
from .utils import download_url, extract_archive
class PatternNet(VisionDataset, ImageFolder): # type: ignore[misc]
class PatternNet(VisionClassificationDataset):
"""PatternNet dataset.
The `PatternNet <https://sites.google.com/view/zhouwx/dataset>`_
@ -97,100 +94,55 @@ class PatternNet(VisionDataset, ImageFolder): # type: ignore[misc]
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
self.root = root
self.download = download
self.checksum = checksum
if download:
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
# When transform & target_transform are None, ImageFolder.__getitem__[index]
# returns a PIL.Image and int for image and label, respectively
self._verify()
super().__init__(
root=os.path.join(root, "images"), transform=None, target_transform=None
root=os.path.join(root, self.directory),
transforms=transforms,
)
# Must be set after calling super().__init__()
self.transforms = transforms
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
image, label = self._load_image(index)
sample = {"image": image, "label": label}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.imgs)
def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
"""Load a single image and it's class label.
Args:
index: index to return
Returns:
the image
the image class label
"""
img, label = ImageFolder.__getitem__(self, index)
array = np.array(img)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
label = torch.tensor(label) # type: ignore[attr-defined]
return tensor, label
def _check_integrity(self) -> bool:
"""Checks the integrity of the dataset structure.
Returns:
True if the dataset directories and split files are found, else False
"""
filepath = os.path.join(self.root, self.directory)
if not os.path.exists(filepath):
return False
return True
def _download(self) -> None:
"""Download the dataset and extract it.
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
AssertionError: if the checksum of split.py does not match
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
if self._check_integrity():
print("Files already downloaded and verified")
# Check if the files already exist
filepath = os.path.join(self.root, self.directory)
if os.path.exists(filepath):
return
download_and_extract_archive(
# Check if zip file already exists (if so then extract)
filepath = os.path.join(self.root, self.filename)
if os.path.exists(filepath):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
"Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)
# Download and extract the dataset
self._download()
self._extract()
def _download(self) -> None:
"""Download the dataset."""
download_url(
self.url,
self.root,
filename=self.filename,
md5=self.md5 if self.checksum else None,
)
def _extract(self) -> None:
"""Extract the dataset."""
filepath = os.path.join(self.root, self.filename)
extract_archive(filepath)

Просмотреть файл

@ -4,18 +4,15 @@
"""RESISC45 dataset."""
import os
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Dict, Optional
import numpy as np
import torch
from torch import Tensor
from torchvision.datasets import ImageFolder
from .geo import VisionDataset
from .utils import download_and_extract_archive
from .geo import VisionClassificationDataset
from .utils import download_url, extract_archive
class RESISC45(VisionDataset, ImageFolder): # type: ignore[misc]
class RESISC45(VisionClassificationDataset):
"""RESISC45 dataset.
The `RESISC45 <http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html>`_
@ -107,102 +104,55 @@ class RESISC45(VisionDataset, ImageFolder): # type: ignore[misc]
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
"""
self.root = root
self.download = download
self.checksum = checksum
if download:
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
# When transform & target_transform are None, ImageFolder.__getitem__[index]
# returns a PIL.Image and int for image and label, respectively
self._verify()
super().__init__(
root=os.path.join(root, self.directory),
transform=None,
target_transform=None,
transforms=transforms,
)
# Must be set after calling super().__init__()
self.transforms = transforms
def __getitem__(self, index: int) -> Dict[str, Tensor]:
"""Return an index within the dataset.
Args:
index: index to return
Returns:
data and label at that index
"""
image, label = self._load_image(index)
sample = {"image": image, "label": label}
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def __len__(self) -> int:
"""Return the number of data points in the dataset.
Returns:
length of the dataset
"""
return len(self.imgs)
def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
"""Load a single image and it's class label.
Args:
index: index to return
Returns:
the image
the image class label
"""
img, label = ImageFolder.__getitem__(self, index)
array = np.array(img)
tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined]
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
label = torch.tensor(label) # type: ignore[attr-defined]
return tensor, label
def _check_integrity(self) -> bool:
"""Checks the integrity of the dataset structure.
Returns:
True if the dataset directories and split files are found, else False
"""
filepath = os.path.join(self.root, self.directory)
if not os.path.exists(filepath):
return False
return True
def _download(self) -> None:
"""Download the dataset and extract it.
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
AssertionError: if the checksum of split.py does not match
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
if self._check_integrity():
print("Files already downloaded and verified")
# Check if the files already exist
filepath = os.path.join(self.root, self.directory)
if os.path.exists(filepath):
return
download_and_extract_archive(
# Check if zip file already exists (if so then extract)
filepath = os.path.join(self.root, self.filename)
if os.path.exists(filepath):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
"Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automaticaly download the dataset."
)
# Download and extract the dataset
self._download()
self._extract()
def _download(self) -> None:
"""Download the dataset."""
download_url(
self.url,
self.root,
filename=self.filename,
md5=self.md5 if self.checksum else None,
)
def _extract(self) -> None:
"""Extract the dataset."""
filepath = os.path.join(self.root, self.filename)
extract_archive(filepath)