diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index 77b44898e..1da51ec62 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -336,6 +336,10 @@ "\n", "Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what you want for 99% of datasets, but can be overridden for tasks like pixel-wise regression (where the target mask should be float32).\n", "\n", + "### `resampling`\n", + "\n", + "Defaults to bilinear for float Tensors and nearest for int Tensors. Can be overridden for custom resampling algorithms.\n", + "\n", "### `separate_files`\n", "\n", "If your data comes with each spectral band in a separate files, as is the case with Sentinel-2, use `separate_files = True`. If all spectral bands are stored in a single file, use `separate_files = False` instead.\n", diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index caf43e336..e3b11e7fc 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -5,12 +5,14 @@ import pickle import sys from collections.abc import Iterable from pathlib import Path +from typing import Any import pytest import torch import torch.nn as nn from _pytest.fixtures import SubRequest from rasterio.crs import CRS +from rasterio.enums import Resampling from torch.utils.data import ConcatDataset from torchgeo.datasets import ( @@ -49,6 +51,16 @@ class CustomGeoDataset(GeoDataset): return {'index': bounds} +class CustomRasterDataset(RasterDataset): + def __init__(self, dtype: torch.dtype, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._dtype = dtype + + @property + def dtype(self) -> torch.dtype: + return self._dtype + + class CustomVectorDataset(VectorDataset): filename_glob = '*.geojson' date_format = '%Y' @@ -274,6 +286,22 @@ class TestRasterDataset: assert isinstance(x['image'], torch.Tensor) assert x['image'].dtype == torch.float32 + @pytest.mark.parametrize('dtype', [torch.float, torch.double]) + def test_resampling_float_dtype(self, dtype: torch.dtype) -> None: + paths = os.path.join('tests', 'data', 'raster', 'uint16') + ds = CustomRasterDataset(dtype, paths) + x = ds[ds.bounds] + assert x['image'].dtype == dtype + assert ds.resampling == Resampling.bilinear + + @pytest.mark.parametrize('dtype', [torch.long, torch.bool]) + def test_resampling_int_dtype(self, dtype: torch.dtype) -> None: + paths = os.path.join('tests', 'data', 'raster', 'uint16') + ds = CustomRasterDataset(dtype, paths) + x = ds[ds.bounds] + assert x['image'].dtype == dtype + assert ds.resampling == Resampling.nearest + def test_invalid_query(self, sentinel: Sentinel2) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 6bd2209b3..9ee33356a 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -22,6 +22,7 @@ import rasterio.merge import shapely import torch from rasterio.crs import CRS +from rasterio.enums import Resampling from rasterio.io import DatasetReader from rasterio.vrt import WarpedVRT from rtree.index import Index, Property @@ -309,7 +310,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): files |= set(glob.iglob(pathname, recursive=True)) elif os.path.isfile(path) or path_is_vsi(path): files.add(path) - elif not hasattr(self, "download"): + elif not hasattr(self, 'download'): warnings.warn( f"Could not find any relevant files for provided path '{path}'. " f'Path was ignored.', @@ -384,6 +385,23 @@ class RasterDataset(GeoDataset): else: return torch.long + @property + def resampling(self) -> Resampling: + """Resampling algorithm used when reading input files. + + Defaults to bilinear for float dtypes and nearest for int dtypes. + + Returns: + The resampling method to use. + + .. versionadded:: 0.6 + """ + # Based on torch.is_floating_point + if self.dtype in [torch.float64, torch.float32, torch.float16, torch.bfloat16]: + return Resampling.bilinear + else: + return Resampling.nearest + def __init__( self, paths: str | Iterable[str] = 'data', @@ -555,7 +573,9 @@ class RasterDataset(GeoDataset): vrt_fhs = [self._load_warp_file(fp) for fp in filepaths] bounds = (query.minx, query.miny, query.maxx, query.maxy) - dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res, indexes=band_indexes) + dest, _ = rasterio.merge.merge( + vrt_fhs, bounds, self.res, indexes=band_indexes, resampling=self.resampling + ) # Use array_to_tensor since merge may return uint16/uint32 arrays. tensor = array_to_tensor(dest) return tensor