зеркало из https://github.com/microsoft/torchgeo.git
GeoDataset: ignore other bands for separate files (#2222)
* GeoDataset: ignore other bands for separate files * Ruff * Mypy * Check VSI paths too Co-authored-by: Sieger Falkena <siegerfalkena@hotmail.com> * Fix formatting * Add support for / in URLs on Windows where path separator is \ --------- Co-authored-by: Sieger Falkena <siegerfalkena@hotmail.com>
This commit is contained in:
Родитель
880593e7ef
Коммит
c26512e39d
|
@ -207,14 +207,27 @@ class TestGeoDataset:
|
|||
|
||||
|
||||
class TestRasterDataset:
|
||||
naip_dir = os.path.join('tests', 'data', 'naip')
|
||||
s2_dir = os.path.join(
|
||||
'tests',
|
||||
'data',
|
||||
'sentinel2',
|
||||
'S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE',
|
||||
'GRANULE',
|
||||
'L2A_T26EMU_A035569_20220414T110747',
|
||||
'IMG_DATA',
|
||||
'R10m',
|
||||
)
|
||||
|
||||
@pytest.fixture(params=zip([['R', 'G', 'B'], None], [True, False]))
|
||||
def naip(self, request: SubRequest) -> NAIP:
|
||||
root = os.path.join('tests', 'data', 'naip')
|
||||
bands = request.param[0]
|
||||
crs = CRS.from_epsg(4087)
|
||||
transforms = nn.Identity()
|
||||
cache = request.param[1]
|
||||
return NAIP(root, crs=crs, bands=bands, transforms=transforms, cache=cache)
|
||||
return NAIP(
|
||||
self.naip_dir, crs=crs, bands=bands, transforms=transforms, cache=cache
|
||||
)
|
||||
|
||||
@pytest.fixture(
|
||||
params=zip(
|
||||
|
@ -236,34 +249,55 @@ class TestRasterDataset:
|
|||
'paths',
|
||||
[
|
||||
# Single directory
|
||||
os.path.join('tests', 'data', 'naip'),
|
||||
naip_dir,
|
||||
# Multiple directories
|
||||
[
|
||||
os.path.join('tests', 'data', 'naip'),
|
||||
os.path.join('tests', 'data', 'naip'),
|
||||
],
|
||||
# Single file
|
||||
os.path.join('tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'),
|
||||
[naip_dir, naip_dir],
|
||||
# Multiple files
|
||||
(
|
||||
os.path.join(
|
||||
'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'
|
||||
),
|
||||
os.path.join(
|
||||
'tests', 'data', 'naip', 'm_3807511_ne_18_060_20190605.tif'
|
||||
),
|
||||
os.path.join(naip_dir, 'm_3807511_ne_18_060_20181104.tif'),
|
||||
os.path.join(naip_dir, 'm_3807511_ne_18_060_20190605.tif'),
|
||||
),
|
||||
# Combination
|
||||
{
|
||||
os.path.join('tests', 'data', 'naip'),
|
||||
os.path.join(
|
||||
'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'
|
||||
),
|
||||
},
|
||||
{naip_dir, os.path.join(naip_dir, 'm_3807511_ne_18_060_20181104.tif')},
|
||||
],
|
||||
)
|
||||
def test_files(self, paths: str | Iterable[str]) -> None:
|
||||
assert 1 <= len(NAIP(paths).files) <= 2
|
||||
assert len(NAIP(paths).files) == 2
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'paths',
|
||||
[
|
||||
# Single directory
|
||||
s2_dir,
|
||||
# Multiple directories
|
||||
[s2_dir, s2_dir],
|
||||
# Multiple files (single band)
|
||||
[
|
||||
os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'),
|
||||
os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'),
|
||||
],
|
||||
# Multiple files (multiple bands)
|
||||
[
|
||||
os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'),
|
||||
os.path.join(s2_dir, 'T26EMU_20190414T110751_B03_10m.jp2'),
|
||||
os.path.join(s2_dir, 'T26EMU_20190414T110751_B02_10m.jp2'),
|
||||
os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'),
|
||||
os.path.join(s2_dir, 'T26EMU_20220414T110751_B03_10m.jp2'),
|
||||
os.path.join(s2_dir, 'T26EMU_20220414T110751_B02_10m.jp2'),
|
||||
],
|
||||
# Combination
|
||||
[
|
||||
s2_dir,
|
||||
os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'),
|
||||
os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'),
|
||||
os.path.join(s2_dir, 'T26EMU_20220414T110751_B03_10m.jp2'),
|
||||
os.path.join(s2_dir, 'T26EMU_20220414T110751_B02_10m.jp2'),
|
||||
],
|
||||
],
|
||||
)
|
||||
@pytest.mark.filterwarnings('ignore:Could not find any relevant files')
|
||||
def test_files_separate(self, paths: str | Iterable[str]) -> None:
|
||||
assert len(Sentinel2(paths, bands=Sentinel2.rgb_bands).files) == 2
|
||||
|
||||
def test_getitem_single_file(self, naip: NAIP) -> None:
|
||||
x = naip[naip.bounds]
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
"""Base classes for all :mod:`torchgeo` datasets."""
|
||||
|
||||
import abc
|
||||
import fnmatch
|
||||
import functools
|
||||
import glob
|
||||
import os
|
||||
|
@ -310,7 +311,9 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
|||
if os.path.isdir(path):
|
||||
pathname = os.path.join(path, '**', self.filename_glob)
|
||||
files |= set(glob.iglob(pathname, recursive=True))
|
||||
elif os.path.isfile(path) or path_is_vsi(path):
|
||||
elif (os.path.isfile(path) or path_is_vsi(path)) and fnmatch.fnmatch(
|
||||
str(path), f'*{self.filename_glob}'
|
||||
):
|
||||
files.add(path)
|
||||
elif not hasattr(self, 'download'):
|
||||
warnings.warn(
|
||||
|
|
Загрузка…
Ссылка в новой задаче