зеркало из https://github.com/microsoft/torchgeo.git
Datasets: add CLI support (#2064)
* Datasets: add CLI support * Return completed process * Fix return type * More powerful azcopy mock
This commit is contained in:
Родитель
39cc9b6ed4
Коммит
91bbb83bf9
|
@ -0,0 +1,27 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Basic mock-up of the azcopy CLI."""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
copy = subparsers.add_parser('copy')
|
||||
copy.add_argument('source')
|
||||
copy.add_argument('destination')
|
||||
copy.add_argument('--recursive', default='false')
|
||||
sync = subparsers.add_parser('sync')
|
||||
sync.add_argument('source')
|
||||
sync.add_argument('destination')
|
||||
sync.add_argument('--recursive', default='true')
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.recursive == 'true':
|
||||
shutil.copytree(args.source, args.destination, dirs_exist_ok=True)
|
||||
else:
|
||||
shutil.copy(args.source, args.destination)
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets.utils import Executable, which
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azcopy(monkeypatch: MonkeyPatch) -> Executable:
|
||||
path = os.path.dirname(os.path.realpath(__file__))
|
||||
monkeypatch.setenv('PATH', path, prepend=os.pathsep)
|
||||
return which('azcopy')
|
|
@ -59,9 +59,10 @@ class TestDatasetNotFoundError:
|
|||
raise DatasetNotFoundError(ds)
|
||||
|
||||
|
||||
def test_missing_dependency() -> None:
|
||||
with pytest.raises(DependencyNotFoundError, match='pip install foo'):
|
||||
raise DependencyNotFoundError('foo')
|
||||
def test_dependency_not_found() -> None:
|
||||
msg = 'foo not installed'
|
||||
with pytest.raises(DependencyNotFoundError, match=msg):
|
||||
raise DependencyNotFoundError(msg)
|
||||
|
||||
|
||||
def test_rgb_bands_missing() -> None:
|
||||
|
|
|
@ -21,6 +21,7 @@ from rasterio.crs import CRS
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, DependencyNotFoundError
|
||||
from torchgeo.datasets.utils import (
|
||||
Executable,
|
||||
array_to_tensor,
|
||||
concat_samples,
|
||||
disambiguate_timestamp,
|
||||
|
@ -33,6 +34,7 @@ from torchgeo.datasets.utils import (
|
|||
percentile_normalization,
|
||||
stack_samples,
|
||||
unbind_samples,
|
||||
which,
|
||||
working_dir,
|
||||
)
|
||||
|
||||
|
@ -590,3 +592,14 @@ def test_lazy_import(name: str) -> None:
|
|||
def test_lazy_import_missing(name: str) -> None:
|
||||
with pytest.raises(DependencyNotFoundError, match='pip install foo-bar\n'):
|
||||
lazy_import(name)
|
||||
|
||||
|
||||
def test_azcopy(tmp_path: Path, azcopy: Executable) -> None:
|
||||
source = os.path.join('tests', 'data', 'cyclone')
|
||||
azcopy('sync', source, tmp_path, '--recursive=true')
|
||||
assert os.path.exists(tmp_path / 'nasa_tropical_storm_competition_test_labels')
|
||||
|
||||
|
||||
def test_which() -> None:
|
||||
with pytest.raises(DependencyNotFoundError, match='foo is not installed'):
|
||||
which('foo')
|
||||
|
|
|
@ -49,31 +49,12 @@ class DatasetNotFoundError(FileNotFoundError):
|
|||
super().__init__(msg)
|
||||
|
||||
|
||||
class DependencyNotFoundError(ModuleNotFoundError):
|
||||
class DependencyNotFoundError(Exception):
|
||||
"""Raised when an optional dataset dependency is not installed.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialize a new DependencyNotFoundError instance.
|
||||
|
||||
Args:
|
||||
name: Name of missing dependency.
|
||||
"""
|
||||
msg = f"""\
|
||||
{name} is not installed and is required to use this dataset. Either run:
|
||||
|
||||
$ pip install {name}
|
||||
|
||||
to install just this dependency, or:
|
||||
|
||||
$ pip install torchgeo[datasets]
|
||||
|
||||
to install all optional dataset dependencies."""
|
||||
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class RGBBandsMissingError(ValueError):
|
||||
"""Raised when a dataset is missing RGB bands for plotting.
|
||||
|
|
|
@ -13,6 +13,8 @@ import gzip
|
|||
import importlib
|
||||
import lzma
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
|
@ -402,6 +404,34 @@ class BoundingBox:
|
|||
return bbox1, bbox2
|
||||
|
||||
|
||||
class Executable:
|
||||
"""Command-line executable.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialize a new Executable instance.
|
||||
|
||||
Args:
|
||||
name: Command name.
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess[bytes]:
|
||||
"""Run the command.
|
||||
|
||||
Args:
|
||||
args: Arguments to pass to the command.
|
||||
kwargs: Keyword arguments to pass to :func:`subprocess.run`.
|
||||
|
||||
Returns:
|
||||
The completed process.
|
||||
"""
|
||||
kwargs['check'] = True
|
||||
return subprocess.run((self.name,) + args, **kwargs)
|
||||
|
||||
|
||||
def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]:
|
||||
"""Disambiguate partial timestamps.
|
||||
|
||||
|
@ -772,6 +802,9 @@ def lazy_import(name: str) -> Any:
|
|||
Args:
|
||||
name: Name of module to import.
|
||||
|
||||
Returns:
|
||||
Module import.
|
||||
|
||||
Raises:
|
||||
DependencyNotFoundError: If *name* is not installed.
|
||||
|
||||
|
@ -785,4 +818,35 @@ def lazy_import(name: str) -> Any:
|
|||
module_to_pypi: dict[str, str] = collections.defaultdict(lambda: name)
|
||||
module_to_pypi |= {'cv2': 'opencv-python', 'skimage': 'scikit-image'}
|
||||
name = module_to_pypi[name]
|
||||
raise DependencyNotFoundError(name) from None
|
||||
msg = f"""\
|
||||
{name} is not installed and is required to use this dataset. Either run:
|
||||
|
||||
$ pip install {name}
|
||||
|
||||
to install just this dependency, or:
|
||||
|
||||
$ pip install torchgeo[datasets]
|
||||
|
||||
to install all optional dataset dependencies."""
|
||||
raise DependencyNotFoundError(msg) from None
|
||||
|
||||
|
||||
def which(name: str) -> Executable:
|
||||
"""Search for executable *name*.
|
||||
|
||||
Args:
|
||||
name: Name of executable to search for.
|
||||
|
||||
Returns:
|
||||
Callable executable instance.
|
||||
|
||||
Raises:
|
||||
DependencyNotFoundError: If *name* is not installed.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
if shutil.which(name):
|
||||
return Executable(name)
|
||||
else:
|
||||
msg = f'{name} is not installed and is required to use this dataset.'
|
||||
raise DependencyNotFoundError(msg) from None
|
||||
|
|
Загрузка…
Ссылка в новой задаче