зеркало из https://github.com/microsoft/torchgeo.git
RasterDataset: add control over resampling algorithm (#2015)
* RasterDataset: add control over resampling algorithm * Fix type hints * cubic -> bilinear * Ruff: single quotes
This commit is contained in:
Родитель
5976bd15bf
Коммит
25fb9ccfb9
|
@ -336,6 +336,10 @@
|
||||||
"\n",
|
"\n",
|
||||||
"Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what you want for 99% of datasets, but can be overridden for tasks like pixel-wise regression (where the target mask should be float32).\n",
|
"Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what you want for 99% of datasets, but can be overridden for tasks like pixel-wise regression (where the target mask should be float32).\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"### `resampling`\n",
|
||||||
|
"\n",
|
||||||
|
"Defaults to bilinear for float Tensors and nearest for int Tensors. Can be overridden for custom resampling algorithms.\n",
|
||||||
|
"\n",
|
||||||
"### `separate_files`\n",
|
"### `separate_files`\n",
|
||||||
"\n",
|
"\n",
|
||||||
"If your data comes with each spectral band in a separate files, as is the case with Sentinel-2, use `separate_files = True`. If all spectral bands are stored in a single file, use `separate_files = False` instead.\n",
|
"If your data comes with each spectral band in a separate files, as is the case with Sentinel-2, use `separate_files = True`. If all spectral bands are stored in a single file, use `separate_files = False` instead.\n",
|
||||||
|
|
|
@ -5,12 +5,14 @@ import pickle
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from _pytest.fixtures import SubRequest
|
from _pytest.fixtures import SubRequest
|
||||||
from rasterio.crs import CRS
|
from rasterio.crs import CRS
|
||||||
|
from rasterio.enums import Resampling
|
||||||
from torch.utils.data import ConcatDataset
|
from torch.utils.data import ConcatDataset
|
||||||
|
|
||||||
from torchgeo.datasets import (
|
from torchgeo.datasets import (
|
||||||
|
@ -49,6 +51,16 @@ class CustomGeoDataset(GeoDataset):
|
||||||
return {'index': bounds}
|
return {'index': bounds}
|
||||||
|
|
||||||
|
|
||||||
|
class CustomRasterDataset(RasterDataset):
|
||||||
|
def __init__(self, dtype: torch.dtype, *args: Any, **kwargs: Any) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._dtype = dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
return self._dtype
|
||||||
|
|
||||||
|
|
||||||
class CustomVectorDataset(VectorDataset):
|
class CustomVectorDataset(VectorDataset):
|
||||||
filename_glob = '*.geojson'
|
filename_glob = '*.geojson'
|
||||||
date_format = '%Y'
|
date_format = '%Y'
|
||||||
|
@ -274,6 +286,22 @@ class TestRasterDataset:
|
||||||
assert isinstance(x['image'], torch.Tensor)
|
assert isinstance(x['image'], torch.Tensor)
|
||||||
assert x['image'].dtype == torch.float32
|
assert x['image'].dtype == torch.float32
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.float, torch.double])
|
||||||
|
def test_resampling_float_dtype(self, dtype: torch.dtype) -> None:
|
||||||
|
paths = os.path.join('tests', 'data', 'raster', 'uint16')
|
||||||
|
ds = CustomRasterDataset(dtype, paths)
|
||||||
|
x = ds[ds.bounds]
|
||||||
|
assert x['image'].dtype == dtype
|
||||||
|
assert ds.resampling == Resampling.bilinear
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.long, torch.bool])
|
||||||
|
def test_resampling_int_dtype(self, dtype: torch.dtype) -> None:
|
||||||
|
paths = os.path.join('tests', 'data', 'raster', 'uint16')
|
||||||
|
ds = CustomRasterDataset(dtype, paths)
|
||||||
|
x = ds[ds.bounds]
|
||||||
|
assert x['image'].dtype == dtype
|
||||||
|
assert ds.resampling == Resampling.nearest
|
||||||
|
|
||||||
def test_invalid_query(self, sentinel: Sentinel2) -> None:
|
def test_invalid_query(self, sentinel: Sentinel2) -> None:
|
||||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
|
|
|
@ -22,6 +22,7 @@ import rasterio.merge
|
||||||
import shapely
|
import shapely
|
||||||
import torch
|
import torch
|
||||||
from rasterio.crs import CRS
|
from rasterio.crs import CRS
|
||||||
|
from rasterio.enums import Resampling
|
||||||
from rasterio.io import DatasetReader
|
from rasterio.io import DatasetReader
|
||||||
from rasterio.vrt import WarpedVRT
|
from rasterio.vrt import WarpedVRT
|
||||||
from rtree.index import Index, Property
|
from rtree.index import Index, Property
|
||||||
|
@ -309,7 +310,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
||||||
files |= set(glob.iglob(pathname, recursive=True))
|
files |= set(glob.iglob(pathname, recursive=True))
|
||||||
elif os.path.isfile(path) or path_is_vsi(path):
|
elif os.path.isfile(path) or path_is_vsi(path):
|
||||||
files.add(path)
|
files.add(path)
|
||||||
elif not hasattr(self, "download"):
|
elif not hasattr(self, 'download'):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Could not find any relevant files for provided path '{path}'. "
|
f"Could not find any relevant files for provided path '{path}'. "
|
||||||
f'Path was ignored.',
|
f'Path was ignored.',
|
||||||
|
@ -384,6 +385,23 @@ class RasterDataset(GeoDataset):
|
||||||
else:
|
else:
|
||||||
return torch.long
|
return torch.long
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resampling(self) -> Resampling:
|
||||||
|
"""Resampling algorithm used when reading input files.
|
||||||
|
|
||||||
|
Defaults to bilinear for float dtypes and nearest for int dtypes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The resampling method to use.
|
||||||
|
|
||||||
|
.. versionadded:: 0.6
|
||||||
|
"""
|
||||||
|
# Based on torch.is_floating_point
|
||||||
|
if self.dtype in [torch.float64, torch.float32, torch.float16, torch.bfloat16]:
|
||||||
|
return Resampling.bilinear
|
||||||
|
else:
|
||||||
|
return Resampling.nearest
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
paths: str | Iterable[str] = 'data',
|
paths: str | Iterable[str] = 'data',
|
||||||
|
@ -555,7 +573,9 @@ class RasterDataset(GeoDataset):
|
||||||
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]
|
vrt_fhs = [self._load_warp_file(fp) for fp in filepaths]
|
||||||
|
|
||||||
bounds = (query.minx, query.miny, query.maxx, query.maxy)
|
bounds = (query.minx, query.miny, query.maxx, query.maxy)
|
||||||
dest, _ = rasterio.merge.merge(vrt_fhs, bounds, self.res, indexes=band_indexes)
|
dest, _ = rasterio.merge.merge(
|
||||||
|
vrt_fhs, bounds, self.res, indexes=band_indexes, resampling=self.resampling
|
||||||
|
)
|
||||||
# Use array_to_tensor since merge may return uint16/uint32 arrays.
|
# Use array_to_tensor since merge may return uint16/uint32 arrays.
|
||||||
tensor = array_to_tensor(dest)
|
tensor = array_to_tensor(dest)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
Загрузка…
Ссылка в новой задаче