diff --git a/benchmark.py b/benchmark.py index 7b8262977..b9a4bd5da 100755 --- a/benchmark.py +++ b/benchmark.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader -from torchvision.models import resnet18 +from torchvision.models import resnet34 from torchgeo.datasets import CDL, Landsat8 from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler @@ -150,10 +150,10 @@ def main(args: argparse.Namespace) -> None: stride = args.stride * cdl.res samplers = [ - RandomGeoSampler(landsat.index, size=size, length=length), - GridGeoSampler(landsat.index, size=size, stride=stride), + RandomGeoSampler(landsat, size=size, length=length), + GridGeoSampler(landsat, size=size, stride=stride), RandomBatchGeoSampler( - landsat.index, size=size, batch_size=args.batch_size, length=length + landsat, size=size, batch_size=args.batch_size, length=length ), ] @@ -213,7 +213,7 @@ def main(args: argparse.Namespace) -> None: ) # Benchmark model - model = resnet18() + model = resnet34() # Change number of input channels to match Landsat model.conv1 = nn.Conv2d( # type: ignore[attr-defined] len(bands), 64, kernel_size=7, stride=2, padding=3, bias=False @@ -248,7 +248,7 @@ def main(args: argparse.Namespace) -> None: duration = toc - tic if args.verbose: - print("\nResNet-18:") + print("\nResNet-34:") print(f" duration: {duration:.3f} sec") print(f" count: {num_total_patches} patches") print(f" rate: {num_total_patches / duration:.3f} patches/sec") @@ -260,7 +260,7 @@ def main(args: argparse.Namespace) -> None: "duration": duration, "count": num_total_patches, "rate": num_total_patches / duration, - "sampler": "resnet18", + "sampler": "ResNet-34", "batch_size": args.batch_size, "num_workers": args.num_workers, } @@ -286,8 +286,6 @@ def main(args: argparse.Namespace) -> None: if __name__ == "__main__": - os.environ["GDAL_CACHEMAX"] = "50%" - parser = set_up_parser() args = parser.parse_args() diff --git a/docs/api/samplers.rst b/docs/api/samplers.rst index e26bd8233..fa1ca6522 100644 --- a/docs/api/samplers.rst +++ b/docs/api/samplers.rst @@ -16,7 +16,7 @@ Samplers are used to index a dataset, retrieving a single query at a time. For : from torchgeo.samplers import RandomGeoSampler dataset = Landsat(...) - sampler = RandomGeoSampler(dataset.index, size=1000, length=100) + sampler = RandomGeoSampler(dataset, size=1000, length=100) dataloader = DataLoader(dataset, sampler=sampler) @@ -43,7 +43,7 @@ When working with large tile-based datasets, randomly sampling patches from each from torchgeo.samplers import RandomBatchGeoSampler dataset = Landsat(...) - sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100) + sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100) dataloader = DataLoader(dataset, batch_sampler=sampler) diff --git a/docs/tutorials/benchmarking.ipynb b/docs/tutorials/benchmarking.ipynb index 3208283d1..d3bf3aa32 100644 --- a/docs/tutorials/benchmarking.ipynb +++ b/docs/tutorials/benchmarking.ipynb @@ -211,7 +211,7 @@ " chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n", " naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n", " dataset = chesapeake + naip\n", - " sampler = RandomGeoSampler(naip.index, size=1000, length=888)\n", + " sampler = RandomGeoSampler(naip, size=1000, length=888)\n", " dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n", " duration, count = time_epoch(dataloader)\n", " print(duration, count)" @@ -262,7 +262,7 @@ " chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n", " naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n", " dataset = chesapeake + naip\n", - " sampler = GridGeoSampler(naip.index, size=1000, stride=500)\n", + " sampler = GridGeoSampler(naip, size=1000, stride=500)\n", " dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n", " duration, count = time_epoch(dataloader)\n", " print(duration, count)" @@ -313,7 +313,7 @@ " chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n", " naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n", " dataset = chesapeake + naip\n", - " sampler = RandomBatchGeoSampler(naip.index, size=1000, batch_size=12, length=888)\n", + " sampler = RandomBatchGeoSampler(naip, size=1000, batch_size=12, length=888)\n", " dataloader = DataLoader(dataset, batch_sampler=sampler)\n", " duration, count = time_epoch(dataloader)\n", " print(duration, count)" diff --git a/docs/tutorials/getting_started.ipynb b/docs/tutorials/getting_started.ipynb index 17ccc0389..347a1f495 100644 --- a/docs/tutorials/getting_started.ipynb +++ b/docs/tutorials/getting_started.ipynb @@ -329,7 +329,7 @@ }, "outputs": [], "source": [ - "sampler = RandomGeoSampler(naip.index, size=1000, length=10)" + "sampler = RandomGeoSampler(naip, size=1000, length=10)" ] }, { diff --git a/experiments/run_benchmarks_experiments.py b/experiments/run_benchmarks_experiments.py index 76ba59eae..7f3a95c09 100755 --- a/experiments/run_benchmarks_experiments.py +++ b/experiments/run_benchmarks_experiments.py @@ -3,12 +3,14 @@ # Licensed under the MIT License. """Script for running the benchmark script over a sweep of different options.""" + import itertools +import os import subprocess import time from typing import List -EPOCH_SIZE = 2048 +EPOCH_SIZE = 4096 SEED_OPTIONS = [0, 1, 2] CACHE_OPTIONS = [True, False] @@ -23,6 +25,9 @@ CDL_DATA_ROOT = "" total_num_experiments = len(SEED_OPTIONS) * len(CACHE_OPTIONS) * len(BATCH_SIZE_OPTIONS) if __name__ == "__main__": + # With 6 workers, this will use ~60% of available RAM + os.environ["GDAL_CACHEMAX"] = "10%" + tic = time.time() for i, (cache, batch_size, seed) in enumerate( itertools.product(CACHE_OPTIONS, BATCH_SIZE_OPTIONS, SEED_OPTIONS) @@ -37,7 +42,7 @@ if __name__ == "__main__": "--cdl-root", CDL_DATA_ROOT, "--num-workers", - "8", + "6", "--batch-size", str(batch_size), "--epoch-size", diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 7818d9316..a3825acc5 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -7,7 +7,6 @@ from typing import Dict, Iterator, List import pytest from _pytest.fixtures import SubRequest from rasterio.crs import CRS -from rtree.index import Index, Property from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset @@ -29,12 +28,10 @@ class CustomBatchGeoSampler(BatchGeoSampler): class CustomGeoDataset(GeoDataset): def __init__( self, - bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5), crs: CRS = CRS.from_epsg(3005), res: float = 1, ) -> None: super().__init__() - self.index.insert(0, bounds) self.crs = crs self.res = res @@ -72,11 +69,11 @@ class TestBatchGeoSampler: class TestRandomBatchGeoSampler: @pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)]) def sampler(self, request: SubRequest) -> RandomBatchGeoSampler: - index = Index(interleaved=False, properties=Property(dimension=3)) - index.insert(0, (0, 10, 20, 30, 40, 50)) - index.insert(1, (0, 10, 20, 30, 40, 50)) + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 20, 30, 40, 50)) + ds.index.insert(1, (0, 10, 20, 30, 40, 50)) size = request.param - return RandomBatchGeoSampler(index, size, batch_size=2, length=10) + return RandomBatchGeoSampler(ds, size, batch_size=2, length=10) def test_iter(self, sampler: RandomBatchGeoSampler) -> None: for batch in sampler: diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index b369a4b5d..8583a206b 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -7,7 +7,6 @@ from typing import Dict, Iterator import pytest from _pytest.fixtures import SubRequest from rasterio.crs import CRS -from rtree.index import Index, Property from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset @@ -29,12 +28,10 @@ class CustomGeoSampler(GeoSampler): class CustomGeoDataset(GeoDataset): def __init__( self, - bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5), crs: CRS = CRS.from_epsg(3005), res: float = 1, ) -> None: super().__init__() - self.index.insert(0, bounds) self.crs = crs self.res = res @@ -71,11 +68,11 @@ class TestGeoSampler: class TestRandomGeoSampler: @pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)]) def sampler(self, request: SubRequest) -> RandomGeoSampler: - index = Index(interleaved=False, properties=Property(dimension=3)) - index.insert(0, (0, 10, 20, 30, 40, 50)) - index.insert(1, (0, 10, 20, 30, 40, 50)) + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 20, 30, 40, 50)) + ds.index.insert(1, (0, 10, 20, 30, 40, 50)) size = request.param - return RandomGeoSampler(index, size, length=10) + return RandomGeoSampler(ds, size, length=10) def test_iter(self, sampler: RandomGeoSampler) -> None: for query in sampler: @@ -116,11 +113,11 @@ class TestGridGeoSampler: ], ) def sampler(self, request: SubRequest) -> GridGeoSampler: - index = Index(interleaved=False, properties=Property(dimension=3)) - index.insert(0, (0, 20, 0, 10, 40, 50)) - index.insert(1, (0, 20, 0, 10, 40, 50)) + ds = CustomGeoDataset() + ds.index.insert(0, (0, 20, 0, 10, 40, 50)) + ds.index.insert(1, (0, 20, 0, 10, 40, 50)) size, stride = request.param - return GridGeoSampler(index, size, stride) + return GridGeoSampler(ds, size, stride) def test_iter(self, sampler: GridGeoSampler) -> None: for query in sampler: diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 31be05758..368e27539 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -22,6 +22,7 @@ import torch from rasterio.crs import CRS from rasterio.io import DatasetReader from rasterio.vrt import WarpedVRT +from rasterio.windows import from_bounds from rtree.index import Index, Property from torch import Tensor from torch.utils.data import Dataset @@ -328,7 +329,16 @@ class RasterDataset(GeoDataset): vrt_fhs = [self._load_warp_file(fp) for fp in filepaths] bounds = (query.minx, query.miny, query.maxx, query.maxy) - dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res) + if len(vrt_fhs) == 1: + src = vrt_fhs[0] + out_width = int(round((query.maxx - query.minx) / self.res)) + out_height = int(round((query.maxy - query.miny) / self.res)) + out_shape = (src.count, out_height, out_width) + dest = src.read( + out_shape=out_shape, window=from_bounds(*bounds, src.transform) + ) + else: + dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res) dest = dest.astype(np.int32) tensor: Tensor = torch.tensor(dest) # type: ignore[attr-defined] diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index f65044b31..d03281ba1 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -7,10 +7,9 @@ import abc import random from typing import Iterator, List, Optional, Tuple, Union -from rtree.index import Index from torch.utils.data import Sampler -from torchgeo.datasets import BoundingBox +from torchgeo.datasets import BoundingBox, GeoDataset from .utils import _to_tuple, get_random_bounding_box @@ -43,13 +42,13 @@ class RandomBatchGeoSampler(BatchGeoSampler): This is particularly useful during training when you want to maximize the size of the dataset and return as many random :term:`chips ` as possible. - When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come - from a tile-based dataset if possible. + When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be + a tile-based dataset if possible. """ def __init__( self, - index: Index, + dataset: GeoDataset, size: Union[Tuple[float, float], float], batch_size: int, length: int, @@ -65,21 +64,22 @@ class RandomBatchGeoSampler(BatchGeoSampler): height dimension, and the second *float* for the width dimension Args: - index: index of a :class:`~torchgeo.datasets.GeoDataset` + dataset: dataset to index from size: dimensions of each :term:`patch` in units of CRS batch_size: number of samples per batch length: number of samples per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) - (defaults to the bounds of ``index``) + (defaults to the bounds of ``dataset.index``) """ - self.index = index + self.index = dataset.index + self.res = dataset.res self.size = _to_tuple(size) self.batch_size = batch_size self.length = length if roi is None: - roi = BoundingBox(*index.bounds) + roi = BoundingBox(*self.index.bounds) self.roi = roi - self.hits = list(index.intersection(roi, objects=True)) + self.hits = list(self.index.intersection(roi, objects=True)) def __iter__(self) -> Iterator[List[BoundingBox]]: """Return the indices of a dataset. @@ -96,7 +96,7 @@ class RandomBatchGeoSampler(BatchGeoSampler): batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box(bounds, self.size) + bounding_box = get_random_bounding_box(bounds, self.size, self.res) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 66bf7d0ed..341600370 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -7,10 +7,9 @@ import abc import random from typing import Iterator, Optional, Tuple, Union -from rtree.index import Index from torch.utils.data import Sampler -from torchgeo.datasets import BoundingBox +from torchgeo.datasets import BoundingBox, GeoDataset from .utils import _to_tuple, get_random_bounding_box @@ -49,7 +48,7 @@ class RandomGeoSampler(GeoSampler): def __init__( self, - index: Index, + dataset: GeoDataset, size: Union[Tuple[float, float], float], length: int, roi: Optional[BoundingBox] = None, @@ -64,19 +63,20 @@ class RandomGeoSampler(GeoSampler): height dimension, and the second *float* for the width dimension Args: - index: index of a :class:`~torchgeo.datasets.GeoDataset` + dataset: dataset to index from size: dimensions of each :term:`patch` in units of CRS length: number of random samples to draw per epoch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) - (defaults to the bounds of ``index``) + (defaults to the bounds of ``dataset.index``) """ - self.index = index + self.index = dataset.index + self.res = dataset.res self.size = _to_tuple(size) self.length = length if roi is None: - roi = BoundingBox(*index.bounds) + roi = BoundingBox(*self.index.bounds) self.roi = roi - self.hits = list(index.intersection(roi, objects=True)) + self.hits = list(self.index.intersection(roi, objects=True)) def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. @@ -90,7 +90,7 @@ class RandomGeoSampler(GeoSampler): bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size) + bounding_box = get_random_bounding_box(bounds, self.size, self.res) yield bounding_box @@ -117,13 +117,13 @@ class GridGeoSampler(GeoSampler): to the `receptive field `_ of the CNN. - When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come - from a non-tile-based dataset if possible. + When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be + a non-tile-based dataset if possible. """ def __init__( self, - index: Index, + dataset: GeoDataset, size: Union[Tuple[float, float], float], stride: Union[Tuple[float, float], float], roi: Optional[BoundingBox] = None, @@ -138,18 +138,19 @@ class GridGeoSampler(GeoSampler): height dimension, and the second *float* for the width dimension Args: - index: index of a :class:`~torchgeo.datasets.GeoDataset` + dataset: dataset to index from size: dimensions of each :term:`patch` in units of CRS stride: distance to skip between each patch roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) + (defaults to the bounds of ``dataset.index``) """ - self.index = index + self.index = dataset.index self.size = _to_tuple(size) self.stride = _to_tuple(stride) if roi is None: - roi = BoundingBox(*index.bounds) + roi = BoundingBox(*self.index.bounds) self.roi = roi - self.hits = list(index.intersection(roi, objects=True)) + self.hits = list(self.index.intersection(roi, objects=True)) self.length: int = 0 for hit in self.hits: diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index e1d906ac5..b8aecd85a 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -25,7 +25,7 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: Union[Tuple[float, float], float] + bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -45,13 +45,16 @@ def get_random_bounding_box( """ t_size: Tuple[float, float] = _to_tuple(size) - minx = random.uniform(bounds.minx, bounds.maxx - t_size[1]) + width = (bounds.maxx - bounds.minx - t_size[1]) // res + minx = random.randrange(int(width)) * res + bounds.minx maxx = minx + t_size[1] - miny = random.uniform(bounds.miny, bounds.maxy - t_size[0]) + height = (bounds.maxy - bounds.miny - t_size[0]) // res + miny = random.randrange(int(height)) * res + bounds.miny maxy = miny + t_size[0] mint = bounds.mint maxt = bounds.maxt - return BoundingBox(minx, maxx, miny, maxy, mint, maxt) + query = BoundingBox(minx, maxx, miny, maxy, mint, maxt) + return query