* Add check if path is vsi

* Add url to reference for apache vsi syntax

* Add missing check to if

* Copy rasterio SCHEMES definition into torchgeo

* Check all schemes, not only last

* Simplify method path_is_vsi

* Add tests

* Remove print

* Update test names

* Add missing comma in list

* Update torchgeo/datasets/utils.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Update torchgeo/datasets/utils.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Use pytest tmp_path for test

* Warn if some of input paths are invalid

* Update docstring for mocked class

* Handle tests failing due to UserWarning

* Remove unnecessary filterwarning

* Test CustomGeoDataset instead of MockRasterDataset

* Merge two similar tests

* str instead of as_posix

Wait with pathlib syntax

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

---------

Co-authored-by: Adrian Tofting <adrian@vake.ai>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Adrian Tofting 2023-10-31 15:07:41 +01:00 коммит произвёл GitHub
Родитель 49cf861470
Коммит 32a7307329
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 59 добавлений и 3 удалений

Просмотреть файл

@ -4,7 +4,7 @@ import os
import pickle
from collections.abc import Iterable
from pathlib import Path
from typing import Union
from typing import Optional, Union
import pytest
import torch
@ -33,11 +33,13 @@ class CustomGeoDataset(GeoDataset):
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(4087),
res: float = 1,
paths: Optional[Union[str, Iterable[str]]] = None,
) -> None:
super().__init__()
self.index.insert(0, tuple(bounds))
self._crs = crs
self.res = res
self.paths = paths or []
def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
hits = self.index.intersection(tuple(query), objects=True)
@ -152,6 +154,23 @@ class TestGeoDataset:
):
dataset & ds2 # type: ignore[operator]
def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None:
paths = [str(tmp_path), str(tmp_path / "non_existing_file.tif")]
with pytest.warns(UserWarning, match="Path was ignored."):
assert len(CustomGeoDataset(paths=paths).files) == 0
def test_files_property_for_virtual_files(self) -> None:
# Tests only a subset of schemes and combinations.
paths = [
"file://directory/file.tif",
"zip://archive.zip!folder/file.tif",
"az://azure_bucket/prefix/file.tif",
"/vsiaz/azure_bucket/prefix/file.tif",
"zip+az://azure_bucket/prefix/archive.zip!folder_in_archive/file.tif",
"/vsizip//vsiaz/azure_bucket/prefix/archive.zip/folder_in_archive/file.tif",
]
assert len(CustomGeoDataset(paths=paths).files) == len(paths)
class TestRasterDataset:
@pytest.fixture(params=zip([["R", "G", "B"], None], [True, False]))

Просмотреть файл

@ -9,6 +9,7 @@ import glob
import os
import re
import sys
import warnings
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, Union, cast
@ -29,7 +30,13 @@ from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader as pil_loader
from .utils import BoundingBox, concat_samples, disambiguate_timestamp, merge_samples
from .utils import (
BoundingBox,
concat_samples,
disambiguate_timestamp,
merge_samples,
path_is_vsi,
)
class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
@ -298,8 +305,14 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
if os.path.isdir(path):
pathname = os.path.join(path, "**", self.filename_glob)
files |= set(glob.iglob(pathname, recursive=True))
else:
elif os.path.isfile(path) or path_is_vsi(path):
files.add(path)
else:
warnings.warn(
f"Could not find any relevant files for provided path '{path}'. "
f"Path was ignored.",
UserWarning,
)
return files

Просмотреть файл

@ -737,3 +737,27 @@ def percentile_normalization(
(img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1
)
return img_normalized
def path_is_vsi(path: str) -> bool:
"""Checks if the given path is pointing to a Virtual File System.
.. note::
Does not check if the path exists, or if it is a dir or file.
VSI can for instance be Cloud Storage Blobs or zip-archives.
They will start with a prefix indicating this.
For examples of these, see references for the two accepted syntaxes.
* https://gdal.org/user/virtual_file_systems.html
* https://rasterio.readthedocs.io/en/latest/topics/datasets.html
Args:
path: string representing a directory or file
Returns:
True if path is on a virtual file system, else False
.. versionadded:: 0.6
"""
return "://" in path or path.startswith("/vsi")