Improve sampler performance for pixel-aligned files (#181)

* Improve sampler performance for pixel-aligned files

* Skip merge if only a single file

* Undo changes to hyperparams

* Fix shape, read all bands

* Remove manual single-file reading

* Always keep workers alive

* Various changes in a desperate attempt to improve performance

* Increase epoch size

* Add missing import, fix model name

* Fix tests

* Persistent workers not used unless entire dataset is consumed
This commit is contained in:
Adam J. Stewart 2021-10-12 15:34:08 -05:00 коммит произвёл GitHub
Родитель 142835cede
Коммит e14980a3eb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 78 добавлений и 67 удалений

Просмотреть файл

@ -12,7 +12,7 @@ import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.models import resnet34
from torchgeo.datasets import CDL, Landsat8
from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler
@ -150,10 +150,10 @@ def main(args: argparse.Namespace) -> None:
stride = args.stride * cdl.res
samplers = [
RandomGeoSampler(landsat.index, size=size, length=length),
GridGeoSampler(landsat.index, size=size, stride=stride),
RandomGeoSampler(landsat, size=size, length=length),
GridGeoSampler(landsat, size=size, stride=stride),
RandomBatchGeoSampler(
landsat.index, size=size, batch_size=args.batch_size, length=length
landsat, size=size, batch_size=args.batch_size, length=length
),
]
@ -213,7 +213,7 @@ def main(args: argparse.Namespace) -> None:
)
# Benchmark model
model = resnet18()
model = resnet34()
# Change number of input channels to match Landsat
model.conv1 = nn.Conv2d( # type: ignore[attr-defined]
len(bands), 64, kernel_size=7, stride=2, padding=3, bias=False
@ -248,7 +248,7 @@ def main(args: argparse.Namespace) -> None:
duration = toc - tic
if args.verbose:
print("\nResNet-18:")
print("\nResNet-34:")
print(f" duration: {duration:.3f} sec")
print(f" count: {num_total_patches} patches")
print(f" rate: {num_total_patches / duration:.3f} patches/sec")
@ -260,7 +260,7 @@ def main(args: argparse.Namespace) -> None:
"duration": duration,
"count": num_total_patches,
"rate": num_total_patches / duration,
"sampler": "resnet18",
"sampler": "ResNet-34",
"batch_size": args.batch_size,
"num_workers": args.num_workers,
}
@ -286,8 +286,6 @@ def main(args: argparse.Namespace) -> None:
if __name__ == "__main__":
os.environ["GDAL_CACHEMAX"] = "50%"
parser = set_up_parser()
args = parser.parse_args()

Просмотреть файл

@ -16,7 +16,7 @@ Samplers are used to index a dataset, retrieving a single query at a time. For :
from torchgeo.samplers import RandomGeoSampler
dataset = Landsat(...)
sampler = RandomGeoSampler(dataset.index, size=1000, length=100)
sampler = RandomGeoSampler(dataset, size=1000, length=100)
dataloader = DataLoader(dataset, sampler=sampler)
@ -43,7 +43,7 @@ When working with large tile-based datasets, randomly sampling patches from each
from torchgeo.samplers import RandomBatchGeoSampler
dataset = Landsat(...)
sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100)
sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100)
dataloader = DataLoader(dataset, batch_sampler=sampler)

Просмотреть файл

@ -211,7 +211,7 @@
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
" dataset = chesapeake + naip\n",
" sampler = RandomGeoSampler(naip.index, size=1000, length=888)\n",
" sampler = RandomGeoSampler(naip, size=1000, length=888)\n",
" dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n",
" duration, count = time_epoch(dataloader)\n",
" print(duration, count)"
@ -262,7 +262,7 @@
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
" dataset = chesapeake + naip\n",
" sampler = GridGeoSampler(naip.index, size=1000, stride=500)\n",
" sampler = GridGeoSampler(naip, size=1000, stride=500)\n",
" dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n",
" duration, count = time_epoch(dataloader)\n",
" print(duration, count)"
@ -313,7 +313,7 @@
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
" dataset = chesapeake + naip\n",
" sampler = RandomBatchGeoSampler(naip.index, size=1000, batch_size=12, length=888)\n",
" sampler = RandomBatchGeoSampler(naip, size=1000, batch_size=12, length=888)\n",
" dataloader = DataLoader(dataset, batch_sampler=sampler)\n",
" duration, count = time_epoch(dataloader)\n",
" print(duration, count)"

Просмотреть файл

@ -329,7 +329,7 @@
},
"outputs": [],
"source": [
"sampler = RandomGeoSampler(naip.index, size=1000, length=10)"
"sampler = RandomGeoSampler(naip, size=1000, length=10)"
]
},
{

Просмотреть файл

@ -3,12 +3,14 @@
# Licensed under the MIT License.
"""Script for running the benchmark script over a sweep of different options."""
import itertools
import os
import subprocess
import time
from typing import List
EPOCH_SIZE = 2048
EPOCH_SIZE = 4096
SEED_OPTIONS = [0, 1, 2]
CACHE_OPTIONS = [True, False]
@ -23,6 +25,9 @@ CDL_DATA_ROOT = ""
total_num_experiments = len(SEED_OPTIONS) * len(CACHE_OPTIONS) * len(BATCH_SIZE_OPTIONS)
if __name__ == "__main__":
# With 6 workers, this will use ~60% of available RAM
os.environ["GDAL_CACHEMAX"] = "10%"
tic = time.time()
for i, (cache, batch_size, seed) in enumerate(
itertools.product(CACHE_OPTIONS, BATCH_SIZE_OPTIONS, SEED_OPTIONS)
@ -37,7 +42,7 @@ if __name__ == "__main__":
"--cdl-root",
CDL_DATA_ROOT,
"--num-workers",
"8",
"6",
"--batch-size",
str(batch_size),
"--epoch-size",

Просмотреть файл

@ -7,7 +7,6 @@ from typing import Dict, Iterator, List
import pytest
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from rtree.index import Index, Property
from torch.utils.data import DataLoader
from torchgeo.datasets import BoundingBox, GeoDataset
@ -29,12 +28,10 @@ class CustomBatchGeoSampler(BatchGeoSampler):
class CustomGeoDataset(GeoDataset):
def __init__(
self,
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(3005),
res: float = 1,
) -> None:
super().__init__()
self.index.insert(0, bounds)
self.crs = crs
self.res = res
@ -72,11 +69,11 @@ class TestBatchGeoSampler:
class TestRandomBatchGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomBatchGeoSampler(index, size, batch_size=2, length=10)
return RandomBatchGeoSampler(ds, size, batch_size=2, length=10)
def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
for batch in sampler:

Просмотреть файл

@ -7,7 +7,6 @@ from typing import Dict, Iterator
import pytest
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from rtree.index import Index, Property
from torch.utils.data import DataLoader
from torchgeo.datasets import BoundingBox, GeoDataset
@ -29,12 +28,10 @@ class CustomGeoSampler(GeoSampler):
class CustomGeoDataset(GeoDataset):
def __init__(
self,
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(3005),
res: float = 1,
) -> None:
super().__init__()
self.index.insert(0, bounds)
self.crs = crs
self.res = res
@ -71,11 +68,11 @@ class TestGeoSampler:
class TestRandomGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomGeoSampler:
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomGeoSampler(index, size, length=10)
return RandomGeoSampler(ds, size, length=10)
def test_iter(self, sampler: RandomGeoSampler) -> None:
for query in sampler:
@ -116,11 +113,11 @@ class TestGridGeoSampler:
],
)
def sampler(self, request: SubRequest) -> GridGeoSampler:
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 20, 0, 10, 40, 50))
index.insert(1, (0, 20, 0, 10, 40, 50))
ds = CustomGeoDataset()
ds.index.insert(0, (0, 20, 0, 10, 40, 50))
ds.index.insert(1, (0, 20, 0, 10, 40, 50))
size, stride = request.param
return GridGeoSampler(index, size, stride)
return GridGeoSampler(ds, size, stride)
def test_iter(self, sampler: GridGeoSampler) -> None:
for query in sampler:

Просмотреть файл

@ -22,6 +22,7 @@ import torch
from rasterio.crs import CRS
from rasterio.io import DatasetReader
from rasterio.vrt import WarpedVRT
from rasterio.windows import from_bounds
from rtree.index import Index, Property
from torch import Tensor
from torch.utils.data import Dataset
@ -328,7 +329,16 @@ class RasterDataset(GeoDataset):
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]
bounds = (query.minx, query.miny, query.maxx, query.maxy)
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
if len(vrt_fhs) == 1:
src = vrt_fhs[0]
out_width = int(round((query.maxx - query.minx) / self.res))
out_height = int(round((query.maxy - query.miny) / self.res))
out_shape = (src.count, out_height, out_width)
dest = src.read(
out_shape=out_shape, window=from_bounds(*bounds, src.transform)
)
else:
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
dest = dest.astype(np.int32)
tensor: Tensor = torch.tensor(dest) # type: ignore[attr-defined]

Просмотреть файл

@ -7,10 +7,9 @@ import abc
import random
from typing import Iterator, List, Optional, Tuple, Union
from rtree.index import Index
from torch.utils.data import Sampler
from torchgeo.datasets import BoundingBox
from torchgeo.datasets import BoundingBox, GeoDataset
from .utils import _to_tuple, get_random_bounding_box
@ -43,13 +42,13 @@ class RandomBatchGeoSampler(BatchGeoSampler):
This is particularly useful during training when you want to maximize the size of
the dataset and return as many random :term:`chips <chip>` as possible.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
from a tile-based dataset if possible.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
a tile-based dataset if possible.
"""
def __init__(
self,
index: Index,
dataset: GeoDataset,
size: Union[Tuple[float, float], float],
batch_size: int,
length: int,
@ -65,21 +64,22 @@ class RandomBatchGeoSampler(BatchGeoSampler):
height dimension, and the second *float* for the width dimension
Args:
index: index of a :class:`~torchgeo.datasets.GeoDataset`
dataset: dataset to index from
size: dimensions of each :term:`patch` in units of CRS
batch_size: number of samples per batch
length: number of samples per epoch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``index``)
(defaults to the bounds of ``dataset.index``)
"""
self.index = index
self.index = dataset.index
self.res = dataset.res
self.size = _to_tuple(size)
self.batch_size = batch_size
self.length = length
if roi is None:
roi = BoundingBox(*index.bounds)
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(index.intersection(roi, objects=True))
self.hits = list(self.index.intersection(roi, objects=True))
def __iter__(self) -> Iterator[List[BoundingBox]]:
"""Return the indices of a dataset.
@ -96,7 +96,7 @@ class RandomBatchGeoSampler(BatchGeoSampler):
batch = []
for _ in range(self.batch_size):
bounding_box = get_random_bounding_box(bounds, self.size)
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
batch.append(bounding_box)
yield batch

Просмотреть файл

@ -7,10 +7,9 @@ import abc
import random
from typing import Iterator, Optional, Tuple, Union
from rtree.index import Index
from torch.utils.data import Sampler
from torchgeo.datasets import BoundingBox
from torchgeo.datasets import BoundingBox, GeoDataset
from .utils import _to_tuple, get_random_bounding_box
@ -49,7 +48,7 @@ class RandomGeoSampler(GeoSampler):
def __init__(
self,
index: Index,
dataset: GeoDataset,
size: Union[Tuple[float, float], float],
length: int,
roi: Optional[BoundingBox] = None,
@ -64,19 +63,20 @@ class RandomGeoSampler(GeoSampler):
height dimension, and the second *float* for the width dimension
Args:
index: index of a :class:`~torchgeo.datasets.GeoDataset`
dataset: dataset to index from
size: dimensions of each :term:`patch` in units of CRS
length: number of random samples to draw per epoch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``index``)
(defaults to the bounds of ``dataset.index``)
"""
self.index = index
self.index = dataset.index
self.res = dataset.res
self.size = _to_tuple(size)
self.length = length
if roi is None:
roi = BoundingBox(*index.bounds)
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(index.intersection(roi, objects=True))
self.hits = list(self.index.intersection(roi, objects=True))
def __iter__(self) -> Iterator[BoundingBox]:
"""Return the index of a dataset.
@ -90,7 +90,7 @@ class RandomGeoSampler(GeoSampler):
bounds = BoundingBox(*hit.bounds)
# Choose a random index within that tile
bounding_box = get_random_bounding_box(bounds, self.size)
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
yield bounding_box
@ -117,13 +117,13 @@ class GridGeoSampler(GeoSampler):
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
the CNN.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
from a non-tile-based dataset if possible.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
a non-tile-based dataset if possible.
"""
def __init__(
self,
index: Index,
dataset: GeoDataset,
size: Union[Tuple[float, float], float],
stride: Union[Tuple[float, float], float],
roi: Optional[BoundingBox] = None,
@ -138,18 +138,19 @@ class GridGeoSampler(GeoSampler):
height dimension, and the second *float* for the width dimension
Args:
index: index of a :class:`~torchgeo.datasets.GeoDataset`
dataset: dataset to index from
size: dimensions of each :term:`patch` in units of CRS
stride: distance to skip between each patch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
"""
self.index = index
self.index = dataset.index
self.size = _to_tuple(size)
self.stride = _to_tuple(stride)
if roi is None:
roi = BoundingBox(*index.bounds)
roi = BoundingBox(*self.index.bounds)
self.roi = roi
self.hits = list(index.intersection(roi, objects=True))
self.hits = list(self.index.intersection(roi, objects=True))
self.length: int = 0
for hit in self.hits:

Просмотреть файл

@ -25,7 +25,7 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
def get_random_bounding_box(
bounds: BoundingBox, size: Union[Tuple[float, float], float]
bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float
) -> BoundingBox:
"""Returns a random bounding box within a given bounding box.
@ -45,13 +45,16 @@ def get_random_bounding_box(
"""
t_size: Tuple[float, float] = _to_tuple(size)
minx = random.uniform(bounds.minx, bounds.maxx - t_size[1])
width = (bounds.maxx - bounds.minx - t_size[1]) // res
minx = random.randrange(int(width)) * res + bounds.minx
maxx = minx + t_size[1]
miny = random.uniform(bounds.miny, bounds.maxy - t_size[0])
height = (bounds.maxy - bounds.miny - t_size[0]) // res
miny = random.randrange(int(height)) * res + bounds.miny
maxy = miny + t_size[0]
mint = bounds.mint
maxt = bounds.maxt
return BoundingBox(minx, maxx, miny, maxy, mint, maxt)
query = BoundingBox(minx, maxx, miny, maxy, mint, maxt)
return query