* Datasets: add CLI support

* Return completed process

* Fix return type

* More powerful azcopy mock
This commit is contained in:
Adam J. Stewart 2024-05-23 12:01:15 +02:00 коммит произвёл GitHub
Родитель 39cc9b6ed4
Коммит 91bbb83bf9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
6 изменённых файлов: 126 добавлений и 24 удалений

27
tests/datasets/azcopy Executable file
Просмотреть файл

@ -0,0 +1,27 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Basic mock-up of the azcopy CLI."""
import argparse
import shutil
if __name__ == '__main__':
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
copy = subparsers.add_parser('copy')
copy.add_argument('source')
copy.add_argument('destination')
copy.add_argument('--recursive', default='false')
sync = subparsers.add_parser('sync')
sync.add_argument('source')
sync.add_argument('destination')
sync.add_argument('--recursive', default='true')
args, _ = parser.parse_known_args()
if args.recursive == 'true':
shutil.copytree(args.source, args.destination, dirs_exist_ok=True)
else:
shutil.copy(args.source, args.destination)

Просмотреть файл

@ -0,0 +1,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import pytest
from pytest import MonkeyPatch
from torchgeo.datasets.utils import Executable, which
@pytest.fixture
def azcopy(monkeypatch: MonkeyPatch) -> Executable:
path = os.path.dirname(os.path.realpath(__file__))
monkeypatch.setenv('PATH', path, prepend=os.pathsep)
return which('azcopy')

Просмотреть файл

@ -59,9 +59,10 @@ class TestDatasetNotFoundError:
raise DatasetNotFoundError(ds)
def test_missing_dependency() -> None:
with pytest.raises(DependencyNotFoundError, match='pip install foo'):
raise DependencyNotFoundError('foo')
def test_dependency_not_found() -> None:
msg = 'foo not installed'
with pytest.raises(DependencyNotFoundError, match=msg):
raise DependencyNotFoundError(msg)
def test_rgb_bands_missing() -> None:

Просмотреть файл

@ -21,6 +21,7 @@ from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, DependencyNotFoundError
from torchgeo.datasets.utils import (
Executable,
array_to_tensor,
concat_samples,
disambiguate_timestamp,
@ -33,6 +34,7 @@ from torchgeo.datasets.utils import (
percentile_normalization,
stack_samples,
unbind_samples,
which,
working_dir,
)
@ -590,3 +592,14 @@ def test_lazy_import(name: str) -> None:
def test_lazy_import_missing(name: str) -> None:
with pytest.raises(DependencyNotFoundError, match='pip install foo-bar\n'):
lazy_import(name)
def test_azcopy(tmp_path: Path, azcopy: Executable) -> None:
source = os.path.join('tests', 'data', 'cyclone')
azcopy('sync', source, tmp_path, '--recursive=true')
assert os.path.exists(tmp_path / 'nasa_tropical_storm_competition_test_labels')
def test_which() -> None:
with pytest.raises(DependencyNotFoundError, match='foo is not installed'):
which('foo')

Просмотреть файл

@ -49,31 +49,12 @@ class DatasetNotFoundError(FileNotFoundError):
super().__init__(msg)
class DependencyNotFoundError(ModuleNotFoundError):
class DependencyNotFoundError(Exception):
"""Raised when an optional dataset dependency is not installed.
.. versionadded:: 0.6
"""
def __init__(self, name: str) -> None:
"""Initialize a new DependencyNotFoundError instance.
Args:
name: Name of missing dependency.
"""
msg = f"""\
{name} is not installed and is required to use this dataset. Either run:
$ pip install {name}
to install just this dependency, or:
$ pip install torchgeo[datasets]
to install all optional dataset dependencies."""
super().__init__(msg)
class RGBBandsMissingError(ValueError):
"""Raised when a dataset is missing RGB bands for plotting.

Просмотреть файл

@ -13,6 +13,8 @@ import gzip
import importlib
import lzma
import os
import shutil
import subprocess
import sys
import tarfile
from collections.abc import Iterable, Iterator, Sequence
@ -402,6 +404,34 @@ class BoundingBox:
return bbox1, bbox2
class Executable:
"""Command-line executable.
.. versionadded:: 0.6
"""
def __init__(self, name: str) -> None:
"""Initialize a new Executable instance.
Args:
name: Command name.
"""
self.name = name
def __call__(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess[bytes]:
"""Run the command.
Args:
args: Arguments to pass to the command.
kwargs: Keyword arguments to pass to :func:`subprocess.run`.
Returns:
The completed process.
"""
kwargs['check'] = True
return subprocess.run((self.name,) + args, **kwargs)
def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]:
"""Disambiguate partial timestamps.
@ -772,6 +802,9 @@ def lazy_import(name: str) -> Any:
Args:
name: Name of module to import.
Returns:
Module import.
Raises:
DependencyNotFoundError: If *name* is not installed.
@ -785,4 +818,35 @@ def lazy_import(name: str) -> Any:
module_to_pypi: dict[str, str] = collections.defaultdict(lambda: name)
module_to_pypi |= {'cv2': 'opencv-python', 'skimage': 'scikit-image'}
name = module_to_pypi[name]
raise DependencyNotFoundError(name) from None
msg = f"""\
{name} is not installed and is required to use this dataset. Either run:
$ pip install {name}
to install just this dependency, or:
$ pip install torchgeo[datasets]
to install all optional dataset dependencies."""
raise DependencyNotFoundError(msg) from None
def which(name: str) -> Executable:
"""Search for executable *name*.
Args:
name: Name of executable to search for.
Returns:
Callable executable instance.
Raises:
DependencyNotFoundError: If *name* is not installed.
.. versionadded:: 0.6
"""
if shutil.which(name):
return Executable(name)
else:
msg = f'{name} is not installed and is required to use this dataset.'
raise DependencyNotFoundError(msg) from None