зеркало из https://github.com/microsoft/torchgeo.git
Add PreChippedGeoSampler for pre-chipped geospatial datasets (#479)
* Add PreChippedGeoSampler for pre-chipped geospatial datasets * Add shuffle parameter * Add tests, fix type hints * Warn about multi-CRS datasets
This commit is contained in:
Родитель
f37c154fed
Коммит
e8474e46e2
|
@ -32,6 +32,11 @@ Grid Geo Sampler
|
|||
|
||||
.. autoclass:: GridGeoSampler
|
||||
|
||||
Pre-chipped Geo Sampler
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: PreChippedGeoSampler
|
||||
|
||||
Batch Samplers
|
||||
--------------
|
||||
|
||||
|
|
|
@ -11,7 +11,13 @@ from rasterio.crs import CRS
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
|
||||
from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler, Units
|
||||
from torchgeo.samplers import (
|
||||
GeoSampler,
|
||||
GridGeoSampler,
|
||||
PreChippedGeoSampler,
|
||||
RandomGeoSampler,
|
||||
Units,
|
||||
)
|
||||
|
||||
|
||||
class CustomGeoSampler(GeoSampler):
|
||||
|
@ -189,3 +195,48 @@ class TestGridGeoSampler:
|
|||
)
|
||||
for _ in dl:
|
||||
continue
|
||||
|
||||
|
||||
class TestPreChippedGeoSampler:
|
||||
@pytest.fixture(scope="class")
|
||||
def dataset(self) -> CustomGeoDataset:
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 20, 0, 20, 0, 20))
|
||||
ds.index.insert(1, (0, 30, 0, 30, 0, 30))
|
||||
return ds
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def sampler(self, dataset: CustomGeoDataset) -> PreChippedGeoSampler:
|
||||
return PreChippedGeoSampler(dataset, shuffle=True)
|
||||
|
||||
def test_iter(self, sampler: GridGeoSampler) -> None:
|
||||
for _ in sampler:
|
||||
continue
|
||||
|
||||
def test_len(self, sampler: GridGeoSampler) -> None:
|
||||
assert len(sampler) == 2
|
||||
|
||||
def test_roi(self, dataset: CustomGeoDataset) -> None:
|
||||
roi = BoundingBox(5, 15, 5, 15, 5, 15)
|
||||
sampler = PreChippedGeoSampler(dataset, roi=roi)
|
||||
for query in sampler:
|
||||
assert query == roi
|
||||
|
||||
def test_point_data(self) -> None:
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 0, 0, 0, 0, 0))
|
||||
ds.index.insert(1, (1, 1, 1, 1, 1, 1))
|
||||
sampler = PreChippedGeoSampler(ds)
|
||||
for _ in sampler:
|
||||
continue
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||
def test_dataloader(
|
||||
self, dataset: CustomGeoDataset, sampler: PreChippedGeoSampler, num_workers: int
|
||||
) -> None:
|
||||
dl = DataLoader(
|
||||
dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
|
||||
)
|
||||
for _ in dl:
|
||||
continue
|
||||
|
|
|
@ -5,11 +5,12 @@
|
|||
|
||||
from .batch import BatchGeoSampler, RandomBatchGeoSampler
|
||||
from .constants import Units
|
||||
from .single import GeoSampler, GridGeoSampler, RandomGeoSampler
|
||||
from .single import GeoSampler, GridGeoSampler, PreChippedGeoSampler, RandomGeoSampler
|
||||
|
||||
__all__ = (
|
||||
# Samplers
|
||||
"GridGeoSampler",
|
||||
"PreChippedGeoSampler",
|
||||
"RandomGeoSampler",
|
||||
# Batch samplers
|
||||
"RandomBatchGeoSampler",
|
||||
|
|
|
@ -5,8 +5,9 @@
|
|||
|
||||
import abc
|
||||
import random
|
||||
from typing import Iterator, Optional, Tuple, Union
|
||||
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from rtree.index import Index, Property
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
@ -240,3 +241,62 @@ class GridGeoSampler(GeoSampler):
|
|||
number of patches that will be sampled
|
||||
"""
|
||||
return self.length
|
||||
|
||||
|
||||
class PreChippedGeoSampler(GeoSampler):
|
||||
"""Samples entire files at a time.
|
||||
|
||||
This is particularly useful for datasets that contain geospatial metadata
|
||||
and subclass :class:`~torchgeo.datasets.GeoDataset` but have already been
|
||||
pre-processed into :term:`chips <chip>`.
|
||||
|
||||
This sampler should not be used with :class:`~torchgeo.datasets.VisionDataset`.
|
||||
You may encounter problems when using an :term:`ROI <region of interest (ROI)>`
|
||||
that partially intersects with one of the file bounding boxes, when using an
|
||||
:class:`~torchgeo.datasets.IntersectionDataset`, or when each file is in a
|
||||
different CRS. These issues can be solved by adding padding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: GeoDataset,
|
||||
roi: Optional[BoundingBox] = None,
|
||||
shuffle: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new Sampler instance.
|
||||
|
||||
Args:
|
||||
dataset: dataset to index from
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
shuffle: if True, reshuffle data at every epoch
|
||||
|
||||
.. versionadded:: 0.3
|
||||
"""
|
||||
super().__init__(dataset, roi)
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.hits = []
|
||||
for hit in self.index.intersection(tuple(self.roi), objects=True):
|
||||
self.hits.append(hit)
|
||||
|
||||
def __iter__(self) -> Iterator[BoundingBox]:
|
||||
"""Return the index of a dataset.
|
||||
|
||||
Returns:
|
||||
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
|
||||
"""
|
||||
generator: Callable[[int], Iterable[int]] = range
|
||||
if self.shuffle:
|
||||
generator = torch.randperm
|
||||
|
||||
for idx in generator(len(self)):
|
||||
yield BoundingBox(*self.hits[idx].bounds)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of samples over the ROI.
|
||||
|
||||
Returns:
|
||||
number of patches that will be sampled
|
||||
"""
|
||||
return len(self.hits)
|
||||
|
|
Загрузка…
Ссылка в новой задаче