зеркало из https://github.com/microsoft/torchgeo.git
More intelligent sampling
This commit is contained in:
Родитель
d5b4a5c06e
Коммит
c385433ca3
|
@ -94,5 +94,6 @@ intersphinx_mapping = {
|
|||
"python": ("https://docs.python.org/3", None),
|
||||
"pytorch-lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
|
||||
"rasterio": ("https://rasterio.readthedocs.io/en/latest/", None),
|
||||
"rtree": ("https://rtree.readthedocs.io/en/latest/", None),
|
||||
"torch": ("https://pytorch.org/docs/stable", None),
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ torchgeo.datasets
|
|||
|
||||
.. module:: torchgeo.datasets
|
||||
|
||||
In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets` and :ref:`Non-geospatial Datasets`. These abstract base classes are documented in more detail in :ref:`Base Classes`.
|
||||
In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets` and :ref:`Non-geospatial Datasets`. These abstract base classes are documented in more detail in :ref:`Dataset Base Classes`.
|
||||
|
||||
Geospatial Datasets
|
||||
-------------------
|
||||
|
@ -107,8 +107,8 @@ NWPU VHR-10
|
|||
|
||||
.. autoclass:: VHR10
|
||||
|
||||
Base Classes
|
||||
------------
|
||||
Dataset Base Classes
|
||||
--------------------
|
||||
|
||||
If you want to write your own custom dataset, you can extend one of these abstract base classes.
|
||||
|
||||
|
|
|
@ -1,4 +1,68 @@
|
|||
torchgeo.samplers
|
||||
=================
|
||||
|
||||
.. automodule:: torchgeo.samplers
|
||||
.. module:: torchgeo.samplers
|
||||
|
||||
Samplers
|
||||
--------
|
||||
|
||||
Samplers are used to index a dataset, retrieving a single query at a time. For :class:`~torchgeo.datasets.VisionDataset`, dataset objects can be indexed with integers, and PyTorch's builtin samplers are sufficient. For :class:`~torchgeo.datasets.GeoDataset`, dataset objects require a bounding box for indexing. For this reason, we define our own :class:`GeoSampler` implementations below. These can be used like so:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchgeo.datasets import Landsat
|
||||
from torchgeo.samplers import RandomGeoSampler
|
||||
|
||||
dataset = Landsat(...)
|
||||
sampler = RandomGeoSampler(dataset.index, size=1000, length=100)
|
||||
dataloader = DataLoader(dataset, sampler=sampler)
|
||||
|
||||
|
||||
Random Geo Sampler
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: RandomGeoSampler
|
||||
|
||||
Grid Geo Sampler
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: GridGeoSampler
|
||||
|
||||
Batch Samplers
|
||||
--------------
|
||||
|
||||
When working with large tile-based datasets, randomly sampling patches from each tile can be extremely time consuming. It's much more efficient to choose a tile, load it, warp it to the appropriate :term:`coordinate reference system (CRS)` and resolution, and then sample random patches from that tile to construct a mini-batch of data. For this reason, we define our own :class:`BatchGeoSampler` implementations below. These can be used like so:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchgeo.datasets import Landsat
|
||||
from torchgeo.samplers import RandomBatchGeoSampler
|
||||
|
||||
dataset = Landsat(...)
|
||||
sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100)
|
||||
dataloader = DataLoader(dataset, batch_sampler=sampler)
|
||||
|
||||
|
||||
Random Batch Geo Sampler
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: RandomBatchGeoSampler
|
||||
|
||||
Sampler Base Classes
|
||||
--------------------
|
||||
|
||||
If you want to write your own custom sampler, you can extend one of these abstract base classes.
|
||||
|
||||
Geo Sampler
|
||||
^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: GeoSampler
|
||||
|
||||
Batch Geo Sampler
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: BatchGeoSampler
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
import math
|
||||
from typing import Iterator, List
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
from rtree.index import Index, Property
|
||||
|
||||
from torchgeo.datasets import BoundingBox
|
||||
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler
|
||||
|
||||
|
||||
class CustomBatchGeoSampler(BatchGeoSampler):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __iter__(self) -> Iterator[List[BoundingBox]]:
|
||||
for i in range(len(self)):
|
||||
yield [BoundingBox(j, j, j, j, j, j) for j in range(len(self))]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 2
|
||||
|
||||
|
||||
class TestBatchGeoSampler:
|
||||
@pytest.fixture(scope="function")
|
||||
def sampler(self) -> CustomBatchGeoSampler:
|
||||
return CustomBatchGeoSampler()
|
||||
|
||||
def test_iter(self, sampler: CustomBatchGeoSampler) -> None:
|
||||
expected = [BoundingBox(0, 0, 0, 0, 0, 0), BoundingBox(1, 1, 1, 1, 1, 1)]
|
||||
assert next(iter(sampler)) == expected
|
||||
|
||||
def test_len(self, sampler: CustomBatchGeoSampler) -> None:
|
||||
assert len(sampler) == 2
|
||||
|
||||
def test_abstract(self) -> None:
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
BatchGeoSampler(None) # type: ignore[abstract]
|
||||
|
||||
|
||||
class TestRandomBatchGeoSampler:
|
||||
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
|
||||
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
|
||||
index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
size = request.param
|
||||
return RandomBatchGeoSampler(index, size, batch_size=2, length=10)
|
||||
|
||||
def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
|
||||
for batch in sampler:
|
||||
for query in batch:
|
||||
assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx
|
||||
assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy
|
||||
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt
|
||||
|
||||
assert math.isclose(query.maxx - query.minx, sampler.size[1])
|
||||
assert math.isclose(query.maxy - query.miny, sampler.size[0])
|
||||
assert math.isclose(
|
||||
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
|
||||
)
|
||||
|
||||
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
|
||||
assert len(sampler) == sampler.length // sampler.batch_size
|
|
@ -3,6 +3,7 @@ from typing import Iterator
|
|||
|
||||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
from rtree.index import Index, Property
|
||||
|
||||
from torchgeo.datasets import BoundingBox
|
||||
from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler
|
||||
|
@ -22,13 +23,13 @@ class CustomGeoSampler(GeoSampler):
|
|||
|
||||
class TestGeoSampler:
|
||||
@pytest.fixture(scope="function")
|
||||
def sampler(self) -> GeoSampler:
|
||||
def sampler(self) -> CustomGeoSampler:
|
||||
return CustomGeoSampler()
|
||||
|
||||
def test_iter(self, sampler: GeoSampler) -> None:
|
||||
def test_iter(self, sampler: CustomGeoSampler) -> None:
|
||||
assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
||||
def test_len(self, sampler: GeoSampler) -> None:
|
||||
def test_len(self, sampler: CustomGeoSampler) -> None:
|
||||
assert len(sampler) == 2
|
||||
|
||||
def test_abstract(self) -> None:
|
||||
|
@ -39,9 +40,11 @@ class TestGeoSampler:
|
|||
class TestRandomGeoSampler:
|
||||
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
|
||||
def sampler(self, request: SubRequest) -> RandomGeoSampler:
|
||||
roi = BoundingBox(0, 10, 20, 30, 40, 50)
|
||||
index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
size = request.param
|
||||
return RandomGeoSampler(roi, size, length=10)
|
||||
return RandomGeoSampler(index, size, length=10)
|
||||
|
||||
def test_iter(self, sampler: RandomGeoSampler) -> None:
|
||||
for query in sampler:
|
||||
|
@ -72,9 +75,11 @@ class TestGridGeoSampler:
|
|||
],
|
||||
)
|
||||
def sampler(self, request: SubRequest) -> GridGeoSampler:
|
||||
roi = BoundingBox(0, 10, 20, 30, 40, 50)
|
||||
index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
size, stride = request.param
|
||||
return GridGeoSampler(roi, size, stride)
|
||||
return GridGeoSampler(index, size, stride)
|
||||
|
||||
def test_iter(self, sampler: GridGeoSampler) -> None:
|
||||
for query in sampler:
|
||||
|
@ -87,6 +92,3 @@ class TestGridGeoSampler:
|
|||
assert math.isclose(
|
||||
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
|
||||
)
|
||||
|
||||
def test_len(self, sampler: RandomGeoSampler) -> None:
|
||||
assert len(sampler) == 9
|
|
@ -1,8 +1,18 @@
|
|||
"""TorchGeo samplers."""
|
||||
|
||||
from .samplers import GeoSampler, GridGeoSampler, RandomGeoSampler
|
||||
from .batch import BatchGeoSampler, RandomBatchGeoSampler
|
||||
from .single import GeoSampler, GridGeoSampler, RandomGeoSampler
|
||||
|
||||
__all__ = ("GeoSampler", "GridGeoSampler", "RandomGeoSampler")
|
||||
__all__ = (
|
||||
# Samplers
|
||||
"GridGeoSampler",
|
||||
"RandomGeoSampler",
|
||||
# Batch samplers
|
||||
"RandomBatchGeoSampler",
|
||||
# Base classes
|
||||
"GeoSampler",
|
||||
"BatchGeoSampler",
|
||||
)
|
||||
|
||||
# https://stackoverflow.com/questions/40018681
|
||||
for module in __all__:
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
"""TorchGeo batch samplers."""
|
||||
|
||||
import abc
|
||||
import random
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from rtree.index import Index
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets import BoundingBox
|
||||
|
||||
from .utils import _to_tuple
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
Sampler.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
class BatchGeoSampler(Sampler[List[BoundingBox]], abc.ABC):
|
||||
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
|
||||
|
||||
Unlike PyTorch's :class:`~torch.utils.data.BatchSampler`, :class:`BatchGeoSampler`
|
||||
returns enough geospatial information to uniquely index any
|
||||
:class:`~torchgeo.datasets.GeoDataset`. This includes things like latitude,
|
||||
longitude, height, width, projection, coordinate system, and time.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __iter__(self) -> Iterator[List[BoundingBox]]:
|
||||
"""Return a batch of indices of a dataset.
|
||||
|
||||
Returns:
|
||||
batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
|
||||
"""
|
||||
|
||||
|
||||
class RandomBatchGeoSampler(BatchGeoSampler):
|
||||
"""Samples batches of elements from a region of interest randomly.
|
||||
|
||||
This is particularly useful during training when you want to maximize the size of
|
||||
the dataset and return as many random :term:`chips <chip>` as possible.
|
||||
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
|
||||
from a tile-based dataset if possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Index,
|
||||
size: Union[Tuple[float, float], float],
|
||||
batch_size: int,
|
||||
length: int,
|
||||
roi: Optional[BoundingBox] = None,
|
||||
) -> None:
|
||||
"""Initialize a new Sampler instance.
|
||||
|
||||
The ``size`` argument can either be:
|
||||
|
||||
* a single ``float`` - in which case the same value is used for the height and
|
||||
width dimension
|
||||
* a ``tuple`` of two floats - in which case, the first *float* is used for the
|
||||
height dimension, and the second *float* for the width dimension
|
||||
|
||||
Args:
|
||||
index: index of a :class:`~torchgeo.datasets.GeoDataset`
|
||||
size: dimensions of each :term:`patch` in units of CRS
|
||||
batch_size: number of samples per batch
|
||||
length: number of samples per epoch
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``index``)
|
||||
"""
|
||||
self.index = index
|
||||
self.size = _to_tuple(size)
|
||||
self.batch_size = batch_size
|
||||
self.length = length
|
||||
if roi is None:
|
||||
roi = BoundingBox(*index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(index.intersection(roi, objects=True))
|
||||
|
||||
def __iter__(self) -> Iterator[List[BoundingBox]]:
|
||||
"""Return the indices of a dataset.
|
||||
|
||||
Returns:
|
||||
batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
|
||||
"""
|
||||
for _ in range(len(self)):
|
||||
# Choose a random tile
|
||||
hit = random.choice(self.hits)
|
||||
bounds = BoundingBox(*hit.bounds)
|
||||
|
||||
# Choose random indices within that tile
|
||||
batch = []
|
||||
for _ in range(self.batch_size):
|
||||
minx = random.uniform(bounds.minx, bounds.maxx - self.size[1])
|
||||
maxx = minx + self.size[1]
|
||||
|
||||
miny = random.uniform(bounds.miny, bounds.maxy - self.size[0])
|
||||
maxy = miny + self.size[0]
|
||||
|
||||
mint = bounds.mint
|
||||
maxt = bounds.maxt
|
||||
|
||||
batch.append(BoundingBox(minx, maxx, miny, maxy, mint, maxt))
|
||||
|
||||
yield batch
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of batches in a single epoch.
|
||||
|
||||
Returns:
|
||||
number of batches in an epoch
|
||||
"""
|
||||
return self.length // self.batch_size
|
|
@ -2,33 +2,21 @@
|
|||
|
||||
import abc
|
||||
import random
|
||||
from typing import Any, Iterator, Tuple, Union
|
||||
from typing import Iterator, Optional, Tuple, Union
|
||||
|
||||
from rtree.index import Index
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets import BoundingBox
|
||||
|
||||
from .utils import _to_tuple
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
Sampler.__module__ = "torch.utils.data"
|
||||
|
||||
|
||||
def _to_tuple(value: Union[Tuple[Any, Any], Any]) -> Tuple[Any, Any]:
|
||||
"""Convert value to a tuple if it is not already a tuple.
|
||||
|
||||
Args:
|
||||
value: input value
|
||||
|
||||
Returns:
|
||||
value if value is a tuple, else (value, value)
|
||||
"""
|
||||
if isinstance(value, (float, int)):
|
||||
return (value, value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class GeoSampler(Sampler[Tuple[Any, ...]], abc.ABC):
|
||||
class GeoSampler(Sampler[BoundingBox], abc.ABC):
|
||||
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
|
||||
|
||||
Unlike PyTorch's :class:`~torch.utils.data.Sampler`, :class:`GeoSampler`
|
||||
|
@ -45,26 +33,25 @@ class GeoSampler(Sampler[Tuple[Any, ...]], abc.ABC):
|
|||
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of samples in a single epoch.
|
||||
|
||||
Returns:
|
||||
length of the epoch
|
||||
"""
|
||||
|
||||
|
||||
class RandomGeoSampler(GeoSampler):
|
||||
"""Samples elements from a region of interest randomly.
|
||||
|
||||
This is particularly useful during training when you want to maximize the size of
|
||||
the dataset and return as many random :term:`chips <chip>` as possible.
|
||||
|
||||
This sampler is not recommended for use with tile-based datasets. Use
|
||||
:class:`RandomBatchGeoSampler` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, roi: BoundingBox, size: Union[Tuple[float, float], float], length: int
|
||||
self,
|
||||
index: Index,
|
||||
size: Union[Tuple[float, float], float],
|
||||
length: int,
|
||||
roi: Optional[BoundingBox] = None,
|
||||
) -> None:
|
||||
"""Initialize a new RandomGeoSampler.
|
||||
"""Initialize a new Sampler instance.
|
||||
|
||||
The ``size`` argument can either be:
|
||||
|
||||
|
@ -74,13 +61,19 @@ class RandomGeoSampler(GeoSampler):
|
|||
height dimension, and the second *float* for the width dimension
|
||||
|
||||
Args:
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
index: index of a :class:`~torchgeo.datasets.GeoDataset`
|
||||
size: dimensions of each :term:`patch` in units of CRS
|
||||
length: number of random samples to draw per epoch
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``index``)
|
||||
"""
|
||||
self.roi = roi
|
||||
self.index = index
|
||||
self.size = _to_tuple(size)
|
||||
self.length = length
|
||||
if roi is None:
|
||||
roi = BoundingBox(*index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(index.intersection(roi, objects=True))
|
||||
|
||||
def __iter__(self) -> Iterator[BoundingBox]:
|
||||
"""Return the index of a dataset.
|
||||
|
@ -89,15 +82,19 @@ class RandomGeoSampler(GeoSampler):
|
|||
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
|
||||
"""
|
||||
for _ in range(len(self)):
|
||||
minx = random.uniform(self.roi.minx, self.roi.maxx - self.size[1])
|
||||
# Choose a random tile
|
||||
hit = random.choice(self.hits)
|
||||
bounds = BoundingBox(*hit.bounds)
|
||||
|
||||
# Choose a random index within that tile
|
||||
minx = random.uniform(bounds.minx, bounds.maxx - self.size[1])
|
||||
maxx = minx + self.size[1]
|
||||
|
||||
miny = random.uniform(self.roi.miny, self.roi.maxy - self.size[0])
|
||||
miny = random.uniform(bounds.miny, bounds.maxy - self.size[0])
|
||||
maxy = miny + self.size[0]
|
||||
|
||||
# TODO: figure out how to handle time
|
||||
mint = self.roi.mint
|
||||
maxt = self.roi.maxt
|
||||
mint = bounds.mint
|
||||
maxt = bounds.maxt
|
||||
|
||||
yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)
|
||||
|
||||
|
@ -123,15 +120,19 @@ class GridGeoSampler(GeoSampler):
|
|||
The overlap between each chip (``chip_size - stride``) should be approximately equal
|
||||
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
|
||||
the CNN.
|
||||
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
|
||||
from a non-tile-based dataset if possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
roi: BoundingBox,
|
||||
index: Index,
|
||||
size: Union[Tuple[float, float], float],
|
||||
stride: Union[Tuple[float, float], float],
|
||||
roi: Optional[BoundingBox] = None,
|
||||
) -> None:
|
||||
"""Initialize a new RandomGeoSampler.
|
||||
"""Initialize a new Sampler instance.
|
||||
|
||||
The ``size`` and ``stride`` arguments can either be:
|
||||
|
||||
|
@ -141,15 +142,18 @@ class GridGeoSampler(GeoSampler):
|
|||
height dimension, and the second *float* for the width dimension
|
||||
|
||||
Args:
|
||||
index: index of a :class:`~torchgeo.datasets.GeoDataset`
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
size: dimensions of each :term:`patch` in units of CRS
|
||||
stride: distance to skip between each patch
|
||||
"""
|
||||
self.roi = roi
|
||||
self.index = index
|
||||
self.size = _to_tuple(size)
|
||||
self.stride = _to_tuple(stride)
|
||||
self.rows = int((roi.maxy - roi.miny - self.size[0]) // self.stride[0]) + 1
|
||||
self.cols = int((roi.maxx - roi.minx - self.size[1]) // self.stride[1]) + 1
|
||||
if roi is None:
|
||||
roi = BoundingBox(*index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = index.intersection(roi, objects=True)
|
||||
|
||||
def __iter__(self) -> Iterator[BoundingBox]:
|
||||
"""Return the index of a dataset.
|
||||
|
@ -157,23 +161,24 @@ class GridGeoSampler(GeoSampler):
|
|||
Returns:
|
||||
(minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
|
||||
"""
|
||||
for i in range(self.rows):
|
||||
miny = self.roi.miny + i * self.stride[0]
|
||||
maxy = miny + self.size[0]
|
||||
for j in range(self.cols):
|
||||
minx = self.roi.minx + j * self.stride[1]
|
||||
maxx = minx + self.size[1]
|
||||
# For each tile...
|
||||
for hit in self.hits:
|
||||
bounds = BoundingBox(*hit.bounds)
|
||||
|
||||
# TODO: figure out how to handle time
|
||||
mint = self.roi.mint
|
||||
maxt = self.roi.maxt
|
||||
rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
|
||||
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
|
||||
|
||||
yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)
|
||||
mint = bounds.mint
|
||||
maxt = bounds.maxt
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of samples in a single epoch.
|
||||
# For each row...
|
||||
for i in range(rows):
|
||||
miny = bounds.miny + i * self.stride[0]
|
||||
maxy = miny + self.size[0]
|
||||
|
||||
Returns:
|
||||
length of the epoch
|
||||
"""
|
||||
return self.rows * self.cols
|
||||
# For each column...
|
||||
for j in range(cols):
|
||||
minx = bounds.minx + j * self.stride[1]
|
||||
maxx = minx + self.size[1]
|
||||
|
||||
yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)
|
|
@ -0,0 +1,18 @@
|
|||
"""Common sampler utilities."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
|
||||
def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
|
||||
"""Convert value to a tuple if it is not already a tuple.
|
||||
|
||||
Args:
|
||||
value: input value
|
||||
|
||||
Returns:
|
||||
value if value is a tuple, else (value, value)
|
||||
"""
|
||||
if isinstance(value, (float, int)):
|
||||
return (value, value)
|
||||
else:
|
||||
return value
|
Загрузка…
Ссылка в новой задаче