RasterDataset: add control over resampling algorithm (#2015)

* RasterDataset: add control over resampling algorithm

* Fix type hints

* cubic -> bilinear

* Ruff: single quotes
This commit is contained in:
Adam J. Stewart 2024-05-13 17:00:45 +02:00 коммит произвёл GitHub
Родитель 5976bd15bf
Коммит 25fb9ccfb9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 54 добавлений и 2 удалений

Просмотреть файл

@ -336,6 +336,10 @@
"\n",
"Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what you want for 99% of datasets, but can be overridden for tasks like pixel-wise regression (where the target mask should be float32).\n",
"\n",
"### `resampling`\n",
"\n",
"Defaults to bilinear for float Tensors and nearest for int Tensors. Can be overridden for custom resampling algorithms.\n",
"\n",
"### `separate_files`\n",
"\n",
"If your data comes with each spectral band in a separate files, as is the case with Sentinel-2, use `separate_files = True`. If all spectral bands are stored in a single file, use `separate_files = False` instead.\n",

Просмотреть файл

@ -5,12 +5,14 @@ import pickle
import sys
from collections.abc import Iterable
from pathlib import Path
from typing import Any
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from rasterio.enums import Resampling
from torch.utils.data import ConcatDataset
from torchgeo.datasets import (
@ -49,6 +51,16 @@ class CustomGeoDataset(GeoDataset):
return {'index': bounds}
class CustomRasterDataset(RasterDataset):
def __init__(self, dtype: torch.dtype, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._dtype = dtype
@property
def dtype(self) -> torch.dtype:
return self._dtype
class CustomVectorDataset(VectorDataset):
filename_glob = '*.geojson'
date_format = '%Y'
@ -274,6 +286,22 @@ class TestRasterDataset:
assert isinstance(x['image'], torch.Tensor)
assert x['image'].dtype == torch.float32
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
def test_resampling_float_dtype(self, dtype: torch.dtype) -> None:
paths = os.path.join('tests', 'data', 'raster', 'uint16')
ds = CustomRasterDataset(dtype, paths)
x = ds[ds.bounds]
assert x['image'].dtype == dtype
assert ds.resampling == Resampling.bilinear
@pytest.mark.parametrize('dtype', [torch.long, torch.bool])
def test_resampling_int_dtype(self, dtype: torch.dtype) -> None:
paths = os.path.join('tests', 'data', 'raster', 'uint16')
ds = CustomRasterDataset(dtype, paths)
x = ds[ds.bounds]
assert x['image'].dtype == dtype
assert ds.resampling == Resampling.nearest
def test_invalid_query(self, sentinel: Sentinel2) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(

Просмотреть файл

@ -22,6 +22,7 @@ import rasterio.merge
import shapely
import torch
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.io import DatasetReader
from rasterio.vrt import WarpedVRT
from rtree.index import Index, Property
@ -309,7 +310,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
files |= set(glob.iglob(pathname, recursive=True))
elif os.path.isfile(path) or path_is_vsi(path):
files.add(path)
elif not hasattr(self, "download"):
elif not hasattr(self, 'download'):
warnings.warn(
f"Could not find any relevant files for provided path '{path}'. "
f'Path was ignored.',
@ -384,6 +385,23 @@ class RasterDataset(GeoDataset):
else:
return torch.long
@property
def resampling(self) -> Resampling:
"""Resampling algorithm used when reading input files.
Defaults to bilinear for float dtypes and nearest for int dtypes.
Returns:
The resampling method to use.
.. versionadded:: 0.6
"""
# Based on torch.is_floating_point
if self.dtype in [torch.float64, torch.float32, torch.float16, torch.bfloat16]:
return Resampling.bilinear
else:
return Resampling.nearest
def __init__(
self,
paths: str | Iterable[str] = 'data',
@ -555,7 +573,9 @@ class RasterDataset(GeoDataset):
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]
bounds = (query.minx, query.miny, query.maxx, query.maxy)
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res, indexes=band_indexes)
dest, _ = rasterio.merge.merge(
vrt_fhs, bounds, self.res, indexes=band_indexes, resampling=self.resampling
)
# Use array_to_tensor since merge may return uint16/uint32 arrays.
tensor = array_to_tensor(dest)
return tensor