зеркало из https://github.com/microsoft/torchgeo.git
Add PreChippedGeoSampler for pre-chipped geospatial datasets (#479)
* Add PreChippedGeoSampler for pre-chipped geospatial datasets * Add shuffle parameter * Add tests, fix type hints * Warn about multi-CRS datasets
This commit is contained in:
Родитель
f37c154fed
Коммит
e8474e46e2
|
@ -32,6 +32,11 @@ Grid Geo Sampler
|
||||||
|
|
||||||
.. autoclass:: GridGeoSampler
|
.. autoclass:: GridGeoSampler
|
||||||
|
|
||||||
|
Pre-chipped Geo Sampler
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. autoclass:: PreChippedGeoSampler
|
||||||
|
|
||||||
Batch Samplers
|
Batch Samplers
|
||||||
--------------
|
--------------
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,13 @@ from rasterio.crs import CRS
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
|
from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
|
||||||
from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler, Units
|
from torchgeo.samplers import (
|
||||||
|
GeoSampler,
|
||||||
|
GridGeoSampler,
|
||||||
|
PreChippedGeoSampler,
|
||||||
|
RandomGeoSampler,
|
||||||
|
Units,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CustomGeoSampler(GeoSampler):
|
class CustomGeoSampler(GeoSampler):
|
||||||
|
@ -189,3 +195,48 @@ class TestGridGeoSampler:
|
||||||
)
|
)
|
||||||
for _ in dl:
|
for _ in dl:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreChippedGeoSampler:
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def dataset(self) -> CustomGeoDataset:
|
||||||
|
ds = CustomGeoDataset()
|
||||||
|
ds.index.insert(0, (0, 20, 0, 20, 0, 20))
|
||||||
|
ds.index.insert(1, (0, 30, 0, 30, 0, 30))
|
||||||
|
return ds
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def sampler(self, dataset: CustomGeoDataset) -> PreChippedGeoSampler:
|
||||||
|
return PreChippedGeoSampler(dataset, shuffle=True)
|
||||||
|
|
||||||
|
def test_iter(self, sampler: GridGeoSampler) -> None:
|
||||||
|
for _ in sampler:
|
||||||
|
continue
|
||||||
|
|
||||||
|
def test_len(self, sampler: GridGeoSampler) -> None:
|
||||||
|
assert len(sampler) == 2
|
||||||
|
|
||||||
|
def test_roi(self, dataset: CustomGeoDataset) -> None:
|
||||||
|
roi = BoundingBox(5, 15, 5, 15, 5, 15)
|
||||||
|
sampler = PreChippedGeoSampler(dataset, roi=roi)
|
||||||
|
for query in sampler:
|
||||||
|
assert query == roi
|
||||||
|
|
||||||
|
def test_point_data(self) -> None:
|
||||||
|
ds = CustomGeoDataset()
|
||||||
|
ds.index.insert(0, (0, 0, 0, 0, 0, 0))
|
||||||
|
ds.index.insert(1, (1, 1, 1, 1, 1, 1))
|
||||||
|
sampler = PreChippedGeoSampler(ds)
|
||||||
|
for _ in sampler:
|
||||||
|
continue
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("num_workers", [0, 1, 2])
|
||||||
|
def test_dataloader(
|
||||||
|
self, dataset: CustomGeoDataset, sampler: PreChippedGeoSampler, num_workers: int
|
||||||
|
) -> None:
|
||||||
|
dl = DataLoader(
|
||||||
|
dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
|
||||||
|
)
|
||||||
|
for _ in dl:
|
||||||
|
continue
|
||||||
|
|
|
@ -5,11 +5,12 @@
|
||||||
|
|
||||||
from .batch import BatchGeoSampler, RandomBatchGeoSampler
|
from .batch import BatchGeoSampler, RandomBatchGeoSampler
|
||||||
from .constants import Units
|
from .constants import Units
|
||||||
from .single import GeoSampler, GridGeoSampler, RandomGeoSampler
|
from .single import GeoSampler, GridGeoSampler, PreChippedGeoSampler, RandomGeoSampler
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
# Samplers
|
# Samplers
|
||||||
"GridGeoSampler",
|
"GridGeoSampler",
|
||||||
|
"PreChippedGeoSampler",
|
||||||
"RandomGeoSampler",
|
"RandomGeoSampler",
|
||||||
# Batch samplers
|
# Batch samplers
|
||||||
"RandomBatchGeoSampler",
|
"RandomBatchGeoSampler",
|
||||||
|
|
|
@ -5,8 +5,9 @@
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import random
|
import random
|
||||||
from typing import Iterator, Optional, Tuple, Union
|
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from rtree.index import Index, Property
|
from rtree.index import Index, Property
|
||||||
from torch.utils.data import Sampler
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
|
@ -240,3 +241,62 @@ class GridGeoSampler(GeoSampler):
|
||||||
number of patches that will be sampled
|
number of patches that will be sampled
|
||||||
"""
|
"""
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
|
class PreChippedGeoSampler(GeoSampler):
|
||||||
|
"""Samples entire files at a time.
|
||||||
|
|
||||||
|
This is particularly useful for datasets that contain geospatial metadata
|
||||||
|
and subclass :class:`~torchgeo.datasets.GeoDataset` but have already been
|
||||||
|
pre-processed into :term:`chips <chip>`.
|
||||||
|
|
||||||
|
This sampler should not be used with :class:`~torchgeo.datasets.VisionDataset`.
|
||||||
|
You may encounter problems when using an :term:`ROI <region of interest (ROI)>`
|
||||||
|
that partially intersects with one of the file bounding boxes, when using an
|
||||||
|
:class:`~torchgeo.datasets.IntersectionDataset`, or when each file is in a
|
||||||
|
different CRS. These issues can be solved by adding padding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: GeoDataset,
|
||||||
|
roi: Optional[BoundingBox] = None,
|
||||||
|
shuffle: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a new Sampler instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: dataset to index from
|
||||||
|
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||||
|
(defaults to the bounds of ``dataset.index``)
|
||||||
|
shuffle: if True, reshuffle data at every epoch
|
||||||
|
|
||||||
|
.. versionadded:: 0.3
|
||||||
|
"""
|
||||||
|
super().__init__(dataset, roi)
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
|
self.hits = []
|
||||||
|
for hit in self.index.intersection(tuple(self.roi), objects=True):
|
||||||
|
self.hits.append(hit)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[BoundingBox]:
|
||||||
|
"""Return the index of a dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
|
||||||
|
"""
|
||||||
|
generator: Callable[[int], Iterable[int]] = range
|
||||||
|
if self.shuffle:
|
||||||
|
generator = torch.randperm
|
||||||
|
|
||||||
|
for idx in generator(len(self)):
|
||||||
|
yield BoundingBox(*self.hits[idx].bounds)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of samples over the ROI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
number of patches that will be sampled
|
||||||
|
"""
|
||||||
|
return len(self.hits)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче