This commit is contained in:
Adam J. Stewart 2021-08-11 21:32:23 +00:00
Родитель d5b4a5c06e
Коммит c385433ca3
9 изменённых файлов: 349 добавлений и 71 удалений

Просмотреть файл

@ -94,5 +94,6 @@ intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"pytorch-lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
"rasterio": ("https://rasterio.readthedocs.io/en/latest/", None),
"rtree": ("https://rtree.readthedocs.io/en/latest/", None),
"torch": ("https://pytorch.org/docs/stable", None),
}

Просмотреть файл

@ -3,7 +3,7 @@ torchgeo.datasets
.. module:: torchgeo.datasets
In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets` and :ref:`Non-geospatial Datasets`. These abstract base classes are documented in more detail in :ref:`Base Classes`.
In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets` and :ref:`Non-geospatial Datasets`. These abstract base classes are documented in more detail in :ref:`Dataset Base Classes`.
Geospatial Datasets
-------------------
@ -107,8 +107,8 @@ NWPU VHR-10
.. autoclass:: VHR10
Base Classes
------------
Dataset Base Classes
--------------------
If you want to write your own custom dataset, you can extend one of these abstract base classes.

Просмотреть файл

@ -1,4 +1,68 @@
torchgeo.samplers
=================
.. automodule:: torchgeo.samplers
.. module:: torchgeo.samplers
Samplers
--------
Samplers are used to index a dataset, retrieving a single query at a time. For :class:`~torchgeo.datasets.VisionDataset`, dataset objects can be indexed with integers, and PyTorch's builtin samplers are sufficient. For :class:`~torchgeo.datasets.GeoDataset`, dataset objects require a bounding box for indexing. For this reason, we define our own :class:`GeoSampler` implementations below. These can be used like so:
.. code-block:: python
from torch.utils.data import DataLoader
from torchgeo.datasets import Landsat
from torchgeo.samplers import RandomGeoSampler
dataset = Landsat(...)
sampler = RandomGeoSampler(dataset.index, size=1000, length=100)
dataloader = DataLoader(dataset, sampler=sampler)
Random Geo Sampler
^^^^^^^^^^^^^^^^^^
.. autoclass:: RandomGeoSampler
Grid Geo Sampler
^^^^^^^^^^^^^^^^
.. autoclass:: GridGeoSampler
Batch Samplers
--------------
When working with large tile-based datasets, randomly sampling patches from each tile can be extremely time consuming. It's much more efficient to choose a tile, load it, warp it to the appropriate :term:`coordinate reference system (CRS)` and resolution, and then sample random patches from that tile to construct a mini-batch of data. For this reason, we define our own :class:`BatchGeoSampler` implementations below. These can be used like so:
.. code-block:: python
from torch.utils.data import DataLoader
from torchgeo.datasets import Landsat
from torchgeo.samplers import RandomBatchGeoSampler
dataset = Landsat(...)
sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100)
dataloader = DataLoader(dataset, batch_sampler=sampler)
Random Batch Geo Sampler
^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: RandomBatchGeoSampler
Sampler Base Classes
--------------------
If you want to write your own custom sampler, you can extend one of these abstract base classes.
Geo Sampler
^^^^^^^^^^^
.. autoclass:: GeoSampler
Batch Geo Sampler
^^^^^^^^^^^^^^^^^
.. autoclass:: BatchGeoSampler

Просмотреть файл

@ -0,0 +1,64 @@
import math
from typing import Iterator, List
import pytest
from _pytest.fixtures import SubRequest
from rtree.index import Index, Property
from torchgeo.datasets import BoundingBox
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler
class CustomBatchGeoSampler(BatchGeoSampler):
def __init__(self) -> None:
pass
def __iter__(self) -> Iterator[List[BoundingBox]]:
for i in range(len(self)):
yield [BoundingBox(j, j, j, j, j, j) for j in range(len(self))]
def __len__(self) -> int:
return 2
class TestBatchGeoSampler:
@pytest.fixture(scope="function")
def sampler(self) -> CustomBatchGeoSampler:
return CustomBatchGeoSampler()
def test_iter(self, sampler: CustomBatchGeoSampler) -> None:
expected = [BoundingBox(0, 0, 0, 0, 0, 0), BoundingBox(1, 1, 1, 1, 1, 1)]
assert next(iter(sampler)) == expected
def test_len(self, sampler: CustomBatchGeoSampler) -> None:
assert len(sampler) == 2
def test_abstract(self) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
BatchGeoSampler(None) # type: ignore[abstract]
class TestRandomBatchGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomBatchGeoSampler(index, size, batch_size=2, length=10)
def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
for batch in sampler:
for query in batch:
assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx
assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt
assert math.isclose(query.maxx - query.minx, sampler.size[1])
assert math.isclose(query.maxy - query.miny, sampler.size[0])
assert math.isclose(
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
assert len(sampler) == sampler.length // sampler.batch_size

Просмотреть файл

@ -3,6 +3,7 @@ from typing import Iterator
import pytest
from _pytest.fixtures import SubRequest
from rtree.index import Index, Property
from torchgeo.datasets import BoundingBox
from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler
@ -22,13 +23,13 @@ class CustomGeoSampler(GeoSampler):
class TestGeoSampler:
@pytest.fixture(scope="function")
def sampler(self) -> GeoSampler:
def sampler(self) -> CustomGeoSampler:
return CustomGeoSampler()
def test_iter(self, sampler: GeoSampler) -> None:
def test_iter(self, sampler: CustomGeoSampler) -> None:
assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0)
def test_len(self, sampler: GeoSampler) -> None:
def test_len(self, sampler: CustomGeoSampler) -> None:
assert len(sampler) == 2
def test_abstract(self) -> None:
@ -39,9 +40,11 @@ class TestGeoSampler:
class TestRandomGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomGeoSampler:
roi = BoundingBox(0, 10, 20, 30, 40, 50)
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomGeoSampler(roi, size, length=10)
return RandomGeoSampler(index, size, length=10)
def test_iter(self, sampler: RandomGeoSampler) -> None:
for query in sampler:
@ -72,9 +75,11 @@ class TestGridGeoSampler:
],
)
def sampler(self, request: SubRequest) -> GridGeoSampler:
roi = BoundingBox(0, 10, 20, 30, 40, 50)
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
size, stride = request.param
return GridGeoSampler(roi, size, stride)
return GridGeoSampler(index, size, stride)
def test_iter(self, sampler: GridGeoSampler) -> None:
for query in sampler:
@ -87,6 +92,3 @@ class TestGridGeoSampler:
assert math.isclose(
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)
def test_len(self, sampler: RandomGeoSampler) -> None:
assert len(sampler) == 9

Просмотреть файл

@ -1,8 +1,18 @@
"""TorchGeo samplers."""
from .samplers import GeoSampler, GridGeoSampler, RandomGeoSampler
from .batch import BatchGeoSampler, RandomBatchGeoSampler
from .single import GeoSampler, GridGeoSampler, RandomGeoSampler
__all__ = ("GeoSampler", "GridGeoSampler", "RandomGeoSampler")
__all__ = (
# Samplers
"GridGeoSampler",
"RandomGeoSampler",
# Batch samplers
"RandomBatchGeoSampler",
# Base classes
"GeoSampler",
"BatchGeoSampler",
)
# https://stackoverflow.com/questions/40018681
for module in __all__:

114
torchgeo/samplers/batch.py Normal file
Просмотреть файл

@ -0,0 +1,114 @@
"""TorchGeo batch samplers."""
import abc
import random
from typing import Iterator, List, Optional, Tuple, Union
from rtree.index import Index
from torch.utils.data import Sampler
from torchgeo.datasets import BoundingBox
from .utils import _to_tuple
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Sampler.__module__ = "torch.utils.data"
class BatchGeoSampler(Sampler[List[BoundingBox]], abc.ABC):
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
Unlike PyTorch's :class:`~torch.utils.data.BatchSampler`, :class:`BatchGeoSampler`
returns enough geospatial information to uniquely index any
:class:`~torchgeo.datasets.GeoDataset`. This includes things like latitude,
longitude, height, width, projection, coordinate system, and time.
"""
@abc.abstractmethod
def __iter__(self) -> Iterator[List[BoundingBox]]:
"""Return a batch of indices of a dataset.
Returns:
batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
class RandomBatchGeoSampler(BatchGeoSampler):
"""Samples batches of elements from a region of interest randomly.
This is particularly useful during training when you want to maximize the size of
the dataset and return as many random :term:`chips <chip>` as possible.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
from a tile-based dataset if possible.
"""
def __init__(
self,
index: Index,
size: Union[Tuple[float, float], float],
batch_size: int,
length: int,
roi: Optional[BoundingBox] = None,
) -> None:
"""Initialize a new Sampler instance.
The ``size`` argument can either be:
* a single ``float`` - in which case the same value is used for the height and
width dimension
* a ``tuple`` of two floats - in which case, the first *float* is used for the
height dimension, and the second *float* for the width dimension
Args:
index: index of a :class:`~torchgeo.datasets.GeoDataset`
size: dimensions of each :term:`patch` in units of CRS
batch_size: number of samples per batch
length: number of samples per epoch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``index``)
"""
self.index = index
self.size = _to_tuple(size)
self.batch_size = batch_size
self.length = length
if roi is None:
roi = BoundingBox(*index.bounds)
self.roi = roi
self.hits = list(index.intersection(roi, objects=True))
def __iter__(self) -> Iterator[List[BoundingBox]]:
"""Return the indices of a dataset.
Returns:
batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
for _ in range(len(self)):
# Choose a random tile
hit = random.choice(self.hits)
bounds = BoundingBox(*hit.bounds)
# Choose random indices within that tile
batch = []
for _ in range(self.batch_size):
minx = random.uniform(bounds.minx, bounds.maxx - self.size[1])
maxx = minx + self.size[1]
miny = random.uniform(bounds.miny, bounds.maxy - self.size[0])
maxy = miny + self.size[0]
mint = bounds.mint
maxt = bounds.maxt
batch.append(BoundingBox(minx, maxx, miny, maxy, mint, maxt))
yield batch
def __len__(self) -> int:
"""Return the number of batches in a single epoch.
Returns:
number of batches in an epoch
"""
return self.length // self.batch_size

Просмотреть файл

@ -2,33 +2,21 @@
import abc
import random
from typing import Any, Iterator, Tuple, Union
from typing import Iterator, Optional, Tuple, Union
from rtree.index import Index
from torch.utils.data import Sampler
from torchgeo.datasets import BoundingBox
from .utils import _to_tuple
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Sampler.__module__ = "torch.utils.data"
def _to_tuple(value: Union[Tuple[Any, Any], Any]) -> Tuple[Any, Any]:
"""Convert value to a tuple if it is not already a tuple.
Args:
value: input value
Returns:
value if value is a tuple, else (value, value)
"""
if isinstance(value, (float, int)):
return (value, value)
else:
return value
class GeoSampler(Sampler[Tuple[Any, ...]], abc.ABC):
class GeoSampler(Sampler[BoundingBox], abc.ABC):
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
Unlike PyTorch's :class:`~torch.utils.data.Sampler`, :class:`GeoSampler`
@ -45,26 +33,25 @@ class GeoSampler(Sampler[Tuple[Any, ...]], abc.ABC):
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
@abc.abstractmethod
def __len__(self) -> int:
"""Return the number of samples in a single epoch.
Returns:
length of the epoch
"""
class RandomGeoSampler(GeoSampler):
"""Samples elements from a region of interest randomly.
This is particularly useful during training when you want to maximize the size of
the dataset and return as many random :term:`chips <chip>` as possible.
This sampler is not recommended for use with tile-based datasets. Use
:class:`RandomBatchGeoSampler` instead.
"""
def __init__(
self, roi: BoundingBox, size: Union[Tuple[float, float], float], length: int
self,
index: Index,
size: Union[Tuple[float, float], float],
length: int,
roi: Optional[BoundingBox] = None,
) -> None:
"""Initialize a new RandomGeoSampler.
"""Initialize a new Sampler instance.
The ``size`` argument can either be:
@ -74,13 +61,19 @@ class RandomGeoSampler(GeoSampler):
height dimension, and the second *float* for the width dimension
Args:
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
index: index of a :class:`~torchgeo.datasets.GeoDataset`
size: dimensions of each :term:`patch` in units of CRS
length: number of random samples to draw per epoch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``index``)
"""
self.roi = roi
self.index = index
self.size = _to_tuple(size)
self.length = length
if roi is None:
roi = BoundingBox(*index.bounds)
self.roi = roi
self.hits = list(index.intersection(roi, objects=True))
def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
@ -89,15 +82,19 @@ class RandomGeoSampler(GeoSampler):
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
for _ in range(len(self)):
minx = random.uniform(self.roi.minx, self.roi.maxx - self.size[1])
# Choose a random tile
hit = random.choice(self.hits)
bounds = BoundingBox(*hit.bounds)
# Choose a random index within that tile
minx = random.uniform(bounds.minx, bounds.maxx - self.size[1])
maxx = minx + self.size[1]
miny = random.uniform(self.roi.miny, self.roi.maxy - self.size[0])
miny = random.uniform(bounds.miny, bounds.maxy - self.size[0])
maxy = miny + self.size[0]
# TODO: figure out how to handle time
mint = self.roi.mint
maxt = self.roi.maxt
mint = bounds.mint
maxt = bounds.maxt
yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)
@ -123,15 +120,19 @@ class GridGeoSampler(GeoSampler):
The overlap between each chip (``chip_size - stride``) should be approximately equal
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
the CNN.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
from a non-tile-based dataset if possible.
"""
def __init__(
self,
roi: BoundingBox,
index: Index,
size: Union[Tuple[float, float], float],
stride: Union[Tuple[float, float], float],
roi: Optional[BoundingBox] = None,
) -> None:
"""Initialize a new RandomGeoSampler.
"""Initialize a new Sampler instance.
The ``size`` and ``stride`` arguments can either be:
@ -141,15 +142,18 @@ class GridGeoSampler(GeoSampler):
height dimension, and the second *float* for the width dimension
Args:
index: index of a :class:`~torchgeo.datasets.GeoDataset`
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
size: dimensions of each :term:`patch` in units of CRS
stride: distance to skip between each patch
"""
self.roi = roi
self.index = index
self.size = _to_tuple(size)
self.stride = _to_tuple(stride)
self.rows = int((roi.maxy - roi.miny - self.size[0]) // self.stride[0]) + 1
self.cols = int((roi.maxx - roi.minx - self.size[1]) // self.stride[1]) + 1
if roi is None:
roi = BoundingBox(*index.bounds)
self.roi = roi
self.hits = index.intersection(roi, objects=True)
def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
@ -157,23 +161,24 @@ class GridGeoSampler(GeoSampler):
Returns:
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
for i in range(self.rows):
miny = self.roi.miny + i * self.stride[0]
maxy = miny + self.size[0]
for j in range(self.cols):
minx = self.roi.minx + j * self.stride[1]
maxx = minx + self.size[1]
# For each tile...
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)
# TODO: figure out how to handle time
mint = self.roi.mint
maxt = self.roi.maxt
rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)
mint = bounds.mint
maxt = bounds.maxt
def __len__(self) -> int:
"""Return the number of samples in a single epoch.
# For each row...
for i in range(rows):
miny = bounds.miny + i * self.stride[0]
maxy = miny + self.size[0]
Returns:
length of the epoch
"""
return self.rows * self.cols
# For each column...
for j in range(cols):
minx = bounds.minx + j * self.stride[1]
maxx = minx + self.size[1]
yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)

Просмотреть файл

@ -0,0 +1,18 @@
"""Common sampler utilities."""
from typing import Tuple, Union
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
"""Convert value to a tuple if it is not already a tuple.
Args:
value: input value
Returns:
value if value is a tuple, else (value, value)
"""
if isinstance(value, (float, int)):
return (value, value)
else:
return value