GridGeoSampler: change stride of last patch to sample entire ROI (#630)

* Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning

* style and mypy fixes

* black test fix

* Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning

* style and mypy fixes

* black test fix

* single.py: adapt gridgeosampler to sample beyond limit of ROI for a partial patch (to be padded)
test_single.py: add tests for multiple limit cases (see issue #448)

* format for black and flake8

* format for black and flake8

* once again, format for black and flake8

* Revert "Adjust minx/miny with a smaller stride for the last sample per row/col and issue warning"

This reverts commit cb554c67

* adapt unit tests, remove warnings

* flake8: remove warnings import

* Address some comments

* Simplify computation of # rows/cols

* Document this new feature

* Fix size of ceiling symbol

* Simplify tests

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Rémi Tavon 2022-09-03 00:11:14 -04:00 коммит произвёл GitHub
Родитель f41619a435
Коммит 7fa0fd429e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 71 добавлений и 14 удалений

Просмотреть файл

@ -182,9 +182,9 @@ class TestGridGeoSampler:
)
def test_len(self, sampler: GridGeoSampler) -> None:
rows = ((100 - sampler.size[0]) // sampler.stride[0]) + 1
cols = ((100 - sampler.size[1]) // sampler.stride[1]) + 1
length = rows * cols * 2
rows = math.ceil((100 - sampler.size[0]) / sampler.stride[0]) + 1
cols = math.ceil((100 - sampler.size[1]) / sampler.stride[1]) + 1
length = rows * cols * 2 # two items in dataset
assert len(sampler) == length
def test_roi(self, dataset: CustomGeoDataset) -> None:
@ -195,11 +195,34 @@ class TestGridGeoSampler:
def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
ds.index.insert(0, (0, 1, 0, 1, 0, 1))
sampler = GridGeoSampler(ds, 2, 10)
for _ in sampler:
continue
assert len(sampler) == 0
def test_tiles_side_by_side(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(0, (0, 10, 10, 20, 0, 10))
sampler = GridGeoSampler(ds, 2, 10)
for bbox in sampler:
assert bbox.area > 0
def test_integer_multiple(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = GridGeoSampler(ds, 10, 10, units=Units.CRS)
iterator = iter(sampler)
assert len(sampler) == 1
assert next(iterator) == BoundingBox(0, 10, 0, 10, 0, 10)
def test_float_multiple(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 6, 0, 5, 0, 10))
sampler = GridGeoSampler(ds, 5, 5, units=Units.CRS)
iterator = iter(sampler)
assert len(sampler) == 2
assert next(iterator) == BoundingBox(0, 5, 0, 5, 0, 10)
assert next(iterator) == BoundingBox(1, 6, 0, 5, 0, 10)
@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])

Просмотреть файл

@ -4,6 +4,7 @@
"""TorchGeo samplers."""
import abc
import math
from typing import Callable, Iterable, Iterator, Optional, Tuple, Union
import torch
@ -146,7 +147,7 @@ class RandomGeoSampler(GeoSampler):
class GridGeoSampler(GeoSampler):
"""Samples elements in a grid-like fashion.
r"""Samples elements in a grid-like fashion.
This is particularly useful during evaluation when you want to make predictions for
an entire region of interest. You want to minimize the amount of redundant
@ -158,6 +159,21 @@ class GridGeoSampler(GeoSampler):
The overlap between each chip (``chip_size - stride``) should be approximately equal
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
the CNN.
Note that the stride of the final set of chips in each row/column may be adjusted so
that the entire :term:`tile` is sampled without exceeding the bounds of the dataset.
Let :math:`i` be the size of the input tile. Let :math:`k` be the requested size of
the output patch. Let :math:`s` be the requested stride. Let :math:`o` be the number
of output rows/columns sampled from each tile. :math:`o` can then be computed as:
.. math::
o = \left\lceil \frac{i - k}{s} \right\rceil + 1
This is almost identical to relationship 5 in
https://doi.org/10.48550/arXiv.1603.07285. However, we use ceiling instead of floor
because we want to include the final remaining chip.
"""
def __init__(
@ -200,8 +216,8 @@ class GridGeoSampler(GeoSampler):
for hit in self.index.intersection(tuple(self.roi), objects=True):
bounds = BoundingBox(*hit.bounds)
if (
bounds.maxx - bounds.minx > self.size[1]
and bounds.maxy - bounds.miny > self.size[0]
bounds.maxx - bounds.minx >= self.size[1]
and bounds.maxy - bounds.miny >= self.size[0]
):
self.hits.append(hit)
@ -209,8 +225,14 @@ class GridGeoSampler(GeoSampler):
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)
rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = (
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
+ 1
)
cols = (
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
+ 1
)
self.length += rows * cols
def __iter__(self) -> Iterator[BoundingBox]:
@ -223,8 +245,14 @@ class GridGeoSampler(GeoSampler):
for hit in self.hits:
bounds = BoundingBox(*hit.bounds)
rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1
cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1
rows = (
math.ceil((bounds.maxy - bounds.miny - self.size[0]) / self.stride[0])
+ 1
)
cols = (
math.ceil((bounds.maxx - bounds.minx - self.size[1]) / self.stride[1])
+ 1
)
mint = bounds.mint
maxt = bounds.maxt
@ -233,11 +261,17 @@ class GridGeoSampler(GeoSampler):
for i in range(rows):
miny = bounds.miny + i * self.stride[0]
maxy = miny + self.size[0]
if maxy > bounds.maxy:
maxy = bounds.maxy
miny = bounds.maxy - self.size[0]
# For each column...
for j in range(cols):
minx = bounds.minx + j * self.stride[1]
maxx = minx + self.size[1]
if maxx > bounds.maxx:
maxx = bounds.maxx
minx = bounds.maxx - self.size[1]
yield BoundingBox(minx, maxx, miny, maxy, mint, maxt)