зеркало из https://github.com/microsoft/torchgeo.git
GeoDataset: ignore other bands for separate files (#2222)
* GeoDataset: ignore other bands for separate files * Ruff * Mypy * Check VSI paths too Co-authored-by: Sieger Falkena <siegerfalkena@hotmail.com> * Fix formatting * Add support for / in URLs on Windows where path separator is \ --------- Co-authored-by: Sieger Falkena <siegerfalkena@hotmail.com>
This commit is contained in:
Родитель
880593e7ef
Коммит
c26512e39d
|
@ -207,14 +207,27 @@ class TestGeoDataset:
|
||||||
|
|
||||||
|
|
||||||
class TestRasterDataset:
|
class TestRasterDataset:
|
||||||
|
naip_dir = os.path.join('tests', 'data', 'naip')
|
||||||
|
s2_dir = os.path.join(
|
||||||
|
'tests',
|
||||||
|
'data',
|
||||||
|
'sentinel2',
|
||||||
|
'S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE',
|
||||||
|
'GRANULE',
|
||||||
|
'L2A_T26EMU_A035569_20220414T110747',
|
||||||
|
'IMG_DATA',
|
||||||
|
'R10m',
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture(params=zip([['R', 'G', 'B'], None], [True, False]))
|
@pytest.fixture(params=zip([['R', 'G', 'B'], None], [True, False]))
|
||||||
def naip(self, request: SubRequest) -> NAIP:
|
def naip(self, request: SubRequest) -> NAIP:
|
||||||
root = os.path.join('tests', 'data', 'naip')
|
|
||||||
bands = request.param[0]
|
bands = request.param[0]
|
||||||
crs = CRS.from_epsg(4087)
|
crs = CRS.from_epsg(4087)
|
||||||
transforms = nn.Identity()
|
transforms = nn.Identity()
|
||||||
cache = request.param[1]
|
cache = request.param[1]
|
||||||
return NAIP(root, crs=crs, bands=bands, transforms=transforms, cache=cache)
|
return NAIP(
|
||||||
|
self.naip_dir, crs=crs, bands=bands, transforms=transforms, cache=cache
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
params=zip(
|
params=zip(
|
||||||
|
@ -236,34 +249,55 @@ class TestRasterDataset:
|
||||||
'paths',
|
'paths',
|
||||||
[
|
[
|
||||||
# Single directory
|
# Single directory
|
||||||
os.path.join('tests', 'data', 'naip'),
|
naip_dir,
|
||||||
# Multiple directories
|
# Multiple directories
|
||||||
[
|
[naip_dir, naip_dir],
|
||||||
os.path.join('tests', 'data', 'naip'),
|
|
||||||
os.path.join('tests', 'data', 'naip'),
|
|
||||||
],
|
|
||||||
# Single file
|
|
||||||
os.path.join('tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'),
|
|
||||||
# Multiple files
|
# Multiple files
|
||||||
(
|
(
|
||||||
os.path.join(
|
os.path.join(naip_dir, 'm_3807511_ne_18_060_20181104.tif'),
|
||||||
'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'
|
os.path.join(naip_dir, 'm_3807511_ne_18_060_20190605.tif'),
|
||||||
),
|
|
||||||
os.path.join(
|
|
||||||
'tests', 'data', 'naip', 'm_3807511_ne_18_060_20190605.tif'
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
# Combination
|
# Combination
|
||||||
{
|
{naip_dir, os.path.join(naip_dir, 'm_3807511_ne_18_060_20181104.tif')},
|
||||||
os.path.join('tests', 'data', 'naip'),
|
|
||||||
os.path.join(
|
|
||||||
'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'
|
|
||||||
),
|
|
||||||
},
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_files(self, paths: str | Iterable[str]) -> None:
|
def test_files(self, paths: str | Iterable[str]) -> None:
|
||||||
assert 1 <= len(NAIP(paths).files) <= 2
|
assert len(NAIP(paths).files) == 2
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'paths',
|
||||||
|
[
|
||||||
|
# Single directory
|
||||||
|
s2_dir,
|
||||||
|
# Multiple directories
|
||||||
|
[s2_dir, s2_dir],
|
||||||
|
# Multiple files (single band)
|
||||||
|
[
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'),
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'),
|
||||||
|
],
|
||||||
|
# Multiple files (multiple bands)
|
||||||
|
[
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'),
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20190414T110751_B03_10m.jp2'),
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20190414T110751_B02_10m.jp2'),
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'),
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20220414T110751_B03_10m.jp2'),
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20220414T110751_B02_10m.jp2'),
|
||||||
|
],
|
||||||
|
# Combination
|
||||||
|
[
|
||||||
|
s2_dir,
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'),
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'),
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20220414T110751_B03_10m.jp2'),
|
||||||
|
os.path.join(s2_dir, 'T26EMU_20220414T110751_B02_10m.jp2'),
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.filterwarnings('ignore:Could not find any relevant files')
|
||||||
|
def test_files_separate(self, paths: str | Iterable[str]) -> None:
|
||||||
|
assert len(Sentinel2(paths, bands=Sentinel2.rgb_bands).files) == 2
|
||||||
|
|
||||||
def test_getitem_single_file(self, naip: NAIP) -> None:
|
def test_getitem_single_file(self, naip: NAIP) -> None:
|
||||||
x = naip[naip.bounds]
|
x = naip[naip.bounds]
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
"""Base classes for all :mod:`torchgeo` datasets."""
|
"""Base classes for all :mod:`torchgeo` datasets."""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
|
import fnmatch
|
||||||
import functools
|
import functools
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
@ -310,7 +311,9 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
pathname = os.path.join(path, '**', self.filename_glob)
|
pathname = os.path.join(path, '**', self.filename_glob)
|
||||||
files |= set(glob.iglob(pathname, recursive=True))
|
files |= set(glob.iglob(pathname, recursive=True))
|
||||||
elif os.path.isfile(path) or path_is_vsi(path):
|
elif (os.path.isfile(path) or path_is_vsi(path)) and fnmatch.fnmatch(
|
||||||
|
str(path), f'*{self.filename_glob}'
|
||||||
|
):
|
||||||
files.add(path)
|
files.add(path)
|
||||||
elif not hasattr(self, 'download'):
|
elif not hasattr(self, 'download'):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
Загрузка…
Ссылка в новой задаче