зеркало из https://github.com/microsoft/torchgeo.git
Improve sampler performance for pixel-aligned files (#181)
* Improve sampler performance for pixel-aligned files * Skip merge if only a single file * Undo changes to hyperparams * Fix shape, read all bands * Remove manual single-file reading * Always keep workers alive * Various changes in a desperate attempt to improve performance * Increase epoch size * Add missing import, fix model name * Fix tests * Persistent workers not used unless entire dataset is consumed
This commit is contained in:
Родитель
142835cede
Коммит
e14980a3eb
16
benchmark.py
16
benchmark.py
|
@ -12,7 +12,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.models import resnet18
|
||||
from torchvision.models import resnet34
|
||||
|
||||
from torchgeo.datasets import CDL, Landsat8
|
||||
from torchgeo.samplers import GridGeoSampler, RandomBatchGeoSampler, RandomGeoSampler
|
||||
|
@ -150,10 +150,10 @@ def main(args: argparse.Namespace) -> None:
|
|||
stride = args.stride * cdl.res
|
||||
|
||||
samplers = [
|
||||
RandomGeoSampler(landsat.index, size=size, length=length),
|
||||
GridGeoSampler(landsat.index, size=size, stride=stride),
|
||||
RandomGeoSampler(landsat, size=size, length=length),
|
||||
GridGeoSampler(landsat, size=size, stride=stride),
|
||||
RandomBatchGeoSampler(
|
||||
landsat.index, size=size, batch_size=args.batch_size, length=length
|
||||
landsat, size=size, batch_size=args.batch_size, length=length
|
||||
),
|
||||
]
|
||||
|
||||
|
@ -213,7 +213,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
)
|
||||
|
||||
# Benchmark model
|
||||
model = resnet18()
|
||||
model = resnet34()
|
||||
# Change number of input channels to match Landsat
|
||||
model.conv1 = nn.Conv2d( # type: ignore[attr-defined]
|
||||
len(bands), 64, kernel_size=7, stride=2, padding=3, bias=False
|
||||
|
@ -248,7 +248,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
duration = toc - tic
|
||||
|
||||
if args.verbose:
|
||||
print("\nResNet-18:")
|
||||
print("\nResNet-34:")
|
||||
print(f" duration: {duration:.3f} sec")
|
||||
print(f" count: {num_total_patches} patches")
|
||||
print(f" rate: {num_total_patches / duration:.3f} patches/sec")
|
||||
|
@ -260,7 +260,7 @@ def main(args: argparse.Namespace) -> None:
|
|||
"duration": duration,
|
||||
"count": num_total_patches,
|
||||
"rate": num_total_patches / duration,
|
||||
"sampler": "resnet18",
|
||||
"sampler": "ResNet-34",
|
||||
"batch_size": args.batch_size,
|
||||
"num_workers": args.num_workers,
|
||||
}
|
||||
|
@ -286,8 +286,6 @@ def main(args: argparse.Namespace) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["GDAL_CACHEMAX"] = "50%"
|
||||
|
||||
parser = set_up_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ Samplers are used to index a dataset, retrieving a single query at a time. For :
|
|||
from torchgeo.samplers import RandomGeoSampler
|
||||
|
||||
dataset = Landsat(...)
|
||||
sampler = RandomGeoSampler(dataset.index, size=1000, length=100)
|
||||
sampler = RandomGeoSampler(dataset, size=1000, length=100)
|
||||
dataloader = DataLoader(dataset, sampler=sampler)
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ When working with large tile-based datasets, randomly sampling patches from each
|
|||
from torchgeo.samplers import RandomBatchGeoSampler
|
||||
|
||||
dataset = Landsat(...)
|
||||
sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100)
|
||||
sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100)
|
||||
dataloader = DataLoader(dataset, batch_sampler=sampler)
|
||||
|
||||
|
||||
|
|
|
@ -211,7 +211,7 @@
|
|||
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
|
||||
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
|
||||
" dataset = chesapeake + naip\n",
|
||||
" sampler = RandomGeoSampler(naip.index, size=1000, length=888)\n",
|
||||
" sampler = RandomGeoSampler(naip, size=1000, length=888)\n",
|
||||
" dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n",
|
||||
" duration, count = time_epoch(dataloader)\n",
|
||||
" print(duration, count)"
|
||||
|
@ -262,7 +262,7 @@
|
|||
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
|
||||
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
|
||||
" dataset = chesapeake + naip\n",
|
||||
" sampler = GridGeoSampler(naip.index, size=1000, stride=500)\n",
|
||||
" sampler = GridGeoSampler(naip, size=1000, stride=500)\n",
|
||||
" dataloader = DataLoader(dataset, batch_size=12, sampler=sampler)\n",
|
||||
" duration, count = time_epoch(dataloader)\n",
|
||||
" print(duration, count)"
|
||||
|
@ -313,7 +313,7 @@
|
|||
" chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
|
||||
" naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
|
||||
" dataset = chesapeake + naip\n",
|
||||
" sampler = RandomBatchGeoSampler(naip.index, size=1000, batch_size=12, length=888)\n",
|
||||
" sampler = RandomBatchGeoSampler(naip, size=1000, batch_size=12, length=888)\n",
|
||||
" dataloader = DataLoader(dataset, batch_sampler=sampler)\n",
|
||||
" duration, count = time_epoch(dataloader)\n",
|
||||
" print(duration, count)"
|
||||
|
|
|
@ -329,7 +329,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sampler = RandomGeoSampler(naip.index, size=1000, length=10)"
|
||||
"sampler = RandomGeoSampler(naip, size=1000, length=10)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -3,12 +3,14 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
"""Script for running the benchmark script over a sweep of different options."""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
EPOCH_SIZE = 2048
|
||||
EPOCH_SIZE = 4096
|
||||
|
||||
SEED_OPTIONS = [0, 1, 2]
|
||||
CACHE_OPTIONS = [True, False]
|
||||
|
@ -23,6 +25,9 @@ CDL_DATA_ROOT = ""
|
|||
total_num_experiments = len(SEED_OPTIONS) * len(CACHE_OPTIONS) * len(BATCH_SIZE_OPTIONS)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# With 6 workers, this will use ~60% of available RAM
|
||||
os.environ["GDAL_CACHEMAX"] = "10%"
|
||||
|
||||
tic = time.time()
|
||||
for i, (cache, batch_size, seed) in enumerate(
|
||||
itertools.product(CACHE_OPTIONS, BATCH_SIZE_OPTIONS, SEED_OPTIONS)
|
||||
|
@ -37,7 +42,7 @@ if __name__ == "__main__":
|
|||
"--cdl-root",
|
||||
CDL_DATA_ROOT,
|
||||
"--num-workers",
|
||||
"8",
|
||||
"6",
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--epoch-size",
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import Dict, Iterator, List
|
|||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
from rasterio.crs import CRS
|
||||
from rtree.index import Index, Property
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset
|
||||
|
@ -29,12 +28,10 @@ class CustomBatchGeoSampler(BatchGeoSampler):
|
|||
class CustomGeoDataset(GeoDataset):
|
||||
def __init__(
|
||||
self,
|
||||
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
|
||||
crs: CRS = CRS.from_epsg(3005),
|
||||
res: float = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.index.insert(0, bounds)
|
||||
self.crs = crs
|
||||
self.res = res
|
||||
|
||||
|
@ -72,11 +69,11 @@ class TestBatchGeoSampler:
|
|||
class TestRandomBatchGeoSampler:
|
||||
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
|
||||
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
|
||||
index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
size = request.param
|
||||
return RandomBatchGeoSampler(index, size, batch_size=2, length=10)
|
||||
return RandomBatchGeoSampler(ds, size, batch_size=2, length=10)
|
||||
|
||||
def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
|
||||
for batch in sampler:
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import Dict, Iterator
|
|||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
from rasterio.crs import CRS
|
||||
from rtree.index import Index, Property
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset
|
||||
|
@ -29,12 +28,10 @@ class CustomGeoSampler(GeoSampler):
|
|||
class CustomGeoDataset(GeoDataset):
|
||||
def __init__(
|
||||
self,
|
||||
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
|
||||
crs: CRS = CRS.from_epsg(3005),
|
||||
res: float = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.index.insert(0, bounds)
|
||||
self.crs = crs
|
||||
self.res = res
|
||||
|
||||
|
@ -71,11 +68,11 @@ class TestGeoSampler:
|
|||
class TestRandomGeoSampler:
|
||||
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
|
||||
def sampler(self, request: SubRequest) -> RandomGeoSampler:
|
||||
index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
|
||||
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
|
||||
size = request.param
|
||||
return RandomGeoSampler(index, size, length=10)
|
||||
return RandomGeoSampler(ds, size, length=10)
|
||||
|
||||
def test_iter(self, sampler: RandomGeoSampler) -> None:
|
||||
for query in sampler:
|
||||
|
@ -116,11 +113,11 @@ class TestGridGeoSampler:
|
|||
],
|
||||
)
|
||||
def sampler(self, request: SubRequest) -> GridGeoSampler:
|
||||
index = Index(interleaved=False, properties=Property(dimension=3))
|
||||
index.insert(0, (0, 20, 0, 10, 40, 50))
|
||||
index.insert(1, (0, 20, 0, 10, 40, 50))
|
||||
ds = CustomGeoDataset()
|
||||
ds.index.insert(0, (0, 20, 0, 10, 40, 50))
|
||||
ds.index.insert(1, (0, 20, 0, 10, 40, 50))
|
||||
size, stride = request.param
|
||||
return GridGeoSampler(index, size, stride)
|
||||
return GridGeoSampler(ds, size, stride)
|
||||
|
||||
def test_iter(self, sampler: GridGeoSampler) -> None:
|
||||
for query in sampler:
|
||||
|
|
|
@ -22,6 +22,7 @@ import torch
|
|||
from rasterio.crs import CRS
|
||||
from rasterio.io import DatasetReader
|
||||
from rasterio.vrt import WarpedVRT
|
||||
from rasterio.windows import from_bounds
|
||||
from rtree.index import Index, Property
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -328,7 +329,16 @@ class RasterDataset(GeoDataset):
|
|||
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]
|
||||
|
||||
bounds = (query.minx, query.miny, query.maxx, query.maxy)
|
||||
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
|
||||
if len(vrt_fhs) == 1:
|
||||
src = vrt_fhs[0]
|
||||
out_width = int(round((query.maxx - query.minx) / self.res))
|
||||
out_height = int(round((query.maxy - query.miny) / self.res))
|
||||
out_shape = (src.count, out_height, out_width)
|
||||
dest = src.read(
|
||||
out_shape=out_shape, window=from_bounds(*bounds, src.transform)
|
||||
)
|
||||
else:
|
||||
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res)
|
||||
dest = dest.astype(np.int32)
|
||||
|
||||
tensor: Tensor = torch.tensor(dest) # type: ignore[attr-defined]
|
||||
|
|
|
@ -7,10 +7,9 @@ import abc
|
|||
import random
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from rtree.index import Index
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets import BoundingBox
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset
|
||||
|
||||
from .utils import _to_tuple, get_random_bounding_box
|
||||
|
||||
|
@ -43,13 +42,13 @@ class RandomBatchGeoSampler(BatchGeoSampler):
|
|||
This is particularly useful during training when you want to maximize the size of
|
||||
the dataset and return as many random :term:`chips <chip>` as possible.
|
||||
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
|
||||
from a tile-based dataset if possible.
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
|
||||
a tile-based dataset if possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Index,
|
||||
dataset: GeoDataset,
|
||||
size: Union[Tuple[float, float], float],
|
||||
batch_size: int,
|
||||
length: int,
|
||||
|
@ -65,21 +64,22 @@ class RandomBatchGeoSampler(BatchGeoSampler):
|
|||
height dimension, and the second *float* for the width dimension
|
||||
|
||||
Args:
|
||||
index: index of a :class:`~torchgeo.datasets.GeoDataset`
|
||||
dataset: dataset to index from
|
||||
size: dimensions of each :term:`patch` in units of CRS
|
||||
batch_size: number of samples per batch
|
||||
length: number of samples per epoch
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``index``)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
self.index = index
|
||||
self.index = dataset.index
|
||||
self.res = dataset.res
|
||||
self.size = _to_tuple(size)
|
||||
self.batch_size = batch_size
|
||||
self.length = length
|
||||
if roi is None:
|
||||
roi = BoundingBox(*index.bounds)
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(index.intersection(roi, objects=True))
|
||||
self.hits = list(self.index.intersection(roi, objects=True))
|
||||
|
||||
def __iter__(self) -> Iterator[List[BoundingBox]]:
|
||||
"""Return the indices of a dataset.
|
||||
|
@ -96,7 +96,7 @@ class RandomBatchGeoSampler(BatchGeoSampler):
|
|||
batch = []
|
||||
for _ in range(self.batch_size):
|
||||
|
||||
bounding_box = get_random_bounding_box(bounds, self.size)
|
||||
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
|
||||
batch.append(bounding_box)
|
||||
|
||||
yield batch
|
||||
|
|
|
@ -7,10 +7,9 @@ import abc
|
|||
import random
|
||||
from typing import Iterator, Optional, Tuple, Union
|
||||
|
||||
from rtree.index import Index
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from torchgeo.datasets import BoundingBox
|
||||
from torchgeo.datasets import BoundingBox, GeoDataset
|
||||
|
||||
from .utils import _to_tuple, get_random_bounding_box
|
||||
|
||||
|
@ -49,7 +48,7 @@ class RandomGeoSampler(GeoSampler):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
index: Index,
|
||||
dataset: GeoDataset,
|
||||
size: Union[Tuple[float, float], float],
|
||||
length: int,
|
||||
roi: Optional[BoundingBox] = None,
|
||||
|
@ -64,19 +63,20 @@ class RandomGeoSampler(GeoSampler):
|
|||
height dimension, and the second *float* for the width dimension
|
||||
|
||||
Args:
|
||||
index: index of a :class:`~torchgeo.datasets.GeoDataset`
|
||||
dataset: dataset to index from
|
||||
size: dimensions of each :term:`patch` in units of CRS
|
||||
length: number of random samples to draw per epoch
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``index``)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
self.index = index
|
||||
self.index = dataset.index
|
||||
self.res = dataset.res
|
||||
self.size = _to_tuple(size)
|
||||
self.length = length
|
||||
if roi is None:
|
||||
roi = BoundingBox(*index.bounds)
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(index.intersection(roi, objects=True))
|
||||
self.hits = list(self.index.intersection(roi, objects=True))
|
||||
|
||||
def __iter__(self) -> Iterator[BoundingBox]:
|
||||
"""Return the index of a dataset.
|
||||
|
@ -90,7 +90,7 @@ class RandomGeoSampler(GeoSampler):
|
|||
bounds = BoundingBox(*hit.bounds)
|
||||
|
||||
# Choose a random index within that tile
|
||||
bounding_box = get_random_bounding_box(bounds, self.size)
|
||||
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
|
||||
|
||||
yield bounding_box
|
||||
|
||||
|
@ -117,13 +117,13 @@ class GridGeoSampler(GeoSampler):
|
|||
to the `receptive field <https://distill.pub/2019/computing-receptive-fields/>`_ of
|
||||
the CNN.
|
||||
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
|
||||
from a non-tile-based dataset if possible.
|
||||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``dataset`` should be
|
||||
a non-tile-based dataset if possible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: Index,
|
||||
dataset: GeoDataset,
|
||||
size: Union[Tuple[float, float], float],
|
||||
stride: Union[Tuple[float, float], float],
|
||||
roi: Optional[BoundingBox] = None,
|
||||
|
@ -138,18 +138,19 @@ class GridGeoSampler(GeoSampler):
|
|||
height dimension, and the second *float* for the width dimension
|
||||
|
||||
Args:
|
||||
index: index of a :class:`~torchgeo.datasets.GeoDataset`
|
||||
dataset: dataset to index from
|
||||
size: dimensions of each :term:`patch` in units of CRS
|
||||
stride: distance to skip between each patch
|
||||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
|
||||
(defaults to the bounds of ``dataset.index``)
|
||||
"""
|
||||
self.index = index
|
||||
self.index = dataset.index
|
||||
self.size = _to_tuple(size)
|
||||
self.stride = _to_tuple(stride)
|
||||
if roi is None:
|
||||
roi = BoundingBox(*index.bounds)
|
||||
roi = BoundingBox(*self.index.bounds)
|
||||
self.roi = roi
|
||||
self.hits = list(index.intersection(roi, objects=True))
|
||||
self.hits = list(self.index.intersection(roi, objects=True))
|
||||
|
||||
self.length: int = 0
|
||||
for hit in self.hits:
|
||||
|
|
|
@ -25,7 +25,7 @@ def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]:
|
|||
|
||||
|
||||
def get_random_bounding_box(
|
||||
bounds: BoundingBox, size: Union[Tuple[float, float], float]
|
||||
bounds: BoundingBox, size: Union[Tuple[float, float], float], res: float
|
||||
) -> BoundingBox:
|
||||
"""Returns a random bounding box within a given bounding box.
|
||||
|
||||
|
@ -45,13 +45,16 @@ def get_random_bounding_box(
|
|||
"""
|
||||
t_size: Tuple[float, float] = _to_tuple(size)
|
||||
|
||||
minx = random.uniform(bounds.minx, bounds.maxx - t_size[1])
|
||||
width = (bounds.maxx - bounds.minx - t_size[1]) // res
|
||||
minx = random.randrange(int(width)) * res + bounds.minx
|
||||
maxx = minx + t_size[1]
|
||||
|
||||
miny = random.uniform(bounds.miny, bounds.maxy - t_size[0])
|
||||
height = (bounds.maxy - bounds.miny - t_size[0]) // res
|
||||
miny = random.randrange(int(height)) * res + bounds.miny
|
||||
maxy = miny + t_size[0]
|
||||
|
||||
mint = bounds.mint
|
||||
maxt = bounds.maxt
|
||||
|
||||
return BoundingBox(minx, maxx, miny, maxy, mint, maxt)
|
||||
query = BoundingBox(minx, maxx, miny, maxy, mint, maxt)
|
||||
return query
|
||||
|
|
Загрузка…
Ссылка в новой задаче