зеркало из https://github.com/microsoft/torchgeo.git
Add check if path is vsi (#1612)
* Add check if path is vsi * Add url to reference for apache vsi syntax * Add missing check to if * Copy rasterio SCHEMES definition into torchgeo * Check all schemes, not only last * Simplify method path_is_vsi * Add tests * Remove print * Update test names * Add missing comma in list * Update torchgeo/datasets/utils.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/utils.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Use pytest tmp_path for test * Warn if some of input paths are invalid * Update docstring for mocked class * Handle tests failing due to UserWarning * Remove unnecessary filterwarning * Test CustomGeoDataset instead of MockRasterDataset * Merge two similar tests * str instead of as_posix Wait with pathlib syntax Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> --------- Co-authored-by: Adrian Tofting <adrian@vake.ai> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
49cf861470
Коммит
32a7307329
|
@ -4,7 +4,7 @@ import os
|
|||
import pickle
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -33,11 +33,13 @@ class CustomGeoDataset(GeoDataset):
|
|||
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
|
||||
crs: CRS = CRS.from_epsg(4087),
|
||||
res: float = 1,
|
||||
paths: Optional[Union[str, Iterable[str]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.index.insert(0, tuple(bounds))
|
||||
self._crs = crs
|
||||
self.res = res
|
||||
self.paths = paths or []
|
||||
|
||||
def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
|
||||
hits = self.index.intersection(tuple(query), objects=True)
|
||||
|
@ -152,6 +154,23 @@ class TestGeoDataset:
|
|||
):
|
||||
dataset & ds2 # type: ignore[operator]
|
||||
|
||||
def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None:
|
||||
paths = [str(tmp_path), str(tmp_path / "non_existing_file.tif")]
|
||||
with pytest.warns(UserWarning, match="Path was ignored."):
|
||||
assert len(CustomGeoDataset(paths=paths).files) == 0
|
||||
|
||||
def test_files_property_for_virtual_files(self) -> None:
|
||||
# Tests only a subset of schemes and combinations.
|
||||
paths = [
|
||||
"file://directory/file.tif",
|
||||
"zip://archive.zip!folder/file.tif",
|
||||
"az://azure_bucket/prefix/file.tif",
|
||||
"/vsiaz/azure_bucket/prefix/file.tif",
|
||||
"zip+az://azure_bucket/prefix/archive.zip!folder_in_archive/file.tif",
|
||||
"/vsizip//vsiaz/azure_bucket/prefix/archive.zip/folder_in_archive/file.tif",
|
||||
]
|
||||
assert len(CustomGeoDataset(paths=paths).files) == len(paths)
|
||||
|
||||
|
||||
class TestRasterDataset:
|
||||
@pytest.fixture(params=zip([["R", "G", "B"], None], [True, False]))
|
||||
|
|
|
@ -9,6 +9,7 @@ import glob
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
|
@ -29,7 +30,13 @@ from torch.utils.data import Dataset
|
|||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.datasets.folder import default_loader as pil_loader
|
||||
|
||||
from .utils import BoundingBox, concat_samples, disambiguate_timestamp, merge_samples
|
||||
from .utils import (
|
||||
BoundingBox,
|
||||
concat_samples,
|
||||
disambiguate_timestamp,
|
||||
merge_samples,
|
||||
path_is_vsi,
|
||||
)
|
||||
|
||||
|
||||
class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
||||
|
@ -298,8 +305,14 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
|
|||
if os.path.isdir(path):
|
||||
pathname = os.path.join(path, "**", self.filename_glob)
|
||||
files |= set(glob.iglob(pathname, recursive=True))
|
||||
else:
|
||||
elif os.path.isfile(path) or path_is_vsi(path):
|
||||
files.add(path)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Could not find any relevant files for provided path '{path}'. "
|
||||
f"Path was ignored.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
|
|
|
@ -737,3 +737,27 @@ def percentile_normalization(
|
|||
(img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1
|
||||
)
|
||||
return img_normalized
|
||||
|
||||
|
||||
def path_is_vsi(path: str) -> bool:
|
||||
"""Checks if the given path is pointing to a Virtual File System.
|
||||
|
||||
.. note::
|
||||
Does not check if the path exists, or if it is a dir or file.
|
||||
|
||||
VSI can for instance be Cloud Storage Blobs or zip-archives.
|
||||
They will start with a prefix indicating this.
|
||||
For examples of these, see references for the two accepted syntaxes.
|
||||
|
||||
* https://gdal.org/user/virtual_file_systems.html
|
||||
* https://rasterio.readthedocs.io/en/latest/topics/datasets.html
|
||||
|
||||
Args:
|
||||
path: string representing a directory or file
|
||||
|
||||
Returns:
|
||||
True if path is on a virtual file system, else False
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
return "://" in path or path.startswith("/vsi")
|
||||
|
|
Загрузка…
Ссылка в новой задаче