Add PreChippedGeoSampler for pre-chipped geospatial datasets (#479)

* Add PreChippedGeoSampler for pre-chipped geospatial datasets

* Add shuffle parameter

* Add tests, fix type hints

* Warn about multi-CRS datasets
This commit is contained in:
Adam J. Stewart 2022-04-05 11:10:39 -05:00 коммит произвёл GitHub
Родитель f37c154fed
Коммит e8474e46e2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 120 добавлений и 3 удалений

Просмотреть файл

@ -32,6 +32,11 @@ Grid Geo Sampler
.. autoclass:: GridGeoSampler
Pre-chipped Geo Sampler
^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: PreChippedGeoSampler
Batch Samplers
--------------

Просмотреть файл

@ -11,7 +11,13 @@ from rasterio.crs import CRS
from torch.utils.data import DataLoader
from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler, Units
from torchgeo.samplers import (
GeoSampler,
GridGeoSampler,
PreChippedGeoSampler,
RandomGeoSampler,
Units,
)
class CustomGeoSampler(GeoSampler):
@ -189,3 +195,48 @@ class TestGridGeoSampler:
)
for _ in dl:
continue
class TestPreChippedGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 20, 0, 20, 0, 20))
ds.index.insert(1, (0, 30, 0, 30, 0, 30))
return ds
@pytest.fixture(scope="function")
def sampler(self, dataset: CustomGeoDataset) -> PreChippedGeoSampler:
return PreChippedGeoSampler(dataset, shuffle=True)
def test_iter(self, sampler: GridGeoSampler) -> None:
for _ in sampler:
continue
def test_len(self, sampler: GridGeoSampler) -> None:
assert len(sampler) == 2
def test_roi(self, dataset: CustomGeoDataset) -> None:
roi = BoundingBox(5, 15, 5, 15, 5, 15)
sampler = PreChippedGeoSampler(dataset, roi=roi)
for query in sampler:
assert query == roi
def test_point_data(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 0, 0, 0, 0, 0))
ds.index.insert(1, (1, 1, 1, 1, 1, 1))
sampler = PreChippedGeoSampler(ds)
for _ in sampler:
continue
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(
self, dataset: CustomGeoDataset, sampler: PreChippedGeoSampler, num_workers: int
) -> None:
dl = DataLoader(
dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
)
for _ in dl:
continue

Просмотреть файл

@ -5,11 +5,12 @@
from .batch import BatchGeoSampler, RandomBatchGeoSampler
from .constants import Units
from .single import GeoSampler, GridGeoSampler, RandomGeoSampler
from .single import GeoSampler, GridGeoSampler, PreChippedGeoSampler, RandomGeoSampler
__all__ = (
# Samplers
"GridGeoSampler",
"PreChippedGeoSampler",
"RandomGeoSampler",
# Batch samplers
"RandomBatchGeoSampler",

Просмотреть файл

@ -5,8 +5,9 @@
import abc
import random
from typing import Iterator, Optional, Tuple, Union
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union
import torch
from rtree.index import Index, Property
from torch.utils.data import Sampler
@ -240,3 +241,62 @@ class GridGeoSampler(GeoSampler):
number of patches that will be sampled
"""
return self.length
class PreChippedGeoSampler(GeoSampler):
"""Samples entire files at a time.
This is particularly useful for datasets that contain geospatial metadata
and subclass :class:`~torchgeo.datasets.GeoDataset` but have already been
pre-processed into :term:`chips <chip>`.
This sampler should not be used with :class:`~torchgeo.datasets.VisionDataset`.
You may encounter problems when using an :term:`ROI <region of interest (ROI)>`
that partially intersects with one of the file bounding boxes, when using an
:class:`~torchgeo.datasets.IntersectionDataset`, or when each file is in a
different CRS. These issues can be solved by adding padding.
"""
def __init__(
self,
dataset: GeoDataset,
roi: Optional[BoundingBox] = None,
shuffle: bool = False,
) -> None:
"""Initialize a new Sampler instance.
Args:
dataset: dataset to index from
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
shuffle: if True, reshuffle data at every epoch
.. versionadded:: 0.3
"""
super().__init__(dataset, roi)
self.shuffle = shuffle
self.hits = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
self.hits.append(hit)
def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
Returns:
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
generator: Callable[[int], Iterable[int]] = range
if self.shuffle:
generator = torch.randperm
for idx in generator(len(self)):
yield BoundingBox(*self.hits[idx].bounds)
def __len__(self) -> int:
"""Return the number of samples over the ROI.
Returns:
number of patches that will be sampled
"""
return len(self.hits)