GeoDataset: ignore other bands for separate files (#2222)

* GeoDataset: ignore other bands for separate files

* Ruff

* Mypy

* Check VSI paths too

Co-authored-by: Sieger Falkena <siegerfalkena@hotmail.com>

* Fix formatting

* Add support for / in URLs on Windows where path separator is \

---------

Co-authored-by: Sieger Falkena <siegerfalkena@hotmail.com>
This commit is contained in:
Adam J. Stewart 2024-08-19 12:05:26 +02:00 коммит произвёл GitHub
Родитель 880593e7ef
Коммит c26512e39d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 60 добавлений и 23 удалений

Просмотреть файл

@ -207,14 +207,27 @@ class TestGeoDataset:
class TestRasterDataset: class TestRasterDataset:
naip_dir = os.path.join('tests', 'data', 'naip')
s2_dir = os.path.join(
'tests',
'data',
'sentinel2',
'S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE',
'GRANULE',
'L2A_T26EMU_A035569_20220414T110747',
'IMG_DATA',
'R10m',
)
@pytest.fixture(params=zip([['R', 'G', 'B'], None], [True, False])) @pytest.fixture(params=zip([['R', 'G', 'B'], None], [True, False]))
def naip(self, request: SubRequest) -> NAIP: def naip(self, request: SubRequest) -> NAIP:
root = os.path.join('tests', 'data', 'naip')
bands = request.param[0] bands = request.param[0]
crs = CRS.from_epsg(4087) crs = CRS.from_epsg(4087)
transforms = nn.Identity() transforms = nn.Identity()
cache = request.param[1] cache = request.param[1]
return NAIP(root, crs=crs, bands=bands, transforms=transforms, cache=cache) return NAIP(
self.naip_dir, crs=crs, bands=bands, transforms=transforms, cache=cache
)
@pytest.fixture( @pytest.fixture(
params=zip( params=zip(
@ -236,34 +249,55 @@ class TestRasterDataset:
'paths', 'paths',
[ [
# Single directory # Single directory
os.path.join('tests', 'data', 'naip'), naip_dir,
# Multiple directories # Multiple directories
[ [naip_dir, naip_dir],
os.path.join('tests', 'data', 'naip'),
os.path.join('tests', 'data', 'naip'),
],
# Single file
os.path.join('tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'),
# Multiple files # Multiple files
( (
os.path.join( os.path.join(naip_dir, 'm_3807511_ne_18_060_20181104.tif'),
'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif' os.path.join(naip_dir, 'm_3807511_ne_18_060_20190605.tif'),
),
os.path.join(
'tests', 'data', 'naip', 'm_3807511_ne_18_060_20190605.tif'
),
), ),
# Combination # Combination
{ {naip_dir, os.path.join(naip_dir, 'm_3807511_ne_18_060_20181104.tif')},
os.path.join('tests', 'data', 'naip'),
os.path.join(
'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'
),
},
], ],
) )
def test_files(self, paths: str | Iterable[str]) -> None: def test_files(self, paths: str | Iterable[str]) -> None:
assert 1 <= len(NAIP(paths).files) <= 2 assert len(NAIP(paths).files) == 2
@pytest.mark.parametrize(
'paths',
[
# Single directory
s2_dir,
# Multiple directories
[s2_dir, s2_dir],
# Multiple files (single band)
[
os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'),
os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'),
],
# Multiple files (multiple bands)
[
os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'),
os.path.join(s2_dir, 'T26EMU_20190414T110751_B03_10m.jp2'),
os.path.join(s2_dir, 'T26EMU_20190414T110751_B02_10m.jp2'),
os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'),
os.path.join(s2_dir, 'T26EMU_20220414T110751_B03_10m.jp2'),
os.path.join(s2_dir, 'T26EMU_20220414T110751_B02_10m.jp2'),
],
# Combination
[
s2_dir,
os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'),
os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'),
os.path.join(s2_dir, 'T26EMU_20220414T110751_B03_10m.jp2'),
os.path.join(s2_dir, 'T26EMU_20220414T110751_B02_10m.jp2'),
],
],
)
@pytest.mark.filterwarnings('ignore:Could not find any relevant files')
def test_files_separate(self, paths: str | Iterable[str]) -> None:
assert len(Sentinel2(paths, bands=Sentinel2.rgb_bands).files) == 2
def test_getitem_single_file(self, naip: NAIP) -> None: def test_getitem_single_file(self, naip: NAIP) -> None:
x = naip[naip.bounds] x = naip[naip.bounds]

Просмотреть файл

@ -4,6 +4,7 @@
"""Base classes for all :mod:`torchgeo` datasets.""" """Base classes for all :mod:`torchgeo` datasets."""
import abc import abc
import fnmatch
import functools import functools
import glob import glob
import os import os
@ -310,7 +311,9 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
if os.path.isdir(path): if os.path.isdir(path):
pathname = os.path.join(path, '**', self.filename_glob) pathname = os.path.join(path, '**', self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True)) files |= set(glob.iglob(pathname, recursive=True))
elif os.path.isfile(path) or path_is_vsi(path): elif (os.path.isfile(path) or path_is_vsi(path)) and fnmatch.fnmatch(
str(path), f'*{self.filename_glob}'
):
files.add(path) files.add(path)
elif not hasattr(self, 'download'): elif not hasattr(self, 'download'):
warnings.warn( warnings.warn(