зеркало из https://github.com/microsoft/torchgeo.git
Datasets: improve lazy import error msg for missing deps (#2054)
* Datasets: improve lazy import error msg for missing deps * Add type annotation * Use lazy imports throughout datasets * Fix support for older scipy * Fix support for older scipy * CI: test optional datasets on every commit * Update minversion and fix tests * Double quotes preferred over single quotes * Undo for now * Fast-fail during dataset initialization * Remove extraneous space * MissingDependencyError -> DependencyNotFoundError
This commit is contained in:
Родитель
ac16f4968a
Коммит
189dabd0b6
|
@ -7,34 +7,6 @@ on:
|
|||
branches:
|
||||
- release**
|
||||
jobs:
|
||||
datasets:
|
||||
name: datasets
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone repo
|
||||
uses: actions/checkout@v4.1.5
|
||||
- name: Set up python
|
||||
uses: actions/setup-python@v5.1.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4.0.2
|
||||
id: cache
|
||||
with:
|
||||
path: ${{ env.pythonLocation }}
|
||||
key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-datasets
|
||||
- name: Install pip dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
pip install .[tests]
|
||||
pip cache purge
|
||||
- name: List pip dependencies
|
||||
run: pip list
|
||||
- name: Run pytest checks
|
||||
run: |
|
||||
pytest --cov=torchgeo --cov-report=xml --durations=10
|
||||
python -m torchgeo --help
|
||||
torchgeo --help
|
||||
integration:
|
||||
name: integration
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
@ -99,6 +99,39 @@ jobs:
|
|||
uses: codecov/codecov-action@v4.3.1
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
datasets:
|
||||
name: datasets
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MPLBACKEND: Agg
|
||||
steps:
|
||||
- name: Clone repo
|
||||
uses: actions/checkout@v4.1.4
|
||||
- name: Set up python
|
||||
uses: actions/setup-python@v5.1.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- name: Cache dependencies
|
||||
uses: actions/cache@v4.0.2
|
||||
id: cache
|
||||
with:
|
||||
path: ${{ env.pythonLocation }}
|
||||
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/tests.txt') }}
|
||||
- name: Install pip dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
pip install -r requirements/required.txt -r requirements/tests.txt
|
||||
pip cache purge
|
||||
- name: List pip dependencies
|
||||
run: pip list
|
||||
- name: Run pytest checks
|
||||
run: |
|
||||
pytest --cov=torchgeo --cov-report=xml --durations=10
|
||||
python3 -m torchgeo --help
|
||||
- name: Report coverage
|
||||
uses: codecov/codecov-action@v4.3.0
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
|
|
@ -535,4 +535,5 @@ Errors
|
|||
------
|
||||
|
||||
.. autoclass:: DatasetNotFoundError
|
||||
.. autoclass:: DependencyNotFoundError
|
||||
.. autoclass:: RGBBandsMissingError
|
||||
|
|
|
@ -14,7 +14,6 @@ from torchgeo.datasets import unbind_samples
|
|||
class TestUSAVarsDataModule:
|
||||
@pytest.fixture
|
||||
def datamodule(self, request: SubRequest) -> USAVarsDataModule:
|
||||
pytest.importorskip('pandas', minversion='1.1.3')
|
||||
root = os.path.join('tests', 'data', 'usavars')
|
||||
batch_size = 1
|
||||
num_workers = 0
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -16,6 +14,8 @@ from pytest import MonkeyPatch
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ADVANCE, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip('scipy', minversion='1.7.2')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
@ -37,19 +37,7 @@ class TestADVANCE:
|
|||
transforms = nn.Identity()
|
||||
return ADVANCE(root, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == 'scipy.io':
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
|
||||
def test_getitem(self, dataset: ADVANCE) -> None:
|
||||
pytest.importorskip('scipy', minversion='1.6.2')
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
|
@ -71,17 +59,7 @@ class TestADVANCE:
|
|||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ADVANCE(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: ADVANCE, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='scipy is not installed and is required to use this dataset',
|
||||
):
|
||||
dataset[0]
|
||||
|
||||
def test_plot(self, dataset: ADVANCE) -> None:
|
||||
pytest.importorskip('scipy', minversion='1.6.2')
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -17,7 +15,7 @@ from pytest import MonkeyPatch
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ChaBuD, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip('h5py', minversion='3')
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -47,17 +45,6 @@ class TestChaBuD:
|
|||
checksum=True,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == 'h5py':
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
|
||||
def test_getitem(self, dataset: ChaBuD) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
|
@ -85,15 +72,6 @@ class TestChaBuD:
|
|||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ChaBuD(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: ChaBuD, tmp_path: Path, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='h5py is not installed and is required to use this dataset',
|
||||
):
|
||||
ChaBuD(dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
ChaBuD(bands=['OK', 'BK'])
|
||||
|
|
|
@ -23,14 +23,14 @@ from torchgeo.datasets import (
|
|||
UnionDataset,
|
||||
)
|
||||
|
||||
pytest.importorskip('zipfile_deflate64')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
class TestChesapeake13:
|
||||
pytest.importorskip('zipfile_deflate64')
|
||||
|
||||
@pytest.fixture
|
||||
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13:
|
||||
monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url)
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -16,7 +14,7 @@ from pytest import MonkeyPatch
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import CropHarvest, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip('h5py', minversion='3')
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, filename: str, md5: str) -> None:
|
||||
|
@ -24,17 +22,6 @@ def download_url(url: str, root: str, filename: str, md5: str) -> None:
|
|||
|
||||
|
||||
class TestCropHarvest:
|
||||
@pytest.fixture
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == 'h5py':
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
|
||||
@pytest.fixture
|
||||
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest:
|
||||
monkeypatch.setattr(torchgeo.datasets.cropharvest, 'download_url', download_url)
|
||||
|
@ -89,12 +76,3 @@ class TestCropHarvest:
|
|||
x = dataset[0].copy()
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='h5py is not installed and is required to use this dataset',
|
||||
):
|
||||
CropHarvest(root=str(tmp_path), download=True)[0]
|
||||
|
|
|
@ -6,7 +6,11 @@ from typing import Any
|
|||
import pytest
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError
|
||||
from torchgeo.datasets import (
|
||||
DatasetNotFoundError,
|
||||
DependencyNotFoundError,
|
||||
RGBBandsMissingError,
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetNotFoundError:
|
||||
|
@ -55,6 +59,11 @@ class TestDatasetNotFoundError:
|
|||
raise DatasetNotFoundError(ds)
|
||||
|
||||
|
||||
def test_missing_dependency() -> None:
|
||||
with pytest.raises(DependencyNotFoundError, match='pip install foo'):
|
||||
raise DependencyNotFoundError('foo')
|
||||
|
||||
|
||||
def test_rgb_bands_missing() -> None:
|
||||
match = 'Dataset does not contain some of the RGB bands'
|
||||
with pytest.raises(RGBBandsMissingError, match=match):
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -50,19 +48,6 @@ class TestIDTReeS:
|
|||
transforms = nn.Identity()
|
||||
return IDTReeS(root, split, task, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture(params=['laspy', 'pyvista'])
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str:
|
||||
import_orig = builtins.__import__
|
||||
package = str(request.param)
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == package:
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
return package
|
||||
|
||||
def test_getitem(self, dataset: IDTReeS) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
|
@ -101,24 +86,6 @@ class TestIDTReeS:
|
|||
shutil.copy(zipfile, root)
|
||||
IDTReeS(root)
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: IDTReeS, mock_missing_module: str
|
||||
) -> None:
|
||||
package = mock_missing_module
|
||||
|
||||
if package == 'laspy':
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match=f'{package} is not installed and is required to use this dataset',
|
||||
):
|
||||
IDTReeS(dataset.root, dataset.split, dataset.task)
|
||||
elif package == 'pyvista':
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match=f'{package} is not installed and is required to plot point cloud',
|
||||
):
|
||||
dataset.plot_las(0)
|
||||
|
||||
def test_plot(self, dataset: IDTReeS) -> None:
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, suptitle='Test')
|
||||
|
|
|
@ -76,11 +76,12 @@ class TestLandCoverAIGeo:
|
|||
|
||||
|
||||
class TestLandCoverAI:
|
||||
pytest.importorskip('cv2', minversion='4.5.4')
|
||||
|
||||
@pytest.fixture(params=['train', 'val', 'test'])
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> LandCoverAI:
|
||||
pytest.importorskip('cv2', minversion='4.4.0')
|
||||
monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url)
|
||||
md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
|
||||
monkeypatch.setattr(LandCoverAI, 'md5', md5)
|
||||
|
@ -111,7 +112,6 @@ class TestLandCoverAI:
|
|||
LandCoverAI(root=dataset.root, download=True)
|
||||
|
||||
def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
|
||||
pytest.importorskip('cv2', minversion='4.4.0')
|
||||
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
|
||||
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
|
||||
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -17,6 +15,8 @@ from pytest import MonkeyPatch
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import DatasetNotFoundError, QuakeSet
|
||||
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
@ -39,26 +39,6 @@ class TestQuakeSet:
|
|||
root, split, transforms=transforms, download=True, checksum=True
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == 'h5py':
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='h5py is not installed and is required to use this dataset',
|
||||
):
|
||||
QuakeSet(dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: QuakeSet) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
|
|
|
@ -15,6 +15,8 @@ from pytest import MonkeyPatch
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import RESISC45, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip('rarfile', minversion='4')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
@ -25,8 +27,6 @@ class TestRESISC45:
|
|||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> RESISC45:
|
||||
pytest.importorskip('rarfile', minversion='4')
|
||||
|
||||
monkeypatch.setattr(torchgeo.datasets.resisc45, 'download_url', download_url)
|
||||
md5 = '5895dea3757ba88707d52f5521c444d3'
|
||||
monkeypatch.setattr(RESISC45, 'md5', md5)
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -18,7 +16,7 @@ from pytest import MonkeyPatch
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import SKIPPD, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip('h5py', minversion='3')
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -53,26 +51,6 @@ class TestSKIPPD:
|
|||
checksum=True,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == 'h5py':
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: SKIPPD, tmp_path: Path, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='h5py is not installed and is required to use this dataset',
|
||||
):
|
||||
SKIPPD(dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_already_extracted(self, dataset: SKIPPD) -> None:
|
||||
SKIPPD(root=dataset.root, download=True)
|
||||
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -15,7 +13,7 @@ from pytest import MonkeyPatch
|
|||
|
||||
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, So2Sat
|
||||
|
||||
pytest.importorskip('h5py', minversion='3')
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
class TestSo2Sat:
|
||||
|
@ -35,17 +33,6 @@ class TestSo2Sat:
|
|||
transforms = nn.Identity()
|
||||
return So2Sat(root=root, split=split, transforms=transforms, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == 'h5py':
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
|
||||
def test_getitem(self, dataset: So2Sat) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
|
@ -89,12 +76,3 @@ class TestSo2Sat:
|
|||
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
|
||||
):
|
||||
dataset.plot(dataset[0], suptitle='Single Band')
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: So2Sat, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='h5py is not installed and is required to use this dataset',
|
||||
):
|
||||
So2Sat(dataset.root)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
|
@ -20,8 +19,8 @@ from pytest import MonkeyPatch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, DependencyNotFoundError
|
||||
from torchgeo.datasets.utils import (
|
||||
BoundingBox,
|
||||
array_to_tensor,
|
||||
concat_samples,
|
||||
disambiguate_timestamp,
|
||||
|
@ -29,6 +28,7 @@ from torchgeo.datasets.utils import (
|
|||
download_radiant_mlhub_collection,
|
||||
download_radiant_mlhub_dataset,
|
||||
extract_archive,
|
||||
lazy_import,
|
||||
merge_samples,
|
||||
percentile_normalization,
|
||||
stack_samples,
|
||||
|
@ -37,18 +37,6 @@ from torchgeo.datasets.utils import (
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name in ['radiant_mlhub', 'rarfile', 'zipfile_deflate64']:
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
|
||||
|
||||
class MLHubDataset:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(
|
||||
|
@ -79,10 +67,6 @@ def download_url(url: str, root: str, *args: str) -> None:
|
|||
shutil.copy(url, root)
|
||||
|
||||
|
||||
def test_mock_missing_module(mock_missing_module: None) -> None:
|
||||
import sys # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'src',
|
||||
[
|
||||
|
@ -102,21 +86,6 @@ def test_extract_archive(src: str, tmp_path: Path) -> None:
|
|||
extract_archive(os.path.join('tests', 'data', src), str(tmp_path))
|
||||
|
||||
|
||||
def test_missing_rarfile(mock_missing_module: None) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='rarfile is not installed and is required to extract this dataset',
|
||||
):
|
||||
extract_archive(
|
||||
os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar')
|
||||
)
|
||||
|
||||
|
||||
def test_missing_zipfile_deflate64(mock_missing_module: None) -> None:
|
||||
# Should fallback on Python builtin zipfile
|
||||
extract_archive(os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'))
|
||||
|
||||
|
||||
def test_unsupported_scheme() -> None:
|
||||
with pytest.raises(
|
||||
RuntimeError, match='src file has unknown archival/compression scheme'
|
||||
|
@ -148,21 +117,6 @@ def test_download_radiant_mlhub_collection(
|
|||
download_radiant_mlhub_collection('', str(tmp_path))
|
||||
|
||||
|
||||
def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='radiant_mlhub is not installed and is required to download this dataset',
|
||||
):
|
||||
download_radiant_mlhub_dataset('', '')
|
||||
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='radiant_mlhub is not installed and is required to download this'
|
||||
+ ' collection',
|
||||
):
|
||||
download_radiant_mlhub_collection('', '')
|
||||
|
||||
|
||||
class TestBoundingBox:
|
||||
def test_repr_str(self) -> None:
|
||||
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)
|
||||
|
@ -625,3 +579,14 @@ def test_array_to_tensor(array_dtype: 'np.typing.DTypeLike') -> None:
|
|||
# values equal even if they differ.
|
||||
assert array[0].item() == tensor[0].item()
|
||||
assert array[1].item() == tensor[1].item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('name', ['collections', 'collections.abc'])
|
||||
def test_lazy_import(name: str) -> None:
|
||||
lazy_import(name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('name', ['foo_bar', 'foo_bar.baz'])
|
||||
def test_lazy_import_missing(name: str) -> None:
|
||||
with pytest.raises(DependencyNotFoundError, match='pip install foo-bar\n'):
|
||||
lazy_import(name)
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -19,6 +17,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import VHR10, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip('pycocotools')
|
||||
pytest.importorskip('rarfile', minversion='4')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -30,7 +29,6 @@ class TestVHR10:
|
|||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
) -> VHR10:
|
||||
pytest.importorskip('rarfile', minversion='4')
|
||||
monkeypatch.setattr(torchgeo.datasets.vhr10, 'download_url', download_url)
|
||||
monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url)
|
||||
url = os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar')
|
||||
|
@ -46,17 +44,6 @@ class TestVHR10:
|
|||
transforms = nn.Identity()
|
||||
return VHR10(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_modules(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name in {'pycocotools.coco', 'skimage.measure'}:
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
|
||||
def test_getitem(self, dataset: VHR10) -> None:
|
||||
for i in range(2):
|
||||
x = dataset[i]
|
||||
|
@ -93,25 +80,8 @@ class TestVHR10:
|
|||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
VHR10(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: VHR10, mock_missing_modules: None
|
||||
) -> None:
|
||||
if dataset.split == 'positive':
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='pycocotools is not installed and is required to use this datase',
|
||||
):
|
||||
VHR10(dataset.root, dataset.split)
|
||||
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='scikit-image is not installed and is required to plot masks',
|
||||
):
|
||||
x = dataset[0]
|
||||
dataset.plot(x)
|
||||
|
||||
def test_plot(self, dataset: VHR10) -> None:
|
||||
pytest.importorskip('skimage', minversion='0.18')
|
||||
pytest.importorskip('skimage', minversion='0.19')
|
||||
x = dataset[1].copy()
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
|
@ -16,7 +14,7 @@ from pytest import MonkeyPatch
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriCrop
|
||||
|
||||
pytest.importorskip('h5py', minversion='3')
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -39,17 +37,6 @@ class TestZueriCrop:
|
|||
transforms = nn.Identity()
|
||||
return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == 'h5py':
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, '__import__', mocked_import)
|
||||
|
||||
def test_getitem(self, dataset: ZueriCrop) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
|
@ -82,15 +69,6 @@ class TestZueriCrop:
|
|||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ZueriCrop(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: ZueriCrop, tmp_path: Path, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match='h5py is not installed and is required to use this dataset',
|
||||
):
|
||||
ZueriCrop(dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ZueriCrop(bands=('OK', 'BK'))
|
||||
|
|
|
@ -89,7 +89,7 @@ class TestClassificationTask:
|
|||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
|
||||
) -> None:
|
||||
if name.startswith('so2sat') or name == 'quakeset':
|
||||
pytest.importorskip('h5py', minversion='3')
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
config = os.path.join('tests', 'conf', name + '.yaml')
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ class TestRegressionTask:
|
|||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
|
||||
) -> None:
|
||||
if name == 'skippd':
|
||||
pytest.importorskip('h5py', minversion='3')
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
config = os.path.join('tests', 'conf', name + '.yaml')
|
||||
|
||||
|
|
|
@ -87,12 +87,16 @@ class TestSemanticSegmentationTask:
|
|||
def test_trainer(
|
||||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
|
||||
) -> None:
|
||||
if name == 'naipchesapeake':
|
||||
pytest.importorskip('zipfile_deflate64')
|
||||
|
||||
if name == 'landcoverai':
|
||||
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
|
||||
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
|
||||
match name:
|
||||
case 'chabud':
|
||||
pytest.importorskip('h5py', minversion='3.6')
|
||||
case 'landcoverai':
|
||||
sha256 = (
|
||||
'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
|
||||
)
|
||||
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
|
||||
case 'naipchesapeake':
|
||||
pytest.importorskip('zipfile_deflate64')
|
||||
|
||||
config = os.path.join('tests', 'conf', name + '.yaml')
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ from .deepglobelandcover import DeepGlobeLandCover
|
|||
from .dfc2022 import DFC2022
|
||||
from .eddmaps import EDDMapS
|
||||
from .enviroatlas import EnviroAtlas
|
||||
from .errors import DatasetNotFoundError, RGBBandsMissingError
|
||||
from .errors import DatasetNotFoundError, DependencyNotFoundError, RGBBandsMissingError
|
||||
from .esri2020 import Esri2020
|
||||
from .etci2021 import ETCI2021
|
||||
from .eudem import EUDEM
|
||||
|
@ -280,5 +280,6 @@ __all__ = (
|
|||
'time_series_split',
|
||||
# Errors
|
||||
'DatasetNotFoundError',
|
||||
'DependencyNotFoundError',
|
||||
'RGBBandsMissingError',
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_and_extract_archive
|
||||
from .utils import download_and_extract_archive, lazy_import
|
||||
|
||||
|
||||
class ADVANCE(NonGeoDataset):
|
||||
|
@ -104,7 +104,10 @@ class ADVANCE(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
DependencyNotFoundError: If scipy is not installed.
|
||||
"""
|
||||
lazy_import('scipy.io.wavfile')
|
||||
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
|
@ -191,14 +194,8 @@ class ADVANCE(NonGeoDataset):
|
|||
Returns:
|
||||
the target audio
|
||||
"""
|
||||
try:
|
||||
from scipy.io import wavfile
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'scipy is not installed and is required to use this dataset'
|
||||
)
|
||||
|
||||
array = wavfile.read(path, mmap=True)[1]
|
||||
siw = lazy_import('scipy.io.wavfile')
|
||||
array = siw.read(path, mmap=True)[1]
|
||||
tensor = torch.from_numpy(array)
|
||||
tensor = tensor.unsqueeze(0)
|
||||
return tensor
|
||||
|
|
|
@ -14,7 +14,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, percentile_normalization
|
||||
from .utils import download_url, lazy_import, percentile_normalization
|
||||
|
||||
|
||||
class ChaBuD(NonGeoDataset):
|
||||
|
@ -96,7 +96,10 @@ class ChaBuD(NonGeoDataset):
|
|||
Raises:
|
||||
AssertionError: If ``split`` or ``bands`` arguments are invalid.
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
DependencyNotFoundError: If h5py is not installed.
|
||||
"""
|
||||
lazy_import('h5py')
|
||||
|
||||
assert split in self.folds
|
||||
assert set(bands) <= set(self.all_bands)
|
||||
|
||||
|
@ -111,13 +114,6 @@ class ChaBuD(NonGeoDataset):
|
|||
|
||||
self._verify()
|
||||
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'h5py is not installed and is required to use this dataset'
|
||||
)
|
||||
|
||||
self.uuids = self._load_uuids()
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
|
@ -153,8 +149,7 @@ class ChaBuD(NonGeoDataset):
|
|||
Returns:
|
||||
the image uuids
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
uuids = []
|
||||
with h5py.File(self.filepath, 'r') as f:
|
||||
for k, v in f.items():
|
||||
|
@ -173,8 +168,7 @@ class ChaBuD(NonGeoDataset):
|
|||
Returns:
|
||||
the image
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
uuid = self.uuids[index]
|
||||
with h5py.File(self.filepath, 'r') as f:
|
||||
pre_array = f[uuid]['pre_fire'][:]
|
||||
|
@ -199,8 +193,7 @@ class ChaBuD(NonGeoDataset):
|
|||
Returns:
|
||||
the target mask
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
uuid = self.uuids[index]
|
||||
with h5py.File(self.filepath, 'r') as f:
|
||||
array = f[uuid]['mask'][:].astype(np.int32).squeeze(axis=-1)
|
||||
|
|
|
@ -17,7 +17,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, extract_archive
|
||||
from .utils import download_url, extract_archive, lazy_import
|
||||
|
||||
|
||||
class CropHarvest(NonGeoDataset):
|
||||
|
@ -112,14 +112,9 @@ class CropHarvest(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
ImportError: If h5py is not installed
|
||||
DependencyNotFoundError: If h5py is not installed.
|
||||
"""
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'h5py is not installed and is required to use this dataset'
|
||||
)
|
||||
lazy_import('h5py')
|
||||
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
|
@ -210,8 +205,7 @@ class CropHarvest(NonGeoDataset):
|
|||
Returns:
|
||||
the image
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
filename = os.path.join(path)
|
||||
with h5py.File(filename, 'r') as f:
|
||||
array = f.get('array')[()]
|
||||
|
|
|
@ -49,6 +49,32 @@ class DatasetNotFoundError(FileNotFoundError):
|
|||
super().__init__(msg)
|
||||
|
||||
|
||||
class DependencyNotFoundError(ModuleNotFoundError):
|
||||
"""Raised when an optional dataset dependency is not installed.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialize a new DependencyNotFoundError instance.
|
||||
|
||||
Args:
|
||||
name: Name of missing dependency.
|
||||
"""
|
||||
msg = f"""\
|
||||
{name} is not installed and is required to use this dataset. Either run:
|
||||
|
||||
$ pip install {name}
|
||||
|
||||
to install just this dependency, or:
|
||||
|
||||
$ pip install torchgeo[datasets]
|
||||
|
||||
to install all optional dataset dependencies."""
|
||||
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class RGBBandsMissingError(ValueError):
|
||||
"""Raised when a dataset is missing RGB bands for plotting.
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from torchvision.utils import draw_bounding_boxes
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, extract_archive
|
||||
from .utils import download_url, extract_archive, lazy_import
|
||||
|
||||
|
||||
class IDTReeS(NonGeoDataset):
|
||||
|
@ -92,6 +92,11 @@ class IDTReeS(NonGeoDataset):
|
|||
|
||||
* https://doi.org/10.1101/2021.08.06.453503
|
||||
|
||||
This dataset requires the following additional libraries to be installed:
|
||||
|
||||
* `laspy <https://pypi.org/project/laspy/>`_ to read lidar point clouds
|
||||
* `pyvista <https://pypi.org/project/pyvista/>`_ to plot lidar point clouds
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
|
||||
|
@ -167,11 +172,14 @@ class IDTReeS(NonGeoDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
ImportError: if laspy is not installed
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
DependencyNotFoundError: If laspy is not installed.
|
||||
"""
|
||||
lazy_import('laspy')
|
||||
|
||||
assert split in ['train', 'test']
|
||||
assert task in ['task1', 'task2']
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.task = task
|
||||
|
@ -182,14 +190,6 @@ class IDTReeS(NonGeoDataset):
|
|||
self.idx2class = {i: c for i, c in enumerate(self.classes)}
|
||||
self.num_classes = len(self.classes)
|
||||
self._verify()
|
||||
|
||||
try:
|
||||
import laspy # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'laspy is not installed and is required to use this dataset'
|
||||
)
|
||||
|
||||
self.images, self.geometries, self.labels = self._load(root)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
|
@ -263,8 +263,7 @@ class IDTReeS(NonGeoDataset):
|
|||
Returns:
|
||||
the point cloud
|
||||
"""
|
||||
import laspy
|
||||
|
||||
laspy = lazy_import('laspy')
|
||||
las = laspy.read(path)
|
||||
array: 'np.typing.NDArray[np.int_]' = np.stack([las.x, las.y, las.z], axis=0)
|
||||
tensor = torch.from_numpy(array)
|
||||
|
@ -561,19 +560,13 @@ class IDTReeS(NonGeoDataset):
|
|||
pyvista.PolyData object. Run pyvista.plot(point_cloud, ...) to display
|
||||
|
||||
Raises:
|
||||
ImportError: if pyvista is not installed
|
||||
DependencyNotFoundError: If laspy or pyvista are not installed.
|
||||
|
||||
.. versionchanged:: 0.4
|
||||
Ported from Open3D to PyVista, *colormap* parameter removed.
|
||||
"""
|
||||
try:
|
||||
import pyvista # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'pyvista is not installed and is required to plot point clouds'
|
||||
)
|
||||
import laspy
|
||||
|
||||
laspy = lazy_import('laspy')
|
||||
pyvista = lazy_import('pyvista')
|
||||
path = self.images[index]
|
||||
path = path.replace('RGB', 'LAS').replace('.tif', '.las')
|
||||
las = laspy.read(path)
|
||||
|
|
|
@ -15,7 +15,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, percentile_normalization
|
||||
from .utils import download_url, lazy_import, percentile_normalization
|
||||
|
||||
|
||||
class QuakeSet(NonGeoDataset):
|
||||
|
@ -85,8 +85,10 @@ class QuakeSet(NonGeoDataset):
|
|||
Raises:
|
||||
AssertionError: If ``split`` argument is invalid.
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
ImportError: if h5py is not installed
|
||||
DependencyNotFoundError: If h5py is not installed.
|
||||
"""
|
||||
lazy_import('h5py')
|
||||
|
||||
assert split in self.splits
|
||||
|
||||
self.root = root
|
||||
|
@ -95,16 +97,7 @@ class QuakeSet(NonGeoDataset):
|
|||
self.download = download
|
||||
self.checksum = checksum
|
||||
self.filepath = os.path.join(root, self.filename)
|
||||
|
||||
self._verify()
|
||||
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'h5py is not installed and is required to use this dataset'
|
||||
)
|
||||
|
||||
self.data = self._load_data()
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
|
@ -141,8 +134,7 @@ class QuakeSet(NonGeoDataset):
|
|||
Returns:
|
||||
the sample keys, patches, images, labels, and magnitudes
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
data = []
|
||||
with h5py.File(self.filepath) as f:
|
||||
for k in sorted(f.keys()):
|
||||
|
@ -185,7 +177,7 @@ class QuakeSet(NonGeoDataset):
|
|||
Returns:
|
||||
the image
|
||||
"""
|
||||
import h5py
|
||||
h5py = lazy_import('h5py')
|
||||
|
||||
key = self.data[index]['key']
|
||||
patch = self.data[index]['patch']
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, extract_archive
|
||||
from .utils import download_url, extract_archive, lazy_import
|
||||
|
||||
|
||||
class SKIPPD(NonGeoDataset):
|
||||
|
@ -53,6 +53,12 @@ class SKIPPD(NonGeoDataset):
|
|||
|
||||
* https://doi.org/10.48550/arXiv.2207.00913
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `<https://pypi.org/project/h5py/>`_ to load the dataset
|
||||
|
||||
.. versionadded:: 0.5
|
||||
"""
|
||||
|
||||
|
@ -94,8 +100,10 @@ class SKIPPD(NonGeoDataset):
|
|||
Raises:
|
||||
AssertionError: if ``task`` or ``split`` is invalid
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
ImportError: if h5py is not installed
|
||||
DependencyNotFoundError: If h5py is not installed.
|
||||
"""
|
||||
lazy_import('h5py')
|
||||
|
||||
assert (
|
||||
split in self.valid_splits
|
||||
), f'Please choose one of these valid data splits {self.valid_splits}.'
|
||||
|
@ -110,14 +118,6 @@ class SKIPPD(NonGeoDataset):
|
|||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'h5py is not installed and is required to use this dataset'
|
||||
)
|
||||
|
||||
self._verify()
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
@ -126,8 +126,7 @@ class SKIPPD(NonGeoDataset):
|
|||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
with h5py.File(
|
||||
os.path.join(self.root, self.data_file_name.format(self.task)), 'r'
|
||||
) as f:
|
||||
|
@ -161,8 +160,7 @@ class SKIPPD(NonGeoDataset):
|
|||
Returns:
|
||||
image tensor at index
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
with h5py.File(
|
||||
os.path.join(self.root, self.data_file_name.format(self.task)), 'r'
|
||||
) as f:
|
||||
|
@ -187,8 +185,7 @@ class SKIPPD(NonGeoDataset):
|
|||
Returns:
|
||||
label tensor at index
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
with h5py.File(
|
||||
os.path.join(self.root, self.data_file_name.format(self.task)), 'r'
|
||||
) as f:
|
||||
|
|
|
@ -15,7 +15,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError, RGBBandsMissingError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, percentile_normalization
|
||||
from .utils import check_integrity, lazy_import, percentile_normalization
|
||||
|
||||
|
||||
class So2Sat(NonGeoDataset):
|
||||
|
@ -97,6 +97,12 @@ class So2Sat(NonGeoDataset):
|
|||
done
|
||||
|
||||
or manually downloaded from https://mediatum.ub.tum.de/1613658
|
||||
|
||||
.. note::
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `<https://pypi.org/project/h5py/>`_ to load the dataset
|
||||
""" # noqa: E501
|
||||
|
||||
versions = ['2', '3_random', '3_block', '3_culture_10']
|
||||
|
@ -210,6 +216,7 @@ class So2Sat(NonGeoDataset):
|
|||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
DependencyNotFoundError: If h5py is not installed.
|
||||
|
||||
.. versionadded:: 0.3
|
||||
The *bands* parameter.
|
||||
|
@ -217,12 +224,8 @@ class So2Sat(NonGeoDataset):
|
|||
.. versionadded:: 0.5
|
||||
The *version* parameter.
|
||||
"""
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'h5py is not installed and is required to use this dataset'
|
||||
)
|
||||
h5py = lazy_import('h5py')
|
||||
|
||||
assert version in self.versions
|
||||
assert split in self.filenames_by_version[version]
|
||||
|
||||
|
@ -272,8 +275,7 @@ class So2Sat(NonGeoDataset):
|
|||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
with h5py.File(self.fn, 'r') as f:
|
||||
s1 = f['sen1'][index].astype(np.float64) # convert from <f8 to float64
|
||||
s1 = np.take(s1, indices=self.s1_band_indices, axis=2)
|
||||
|
|
|
@ -10,6 +10,7 @@ import bz2
|
|||
import collections
|
||||
import contextlib
|
||||
import gzip
|
||||
import importlib
|
||||
import lzma
|
||||
import os
|
||||
import sys
|
||||
|
@ -26,6 +27,8 @@ from torch import Tensor
|
|||
from torchvision.datasets.utils import check_integrity, download_url
|
||||
from torchvision.utils import draw_segmentation_masks
|
||||
|
||||
from .errors import DependencyNotFoundError
|
||||
|
||||
# Only include import redirects
|
||||
__all__ = ('check_integrity', 'download_url')
|
||||
|
||||
|
@ -37,13 +40,7 @@ class _rarfile:
|
|||
self.kwargs = kwargs
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
try:
|
||||
import rarfile
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'rarfile is not installed and is required to extract this dataset'
|
||||
)
|
||||
|
||||
rarfile = lazy_import('rarfile')
|
||||
# TODO: catch exception for when rarfile is installed but not
|
||||
# unrar/unar/bsdtar
|
||||
return rarfile.RarFile(*self.args, **self.kwargs)
|
||||
|
@ -157,14 +154,11 @@ def download_radiant_mlhub_dataset(
|
|||
api_key: the API key to use for all requests from the session. Can also be
|
||||
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
|
||||
``~/.mlhub/profiles``.
|
||||
"""
|
||||
try:
|
||||
import radiant_mlhub
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'radiant_mlhub is not installed and is required to download this dataset'
|
||||
)
|
||||
|
||||
Raises:
|
||||
DependencyNotFoundError: If radiant_mlhub is not installed.
|
||||
"""
|
||||
radiant_mlhub = lazy_import('radiant_mlhub')
|
||||
dataset = radiant_mlhub.Dataset.fetch(dataset_id, api_key=api_key)
|
||||
dataset.download(output_dir=download_root, api_key=api_key)
|
||||
|
||||
|
@ -180,14 +174,11 @@ def download_radiant_mlhub_collection(
|
|||
api_key: the API key to use for all requests from the session. Can also be
|
||||
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
|
||||
``~/.mlhub/profiles``.
|
||||
"""
|
||||
try:
|
||||
import radiant_mlhub
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'radiant_mlhub is not installed and is required to download this collection'
|
||||
)
|
||||
|
||||
Raises:
|
||||
DependencyNotFoundError: If radiant_mlhub is not installed.
|
||||
"""
|
||||
radiant_mlhub = lazy_import('radiant_mlhub')
|
||||
collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key)
|
||||
collection.download(output_dir=download_root, api_key=api_key)
|
||||
|
||||
|
@ -773,3 +764,25 @@ def array_to_tensor(array: np.typing.NDArray[Any]) -> Tensor:
|
|||
elif array.dtype == np.uint32:
|
||||
array = array.astype(np.int64)
|
||||
return torch.tensor(array)
|
||||
|
||||
|
||||
def lazy_import(name: str) -> Any:
|
||||
"""Lazy import of *name*.
|
||||
|
||||
Args:
|
||||
name: Name of module to import.
|
||||
|
||||
Raises:
|
||||
DependencyNotFoundError: If *name* is not installed.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(name)
|
||||
except ModuleNotFoundError:
|
||||
# Map from import name to package name on PyPI
|
||||
name = name.split('.')[0].replace('_', '-')
|
||||
module_to_pypi: dict[str, str] = collections.defaultdict(lambda: name)
|
||||
module_to_pypi |= {'cv2': 'opencv-python', 'skimage': 'scikit-image'}
|
||||
name = module_to_pypi[name]
|
||||
raise DependencyNotFoundError(name) from None
|
||||
|
|
|
@ -17,7 +17,12 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_and_extract_archive, download_url
|
||||
from .utils import (
|
||||
check_integrity,
|
||||
download_and_extract_archive,
|
||||
download_url,
|
||||
lazy_import,
|
||||
)
|
||||
|
||||
|
||||
def convert_coco_poly_to_mask(
|
||||
|
@ -32,13 +37,15 @@ def convert_coco_poly_to_mask(
|
|||
|
||||
Returns:
|
||||
Tensor: Mask tensor
|
||||
"""
|
||||
from pycocotools import mask as coco_mask # noqa: F401
|
||||
|
||||
Raises:
|
||||
DependencyNotFoundError: If pycocotools is not installed.
|
||||
"""
|
||||
pycocotools = lazy_import('pycocotools')
|
||||
masks = []
|
||||
for polygons in segmentations:
|
||||
rles = coco_mask.frPyObjects(polygons, height, width)
|
||||
mask = coco_mask.decode(rles)
|
||||
rles = pycocotools.mask.frPyObjects(polygons, height, width)
|
||||
mask = pycocotools.mask.decode(rles)
|
||||
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
||||
mask = mask.any(dim=2)
|
||||
masks.append(mask)
|
||||
|
@ -196,8 +203,9 @@ class VHR10(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
ImportError: if ``split="positive"`` and pycocotools is not installed
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
DependencyNotFoundError: if ``split="positive"`` and pycocotools is
|
||||
not installed.
|
||||
"""
|
||||
assert split in ['positive', 'negative']
|
||||
|
||||
|
@ -213,20 +221,12 @@ class VHR10(NonGeoDataset):
|
|||
raise DatasetNotFoundError(self)
|
||||
|
||||
if split == 'positive':
|
||||
# Must be installed to parse annotations file
|
||||
try:
|
||||
from pycocotools.coco import COCO # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'pycocotools is not installed and is required to use this dataset'
|
||||
)
|
||||
|
||||
self.coco = COCO(
|
||||
pc = lazy_import('pycocotools.coco')
|
||||
self.coco = pc.COCO(
|
||||
os.path.join(
|
||||
self.root, 'NWPU VHR-10 dataset', self.target_meta['filename']
|
||||
)
|
||||
)
|
||||
|
||||
self.coco_convert = ConvertCocoAnnotations()
|
||||
self.ids = list(sorted(self.coco.imgs.keys()))
|
||||
|
||||
|
@ -381,7 +381,7 @@ class VHR10(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``show_feats`` argument is invalid
|
||||
ImportError: if plotting masks and scikit-image is not installed
|
||||
DependencyNotFoundError: If plotting masks and scikit-image is not installed.
|
||||
|
||||
.. versionadded:: 0.4
|
||||
"""
|
||||
|
@ -397,12 +397,7 @@ class VHR10(NonGeoDataset):
|
|||
return fig
|
||||
|
||||
if show_feats != 'boxes':
|
||||
try:
|
||||
from skimage.measure import find_contours # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'scikit-image is not installed and is required to plot masks.'
|
||||
)
|
||||
skimage = lazy_import('skimage')
|
||||
|
||||
image = sample['image'].permute(1, 2, 0).numpy()
|
||||
boxes = sample['boxes'].cpu().numpy()
|
||||
|
@ -465,7 +460,7 @@ class VHR10(NonGeoDataset):
|
|||
# Add masks
|
||||
if show_feats in {'masks', 'both'} and 'masks' in sample:
|
||||
mask = masks[i]
|
||||
contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
|
||||
contours = skimage.measure.find_contours(mask, 0.5)
|
||||
for verts in contours:
|
||||
verts = np.fliplr(verts)
|
||||
p = patches.Polygon(
|
||||
|
@ -517,7 +512,7 @@ class VHR10(NonGeoDataset):
|
|||
# Add masks
|
||||
if show_pred_masks:
|
||||
mask = prediction_masks[i]
|
||||
contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
|
||||
contours = skimage.measure.find_contours(mask, 0.5)
|
||||
for verts in contours:
|
||||
verts = np.fliplr(verts)
|
||||
p = patches.Polygon(
|
||||
|
|
|
@ -13,7 +13,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError, RGBBandsMissingError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, percentile_normalization
|
||||
from .utils import download_url, lazy_import, percentile_normalization
|
||||
|
||||
|
||||
class ZueriCrop(NonGeoDataset):
|
||||
|
@ -82,7 +82,10 @@ class ZueriCrop(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
DependencyNotFoundError: If h5py is not installed.
|
||||
"""
|
||||
lazy_import('h5py')
|
||||
|
||||
self._validate_bands(bands)
|
||||
self.band_indices = torch.tensor(
|
||||
[self.band_names.index(b) for b in bands]
|
||||
|
@ -97,13 +100,6 @@ class ZueriCrop(NonGeoDataset):
|
|||
|
||||
self._verify()
|
||||
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'h5py is not installed and is required to use this dataset'
|
||||
)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
|
@ -129,8 +125,7 @@ class ZueriCrop(NonGeoDataset):
|
|||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
with h5py.File(self.filepath, 'r') as f:
|
||||
length: int = f['data'].shape[0]
|
||||
return length
|
||||
|
@ -144,8 +139,7 @@ class ZueriCrop(NonGeoDataset):
|
|||
Returns:
|
||||
the image
|
||||
"""
|
||||
import h5py
|
||||
|
||||
h5py = lazy_import('h5py')
|
||||
with h5py.File(self.filepath, 'r') as f:
|
||||
array = f['data'][index, ...]
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче