Datasets: improve lazy import error msg for missing deps (#2054)

* Datasets: improve lazy import error msg for missing deps

* Add type annotation

* Use lazy imports throughout datasets

* Fix support for older scipy

* Fix support for older scipy

* CI: test optional datasets on every commit

* Update minversion and fix tests

* Double quotes preferred over single quotes

* Undo for now

* Fast-fail during dataset initialization

* Remove extraneous space

* MissingDependencyError -> DependencyNotFoundError
This commit is contained in:
Adam J. Stewart 2024-05-15 18:03:54 +02:00 коммит произвёл GitHub
Родитель ac16f4968a
Коммит 189dabd0b6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
33 изменённых файлов: 235 добавлений и 470 удалений

28
.github/workflows/release.yaml поставляемый
Просмотреть файл

@ -7,34 +7,6 @@ on:
branches:
- release**
jobs:
datasets:
name: datasets
runs-on: ubuntu-latest
steps:
- name: Clone repo
uses: actions/checkout@v4.1.5
- name: Set up python
uses: actions/setup-python@v5.1.0
with:
python-version: "3.12"
- name: Cache dependencies
uses: actions/cache@v4.0.2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-datasets
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install .[tests]
pip cache purge
- name: List pip dependencies
run: pip list
- name: Run pytest checks
run: |
pytest --cov=torchgeo --cov-report=xml --durations=10
python -m torchgeo --help
torchgeo --help
integration:
name: integration
runs-on: ubuntu-latest

33
.github/workflows/tests.yaml поставляемый
Просмотреть файл

@ -99,6 +99,39 @@ jobs:
uses: codecov/codecov-action@v4.3.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
datasets:
name: datasets
runs-on: ubuntu-latest
env:
MPLBACKEND: Agg
steps:
- name: Clone repo
uses: actions/checkout@v4.1.4
- name: Set up python
uses: actions/setup-python@v5.1.0
with:
python-version: "3.12"
- name: Cache dependencies
uses: actions/cache@v4.0.2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/tests.txt') }}
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/required.txt -r requirements/tests.txt
pip cache purge
- name: List pip dependencies
run: pip list
- name: Run pytest checks
run: |
pytest --cov=torchgeo --cov-report=xml --durations=10
python3 -m torchgeo --help
- name: Report coverage
uses: codecov/codecov-action@v4.3.0
with:
token: ${{ secrets.CODECOV_TOKEN }}
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}
cancel-in-progress: true

Просмотреть файл

@ -535,4 +535,5 @@ Errors
------
.. autoclass:: DatasetNotFoundError
.. autoclass:: DependencyNotFoundError
.. autoclass:: RGBBandsMissingError

Просмотреть файл

@ -14,7 +14,6 @@ from torchgeo.datasets import unbind_samples
class TestUSAVarsDataModule:
@pytest.fixture
def datamodule(self, request: SubRequest) -> USAVarsDataModule:
pytest.importorskip('pandas', minversion='1.1.3')
root = os.path.join('tests', 'data', 'usavars')
batch_size = 1
num_workers = 0

Просмотреть файл

@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
@ -16,6 +14,8 @@ from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ADVANCE, DatasetNotFoundError
pytest.importorskip('scipy', minversion='1.7.2')
def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
@ -37,19 +37,7 @@ class TestADVANCE:
transforms = nn.Identity()
return ADVANCE(root, transforms, download=True, checksum=True)
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'scipy.io':
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
def test_getitem(self, dataset: ADVANCE) -> None:
pytest.importorskip('scipy', minversion='1.6.2')
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
@ -71,17 +59,7 @@ class TestADVANCE:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ADVANCE(str(tmp_path))
def test_mock_missing_module(
self, dataset: ADVANCE, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='scipy is not installed and is required to use this dataset',
):
dataset[0]
def test_plot(self, dataset: ADVANCE) -> None:
pytest.importorskip('scipy', minversion='1.6.2')
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()

Просмотреть файл

@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
@ -17,7 +15,7 @@ from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ChaBuD, DatasetNotFoundError
pytest.importorskip('h5py', minversion='3')
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
@ -47,17 +45,6 @@ class TestChaBuD:
checksum=True,
)
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
def test_getitem(self, dataset: ChaBuD) -> None:
x = dataset[0]
assert isinstance(x, dict)
@ -85,15 +72,6 @@ class TestChaBuD:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ChaBuD(str(tmp_path))
def test_mock_missing_module(
self, dataset: ChaBuD, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
ChaBuD(dataset.root, download=True, checksum=True)
def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
ChaBuD(bands=['OK', 'BK'])

Просмотреть файл

@ -23,14 +23,14 @@ from torchgeo.datasets import (
UnionDataset,
)
pytest.importorskip('zipfile_deflate64')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestChesapeake13:
pytest.importorskip('zipfile_deflate64')
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13:
monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url)

Просмотреть файл

@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
@ -16,7 +14,7 @@ from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import CropHarvest, DatasetNotFoundError
pytest.importorskip('h5py', minversion='3')
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, filename: str, md5: str) -> None:
@ -24,17 +22,6 @@ def download_url(url: str, root: str, filename: str, md5: str) -> None:
class TestCropHarvest:
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest:
monkeypatch.setattr(torchgeo.datasets.cropharvest, 'download_url', download_url)
@ -89,12 +76,3 @@ class TestCropHarvest:
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()
def test_mock_missing_module(
self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
CropHarvest(root=str(tmp_path), download=True)[0]

Просмотреть файл

@ -6,7 +6,11 @@ from typing import Any
import pytest
from torch.utils.data import Dataset
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError
from torchgeo.datasets import (
DatasetNotFoundError,
DependencyNotFoundError,
RGBBandsMissingError,
)
class TestDatasetNotFoundError:
@ -55,6 +59,11 @@ class TestDatasetNotFoundError:
raise DatasetNotFoundError(ds)
def test_missing_dependency() -> None:
with pytest.raises(DependencyNotFoundError, match='pip install foo'):
raise DependencyNotFoundError('foo')
def test_rgb_bands_missing() -> None:
match = 'Dataset does not contain some of the RGB bands'
with pytest.raises(RGBBandsMissingError, match=match):

Просмотреть файл

@ -1,12 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import glob
import os
import shutil
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
@ -50,19 +48,6 @@ class TestIDTReeS:
transforms = nn.Identity()
return IDTReeS(root, split, task, transforms, download=True, checksum=True)
@pytest.fixture(params=['laspy', 'pyvista'])
def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str:
import_orig = builtins.__import__
package = str(request.param)
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == package:
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
return package
def test_getitem(self, dataset: IDTReeS) -> None:
x = dataset[0]
assert isinstance(x, dict)
@ -101,24 +86,6 @@ class TestIDTReeS:
shutil.copy(zipfile, root)
IDTReeS(root)
def test_mock_missing_module(
self, dataset: IDTReeS, mock_missing_module: str
) -> None:
package = mock_missing_module
if package == 'laspy':
with pytest.raises(
ImportError,
match=f'{package} is not installed and is required to use this dataset',
):
IDTReeS(dataset.root, dataset.split, dataset.task)
elif package == 'pyvista':
with pytest.raises(
ImportError,
match=f'{package} is not installed and is required to plot point cloud',
):
dataset.plot_las(0)
def test_plot(self, dataset: IDTReeS) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')

Просмотреть файл

@ -76,11 +76,12 @@ class TestLandCoverAIGeo:
class TestLandCoverAI:
pytest.importorskip('cv2', minversion='4.5.4')
@pytest.fixture(params=['train', 'val', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> LandCoverAI:
pytest.importorskip('cv2', minversion='4.4.0')
monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url)
md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
monkeypatch.setattr(LandCoverAI, 'md5', md5)
@ -111,7 +112,6 @@ class TestLandCoverAI:
LandCoverAI(root=dataset.root, download=True)
def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
pytest.importorskip('cv2', minversion='4.4.0')
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')

Просмотреть файл

@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
@ -17,6 +15,8 @@ from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, QuakeSet
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -39,26 +39,6 @@ class TestQuakeSet:
root, split, transforms=transforms, download=True, checksum=True
)
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
def test_mock_missing_module(
self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
QuakeSet(dataset.root, download=True, checksum=True)
def test_getitem(self, dataset: QuakeSet) -> None:
x = dataset[0]
assert isinstance(x, dict)

Просмотреть файл

@ -15,6 +15,8 @@ from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import RESISC45, DatasetNotFoundError
pytest.importorskip('rarfile', minversion='4')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -25,8 +27,6 @@ class TestRESISC45:
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> RESISC45:
pytest.importorskip('rarfile', minversion='4')
monkeypatch.setattr(torchgeo.datasets.resisc45, 'download_url', download_url)
md5 = '5895dea3757ba88707d52f5521c444d3'
monkeypatch.setattr(RESISC45, 'md5', md5)

Просмотреть файл

@ -1,12 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from itertools import product
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
@ -18,7 +16,7 @@ from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import SKIPPD, DatasetNotFoundError
pytest.importorskip('h5py', minversion='3')
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -53,26 +51,6 @@ class TestSKIPPD:
checksum=True,
)
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
def test_mock_missing_module(
self, dataset: SKIPPD, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
SKIPPD(dataset.root, download=True, checksum=True)
def test_already_extracted(self, dataset: SKIPPD) -> None:
SKIPPD(root=dataset.root, download=True)

Просмотреть файл

@ -1,10 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
@ -15,7 +13,7 @@ from pytest import MonkeyPatch
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, So2Sat
pytest.importorskip('h5py', minversion='3')
pytest.importorskip('h5py', minversion='3.6')
class TestSo2Sat:
@ -35,17 +33,6 @@ class TestSo2Sat:
transforms = nn.Identity()
return So2Sat(root=root, split=split, transforms=transforms, checksum=True)
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
def test_getitem(self, dataset: So2Sat) -> None:
x = dataset[0]
assert isinstance(x, dict)
@ -89,12 +76,3 @@ class TestSo2Sat:
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
):
dataset.plot(dataset[0], suptitle='Single Band')
def test_mock_missing_module(
self, dataset: So2Sat, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
So2Sat(dataset.root)

Просмотреть файл

@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import glob
import math
import os
@ -20,8 +19,8 @@ from pytest import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, DependencyNotFoundError
from torchgeo.datasets.utils import (
BoundingBox,
array_to_tensor,
concat_samples,
disambiguate_timestamp,
@ -29,6 +28,7 @@ from torchgeo.datasets.utils import (
download_radiant_mlhub_collection,
download_radiant_mlhub_dataset,
extract_archive,
lazy_import,
merge_samples,
percentile_normalization,
stack_samples,
@ -37,18 +37,6 @@ from torchgeo.datasets.utils import (
)
@pytest.fixture
def mock_missing_module(monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name in ['radiant_mlhub', 'rarfile', 'zipfile_deflate64']:
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
class MLHubDataset:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
@ -79,10 +67,6 @@ def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
def test_mock_missing_module(mock_missing_module: None) -> None:
import sys # noqa: F401
@pytest.mark.parametrize(
'src',
[
@ -102,21 +86,6 @@ def test_extract_archive(src: str, tmp_path: Path) -> None:
extract_archive(os.path.join('tests', 'data', src), str(tmp_path))
def test_missing_rarfile(mock_missing_module: None) -> None:
with pytest.raises(
ImportError,
match='rarfile is not installed and is required to extract this dataset',
):
extract_archive(
os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar')
)
def test_missing_zipfile_deflate64(mock_missing_module: None) -> None:
# Should fallback on Python builtin zipfile
extract_archive(os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'))
def test_unsupported_scheme() -> None:
with pytest.raises(
RuntimeError, match='src file has unknown archival/compression scheme'
@ -148,21 +117,6 @@ def test_download_radiant_mlhub_collection(
download_radiant_mlhub_collection('', str(tmp_path))
def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
with pytest.raises(
ImportError,
match='radiant_mlhub is not installed and is required to download this dataset',
):
download_radiant_mlhub_dataset('', '')
with pytest.raises(
ImportError,
match='radiant_mlhub is not installed and is required to download this'
+ ' collection',
):
download_radiant_mlhub_collection('', '')
class TestBoundingBox:
def test_repr_str(self) -> None:
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)
@ -625,3 +579,14 @@ def test_array_to_tensor(array_dtype: 'np.typing.DTypeLike') -> None:
# values equal even if they differ.
assert array[0].item() == tensor[0].item()
assert array[1].item() == tensor[1].item()
@pytest.mark.parametrize('name', ['collections', 'collections.abc'])
def test_lazy_import(name: str) -> None:
lazy_import(name)
@pytest.mark.parametrize('name', ['foo_bar', 'foo_bar.baz'])
def test_lazy_import_missing(name: str) -> None:
with pytest.raises(DependencyNotFoundError, match='pip install foo-bar\n'):
lazy_import(name)

Просмотреть файл

@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
@ -19,6 +17,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import VHR10, DatasetNotFoundError
pytest.importorskip('pycocotools')
pytest.importorskip('rarfile', minversion='4')
def download_url(url: str, root: str, *args: str) -> None:
@ -30,7 +29,6 @@ class TestVHR10:
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> VHR10:
pytest.importorskip('rarfile', minversion='4')
monkeypatch.setattr(torchgeo.datasets.vhr10, 'download_url', download_url)
monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url)
url = os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar')
@ -46,17 +44,6 @@ class TestVHR10:
transforms = nn.Identity()
return VHR10(root, split, transforms, download=True, checksum=True)
@pytest.fixture
def mock_missing_modules(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name in {'pycocotools.coco', 'skimage.measure'}:
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
def test_getitem(self, dataset: VHR10) -> None:
for i in range(2):
x = dataset[i]
@ -93,25 +80,8 @@ class TestVHR10:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
VHR10(str(tmp_path))
def test_mock_missing_module(
self, dataset: VHR10, mock_missing_modules: None
) -> None:
if dataset.split == 'positive':
with pytest.raises(
ImportError,
match='pycocotools is not installed and is required to use this datase',
):
VHR10(dataset.root, dataset.split)
with pytest.raises(
ImportError,
match='scikit-image is not installed and is required to plot masks',
):
x = dataset[0]
dataset.plot(x)
def test_plot(self, dataset: VHR10) -> None:
pytest.importorskip('skimage', minversion='0.18')
pytest.importorskip('skimage', minversion='0.19')
x = dataset[1].copy()
dataset.plot(x, suptitle='Test')
plt.close()

Просмотреть файл

@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import builtins
import os
import shutil
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import pytest
@ -16,7 +14,7 @@ from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriCrop
pytest.importorskip('h5py', minversion='3')
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -39,17 +37,6 @@ class TestZueriCrop:
transforms = nn.Identity()
return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True)
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)
monkeypatch.setattr(builtins, '__import__', mocked_import)
def test_getitem(self, dataset: ZueriCrop) -> None:
x = dataset[0]
assert isinstance(x, dict)
@ -82,15 +69,6 @@ class TestZueriCrop:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ZueriCrop(str(tmp_path))
def test_mock_missing_module(
self, dataset: ZueriCrop, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
ZueriCrop(dataset.root, download=True, checksum=True)
def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):
ZueriCrop(bands=('OK', 'BK'))

Просмотреть файл

@ -89,7 +89,7 @@ class TestClassificationTask:
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
if name.startswith('so2sat') or name == 'quakeset':
pytest.importorskip('h5py', minversion='3')
pytest.importorskip('h5py', minversion='3.6')
config = os.path.join('tests', 'conf', name + '.yaml')

Просмотреть файл

@ -71,7 +71,7 @@ class TestRegressionTask:
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
if name == 'skippd':
pytest.importorskip('h5py', minversion='3')
pytest.importorskip('h5py', minversion='3.6')
config = os.path.join('tests', 'conf', name + '.yaml')

Просмотреть файл

@ -87,12 +87,16 @@ class TestSemanticSegmentationTask:
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
if name == 'naipchesapeake':
pytest.importorskip('zipfile_deflate64')
if name == 'landcoverai':
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
match name:
case 'chabud':
pytest.importorskip('h5py', minversion='3.6')
case 'landcoverai':
sha256 = (
'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
)
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
case 'naipchesapeake':
pytest.importorskip('zipfile_deflate64')
config = os.path.join('tests', 'conf', name + '.yaml')

Просмотреть файл

@ -37,7 +37,7 @@ from .deepglobelandcover import DeepGlobeLandCover
from .dfc2022 import DFC2022
from .eddmaps import EDDMapS
from .enviroatlas import EnviroAtlas
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .errors import DatasetNotFoundError, DependencyNotFoundError, RGBBandsMissingError
from .esri2020 import Esri2020
from .etci2021 import ETCI2021
from .eudem import EUDEM
@ -280,5 +280,6 @@ __all__ = (
'time_series_split',
# Errors
'DatasetNotFoundError',
'DependencyNotFoundError',
'RGBBandsMissingError',
)

Просмотреть файл

@ -17,7 +17,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_and_extract_archive
from .utils import download_and_extract_archive, lazy_import
class ADVANCE(NonGeoDataset):
@ -104,7 +104,10 @@ class ADVANCE(NonGeoDataset):
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
DependencyNotFoundError: If scipy is not installed.
"""
lazy_import('scipy.io.wavfile')
self.root = root
self.transforms = transforms
self.checksum = checksum
@ -191,14 +194,8 @@ class ADVANCE(NonGeoDataset):
Returns:
the target audio
"""
try:
from scipy.io import wavfile
except ImportError:
raise ImportError(
'scipy is not installed and is required to use this dataset'
)
array = wavfile.read(path, mmap=True)[1]
siw = lazy_import('scipy.io.wavfile')
array = siw.read(path, mmap=True)[1]
tensor = torch.from_numpy(array)
tensor = tensor.unsqueeze(0)
return tensor

Просмотреть файл

@ -14,7 +14,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, percentile_normalization
from .utils import download_url, lazy_import, percentile_normalization
class ChaBuD(NonGeoDataset):
@ -96,7 +96,10 @@ class ChaBuD(NonGeoDataset):
Raises:
AssertionError: If ``split`` or ``bands`` arguments are invalid.
DatasetNotFoundError: If dataset is not found and *download* is False.
DependencyNotFoundError: If h5py is not installed.
"""
lazy_import('h5py')
assert split in self.folds
assert set(bands) <= set(self.all_bands)
@ -111,13 +114,6 @@ class ChaBuD(NonGeoDataset):
self._verify()
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
'h5py is not installed and is required to use this dataset'
)
self.uuids = self._load_uuids()
def __getitem__(self, index: int) -> dict[str, Tensor]:
@ -153,8 +149,7 @@ class ChaBuD(NonGeoDataset):
Returns:
the image uuids
"""
import h5py
h5py = lazy_import('h5py')
uuids = []
with h5py.File(self.filepath, 'r') as f:
for k, v in f.items():
@ -173,8 +168,7 @@ class ChaBuD(NonGeoDataset):
Returns:
the image
"""
import h5py
h5py = lazy_import('h5py')
uuid = self.uuids[index]
with h5py.File(self.filepath, 'r') as f:
pre_array = f[uuid]['pre_fire'][:]
@ -199,8 +193,7 @@ class ChaBuD(NonGeoDataset):
Returns:
the target mask
"""
import h5py
h5py = lazy_import('h5py')
uuid = self.uuids[index]
with h5py.File(self.filepath, 'r') as f:
array = f[uuid]['mask'][:].astype(np.int32).squeeze(axis=-1)

Просмотреть файл

@ -17,7 +17,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, extract_archive
from .utils import download_url, extract_archive, lazy_import
class CropHarvest(NonGeoDataset):
@ -112,14 +112,9 @@ class CropHarvest(NonGeoDataset):
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
ImportError: If h5py is not installed
DependencyNotFoundError: If h5py is not installed.
"""
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
'h5py is not installed and is required to use this dataset'
)
lazy_import('h5py')
self.root = root
self.transforms = transforms
@ -210,8 +205,7 @@ class CropHarvest(NonGeoDataset):
Returns:
the image
"""
import h5py
h5py = lazy_import('h5py')
filename = os.path.join(path)
with h5py.File(filename, 'r') as f:
array = f.get('array')[()]

Просмотреть файл

@ -49,6 +49,32 @@ class DatasetNotFoundError(FileNotFoundError):
super().__init__(msg)
class DependencyNotFoundError(ModuleNotFoundError):
"""Raised when an optional dataset dependency is not installed.
.. versionadded:: 0.6
"""
def __init__(self, name: str) -> None:
"""Initialize a new DependencyNotFoundError instance.
Args:
name: Name of missing dependency.
"""
msg = f"""\
{name} is not installed and is required to use this dataset. Either run:
$ pip install {name}
to install just this dependency, or:
$ pip install torchgeo[datasets]
to install all optional dataset dependencies."""
super().__init__(msg)
class RGBBandsMissingError(ValueError):
"""Raised when a dataset is missing RGB bands for plotting.

Просмотреть файл

@ -22,7 +22,7 @@ from torchvision.utils import draw_bounding_boxes
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, extract_archive
from .utils import download_url, extract_archive, lazy_import
class IDTReeS(NonGeoDataset):
@ -92,6 +92,11 @@ class IDTReeS(NonGeoDataset):
* https://doi.org/10.1101/2021.08.06.453503
This dataset requires the following additional libraries to be installed:
* `laspy <https://pypi.org/project/laspy/>`_ to read lidar point clouds
* `pyvista <https://pypi.org/project/pyvista/>`_ to plot lidar point clouds
.. versionadded:: 0.2
"""
@ -167,11 +172,14 @@ class IDTReeS(NonGeoDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
ImportError: if laspy is not installed
DatasetNotFoundError: If dataset is not found and *download* is False.
DependencyNotFoundError: If laspy is not installed.
"""
lazy_import('laspy')
assert split in ['train', 'test']
assert task in ['task1', 'task2']
self.root = root
self.split = split
self.task = task
@ -182,14 +190,6 @@ class IDTReeS(NonGeoDataset):
self.idx2class = {i: c for i, c in enumerate(self.classes)}
self.num_classes = len(self.classes)
self._verify()
try:
import laspy # noqa: F401
except ImportError:
raise ImportError(
'laspy is not installed and is required to use this dataset'
)
self.images, self.geometries, self.labels = self._load(root)
def __getitem__(self, index: int) -> dict[str, Tensor]:
@ -263,8 +263,7 @@ class IDTReeS(NonGeoDataset):
Returns:
the point cloud
"""
import laspy
laspy = lazy_import('laspy')
las = laspy.read(path)
array: 'np.typing.NDArray[np.int_]' = np.stack([las.x, las.y, las.z], axis=0)
tensor = torch.from_numpy(array)
@ -561,19 +560,13 @@ class IDTReeS(NonGeoDataset):
pyvista.PolyData object. Run pyvista.plot(point_cloud, ...) to display
Raises:
ImportError: if pyvista is not installed
DependencyNotFoundError: If laspy or pyvista are not installed.
.. versionchanged:: 0.4
Ported from Open3D to PyVista, *colormap* parameter removed.
"""
try:
import pyvista # noqa: F401
except ImportError:
raise ImportError(
'pyvista is not installed and is required to plot point clouds'
)
import laspy
laspy = lazy_import('laspy')
pyvista = lazy_import('pyvista')
path = self.images[index]
path = path.replace('RGB', 'LAS').replace('.tif', '.las')
las = laspy.read(path)

Просмотреть файл

@ -15,7 +15,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, percentile_normalization
from .utils import download_url, lazy_import, percentile_normalization
class QuakeSet(NonGeoDataset):
@ -85,8 +85,10 @@ class QuakeSet(NonGeoDataset):
Raises:
AssertionError: If ``split`` argument is invalid.
DatasetNotFoundError: If dataset is not found and *download* is False.
ImportError: if h5py is not installed
DependencyNotFoundError: If h5py is not installed.
"""
lazy_import('h5py')
assert split in self.splits
self.root = root
@ -95,16 +97,7 @@ class QuakeSet(NonGeoDataset):
self.download = download
self.checksum = checksum
self.filepath = os.path.join(root, self.filename)
self._verify()
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
'h5py is not installed and is required to use this dataset'
)
self.data = self._load_data()
def __getitem__(self, index: int) -> dict[str, Tensor]:
@ -141,8 +134,7 @@ class QuakeSet(NonGeoDataset):
Returns:
the sample keys, patches, images, labels, and magnitudes
"""
import h5py
h5py = lazy_import('h5py')
data = []
with h5py.File(self.filepath) as f:
for k in sorted(f.keys()):
@ -185,7 +177,7 @@ class QuakeSet(NonGeoDataset):
Returns:
the image
"""
import h5py
h5py = lazy_import('h5py')
key = self.data[index]['key']
patch = self.data[index]['patch']

Просмотреть файл

@ -16,7 +16,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, extract_archive
from .utils import download_url, extract_archive, lazy_import
class SKIPPD(NonGeoDataset):
@ -53,6 +53,12 @@ class SKIPPD(NonGeoDataset):
* https://doi.org/10.48550/arXiv.2207.00913
.. note::
This dataset requires the following additional library to be installed:
* `<https://pypi.org/project/h5py/>`_ to load the dataset
.. versionadded:: 0.5
"""
@ -94,8 +100,10 @@ class SKIPPD(NonGeoDataset):
Raises:
AssertionError: if ``task`` or ``split`` is invalid
DatasetNotFoundError: If dataset is not found and *download* is False.
ImportError: if h5py is not installed
DependencyNotFoundError: If h5py is not installed.
"""
lazy_import('h5py')
assert (
split in self.valid_splits
), f'Please choose one of these valid data splits {self.valid_splits}.'
@ -110,14 +118,6 @@ class SKIPPD(NonGeoDataset):
self.transforms = transforms
self.download = download
self.checksum = checksum
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
'h5py is not installed and is required to use this dataset'
)
self._verify()
def __len__(self) -> int:
@ -126,8 +126,7 @@ class SKIPPD(NonGeoDataset):
Returns:
length of the dataset
"""
import h5py
h5py = lazy_import('h5py')
with h5py.File(
os.path.join(self.root, self.data_file_name.format(self.task)), 'r'
) as f:
@ -161,8 +160,7 @@ class SKIPPD(NonGeoDataset):
Returns:
image tensor at index
"""
import h5py
h5py = lazy_import('h5py')
with h5py.File(
os.path.join(self.root, self.data_file_name.format(self.task)), 'r'
) as f:
@ -187,8 +185,7 @@ class SKIPPD(NonGeoDataset):
Returns:
label tensor at index
"""
import h5py
h5py = lazy_import('h5py')
with h5py.File(
os.path.join(self.root, self.data_file_name.format(self.task)), 'r'
) as f:

Просмотреть файл

@ -15,7 +15,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import check_integrity, percentile_normalization
from .utils import check_integrity, lazy_import, percentile_normalization
class So2Sat(NonGeoDataset):
@ -97,6 +97,12 @@ class So2Sat(NonGeoDataset):
done
or manually downloaded from https://mediatum.ub.tum.de/1613658
.. note::
This dataset requires the following additional library to be installed:
* `<https://pypi.org/project/h5py/>`_ to load the dataset
""" # noqa: E501
versions = ['2', '3_random', '3_block', '3_culture_10']
@ -210,6 +216,7 @@ class So2Sat(NonGeoDataset):
Raises:
AssertionError: if ``split`` argument is invalid
DatasetNotFoundError: If dataset is not found.
DependencyNotFoundError: If h5py is not installed.
.. versionadded:: 0.3
The *bands* parameter.
@ -217,12 +224,8 @@ class So2Sat(NonGeoDataset):
.. versionadded:: 0.5
The *version* parameter.
"""
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
'h5py is not installed and is required to use this dataset'
)
h5py = lazy_import('h5py')
assert version in self.versions
assert split in self.filenames_by_version[version]
@ -272,8 +275,7 @@ class So2Sat(NonGeoDataset):
Returns:
data and label at that index
"""
import h5py
h5py = lazy_import('h5py')
with h5py.File(self.fn, 'r') as f:
s1 = f['sen1'][index].astype(np.float64) # convert from <f8 to float64
s1 = np.take(s1, indices=self.s1_band_indices, axis=2)

Просмотреть файл

@ -10,6 +10,7 @@ import bz2
import collections
import contextlib
import gzip
import importlib
import lzma
import os
import sys
@ -26,6 +27,8 @@ from torch import Tensor
from torchvision.datasets.utils import check_integrity, download_url
from torchvision.utils import draw_segmentation_masks
from .errors import DependencyNotFoundError
# Only include import redirects
__all__ = ('check_integrity', 'download_url')
@ -37,13 +40,7 @@ class _rarfile:
self.kwargs = kwargs
def __enter__(self) -> Any:
try:
import rarfile
except ImportError:
raise ImportError(
'rarfile is not installed and is required to extract this dataset'
)
rarfile = lazy_import('rarfile')
# TODO: catch exception for when rarfile is installed but not
# unrar/unar/bsdtar
return rarfile.RarFile(*self.args, **self.kwargs)
@ -157,14 +154,11 @@ def download_radiant_mlhub_dataset(
api_key: the API key to use for all requests from the session. Can also be
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
``~/.mlhub/profiles``.
"""
try:
import radiant_mlhub
except ImportError:
raise ImportError(
'radiant_mlhub is not installed and is required to download this dataset'
)
Raises:
DependencyNotFoundError: If radiant_mlhub is not installed.
"""
radiant_mlhub = lazy_import('radiant_mlhub')
dataset = radiant_mlhub.Dataset.fetch(dataset_id, api_key=api_key)
dataset.download(output_dir=download_root, api_key=api_key)
@ -180,14 +174,11 @@ def download_radiant_mlhub_collection(
api_key: the API key to use for all requests from the session. Can also be
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
``~/.mlhub/profiles``.
"""
try:
import radiant_mlhub
except ImportError:
raise ImportError(
'radiant_mlhub is not installed and is required to download this collection'
)
Raises:
DependencyNotFoundError: If radiant_mlhub is not installed.
"""
radiant_mlhub = lazy_import('radiant_mlhub')
collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key)
collection.download(output_dir=download_root, api_key=api_key)
@ -773,3 +764,25 @@ def array_to_tensor(array: np.typing.NDArray[Any]) -> Tensor:
elif array.dtype == np.uint32:
array = array.astype(np.int64)
return torch.tensor(array)
def lazy_import(name: str) -> Any:
"""Lazy import of *name*.
Args:
name: Name of module to import.
Raises:
DependencyNotFoundError: If *name* is not installed.
.. versionadded:: 0.6
"""
try:
return importlib.import_module(name)
except ModuleNotFoundError:
# Map from import name to package name on PyPI
name = name.split('.')[0].replace('_', '-')
module_to_pypi: dict[str, str] = collections.defaultdict(lambda: name)
module_to_pypi |= {'cv2': 'opencv-python', 'skimage': 'scikit-image'}
name = module_to_pypi[name]
raise DependencyNotFoundError(name) from None

Просмотреть файл

@ -17,7 +17,12 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import check_integrity, download_and_extract_archive, download_url
from .utils import (
check_integrity,
download_and_extract_archive,
download_url,
lazy_import,
)
def convert_coco_poly_to_mask(
@ -32,13 +37,15 @@ def convert_coco_poly_to_mask(
Returns:
Tensor: Mask tensor
"""
from pycocotools import mask as coco_mask # noqa: F401
Raises:
DependencyNotFoundError: If pycocotools is not installed.
"""
pycocotools = lazy_import('pycocotools')
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
rles = pycocotools.mask.frPyObjects(polygons, height, width)
mask = pycocotools.mask.decode(rles)
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
@ -196,8 +203,9 @@ class VHR10(NonGeoDataset):
Raises:
AssertionError: if ``split`` argument is invalid
ImportError: if ``split="positive"`` and pycocotools is not installed
DatasetNotFoundError: If dataset is not found and *download* is False.
DependencyNotFoundError: if ``split="positive"`` and pycocotools is
not installed.
"""
assert split in ['positive', 'negative']
@ -213,20 +221,12 @@ class VHR10(NonGeoDataset):
raise DatasetNotFoundError(self)
if split == 'positive':
# Must be installed to parse annotations file
try:
from pycocotools.coco import COCO # noqa: F401
except ImportError:
raise ImportError(
'pycocotools is not installed and is required to use this dataset'
)
self.coco = COCO(
pc = lazy_import('pycocotools.coco')
self.coco = pc.COCO(
os.path.join(
self.root, 'NWPU VHR-10 dataset', self.target_meta['filename']
)
)
self.coco_convert = ConvertCocoAnnotations()
self.ids = list(sorted(self.coco.imgs.keys()))
@ -381,7 +381,7 @@ class VHR10(NonGeoDataset):
Raises:
AssertionError: if ``show_feats`` argument is invalid
ImportError: if plotting masks and scikit-image is not installed
DependencyNotFoundError: If plotting masks and scikit-image is not installed.
.. versionadded:: 0.4
"""
@ -397,12 +397,7 @@ class VHR10(NonGeoDataset):
return fig
if show_feats != 'boxes':
try:
from skimage.measure import find_contours # noqa: F401
except ImportError:
raise ImportError(
'scikit-image is not installed and is required to plot masks.'
)
skimage = lazy_import('skimage')
image = sample['image'].permute(1, 2, 0).numpy()
boxes = sample['boxes'].cpu().numpy()
@ -465,7 +460,7 @@ class VHR10(NonGeoDataset):
# Add masks
if show_feats in {'masks', 'both'} and 'masks' in sample:
mask = masks[i]
contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
contours = skimage.measure.find_contours(mask, 0.5)
for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(
@ -517,7 +512,7 @@ class VHR10(NonGeoDataset):
# Add masks
if show_pred_masks:
mask = prediction_masks[i]
contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
contours = skimage.measure.find_contours(mask, 0.5)
for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(

Просмотреть файл

@ -13,7 +13,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import download_url, percentile_normalization
from .utils import download_url, lazy_import, percentile_normalization
class ZueriCrop(NonGeoDataset):
@ -82,7 +82,10 @@ class ZueriCrop(NonGeoDataset):
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
DependencyNotFoundError: If h5py is not installed.
"""
lazy_import('h5py')
self._validate_bands(bands)
self.band_indices = torch.tensor(
[self.band_names.index(b) for b in bands]
@ -97,13 +100,6 @@ class ZueriCrop(NonGeoDataset):
self._verify()
try:
import h5py # noqa: F401
except ImportError:
raise ImportError(
'h5py is not installed and is required to use this dataset'
)
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
@ -129,8 +125,7 @@ class ZueriCrop(NonGeoDataset):
Returns:
length of the dataset
"""
import h5py
h5py = lazy_import('h5py')
with h5py.File(self.filepath, 'r') as f:
length: int = f['data'].shape[0]
return length
@ -144,8 +139,7 @@ class ZueriCrop(NonGeoDataset):
Returns:
the image
"""
import h5py
h5py = lazy_import('h5py')
with h5py.File(self.filepath, 'r') as f:
array = f['data'][index, ...]