зеркало из https://github.com/microsoft/torchgeo.git
Rework list of required dependencies (#287)
* Rework list of required dependencies * Update open3d import error msg * Style fixes * Remove extra empty line * Increase test coverage * Fix idtrees tests
This commit is contained in:
Родитель
ae11f10502
Коммит
01ae2db287
|
@ -37,7 +37,7 @@ jobs:
|
|||
run: sudo apt-get install pandoc
|
||||
- name: Install pip dependencies
|
||||
run: |
|
||||
pip install .[train]
|
||||
pip install .
|
||||
pip install -r docs/requirements.txt
|
||||
- name: Run sphinx checks
|
||||
run: cd docs && make html
|
||||
|
|
|
@ -25,7 +25,7 @@ jobs:
|
|||
- name: Install pip dependencies
|
||||
run: |
|
||||
pip install gdal tqdm # TODO: these deps shouldn't be needed
|
||||
pip install .[datasets,tests,train]
|
||||
pip install .[datasets,tests]
|
||||
pip install -r docs/requirements.txt
|
||||
- name: Run notebook checks
|
||||
env:
|
||||
|
@ -42,6 +42,6 @@ jobs:
|
|||
with:
|
||||
python-version: 3.9
|
||||
- name: Install pip dependencies
|
||||
run: pip install .[datasets,tests,train]
|
||||
run: pip install .[datasets,tests]
|
||||
- name: Run integration checks
|
||||
run: pytest -m slow
|
||||
|
|
|
@ -20,9 +20,25 @@ jobs:
|
|||
- name: Install pip dependencies
|
||||
run: |
|
||||
pip install cython numpy # needed for pycocotools
|
||||
pip install .[datasets,tests,train]
|
||||
pip install .[datasets,tests]
|
||||
- name: Run mypy checks
|
||||
run: mypy .
|
||||
datasets:
|
||||
name: datasets
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone repo
|
||||
uses: actions/checkout@v2
|
||||
- name: Set up python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Install pip dependencies
|
||||
run: |
|
||||
pip install cython numpy # needed for pycocotools
|
||||
pip install .[tests]
|
||||
- name: Run pytest checks
|
||||
run: pytest --cov=torchgeo --cov-report=xml
|
||||
pytest:
|
||||
name: pytest
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
@ -64,7 +80,7 @@ jobs:
|
|||
- name: Install pip dependencies
|
||||
run: |
|
||||
pip install cython numpy # needed for pycocotools
|
||||
pip install .[datasets,tests,train]
|
||||
pip install .[datasets,tests]
|
||||
- name: Run pytest checks
|
||||
run: pytest --cov=torchgeo --cov-report=xml
|
||||
- name: Report coverage
|
||||
|
|
17
setup.cfg
17
setup.cfg
|
@ -35,6 +35,8 @@ install_requires =
|
|||
kornia>=0.5.4
|
||||
matplotlib
|
||||
numpy
|
||||
# omegaconf 2.1+ required for to_object method
|
||||
omegaconf>=2.1
|
||||
# pillow 2.9+ required for height attribute
|
||||
pillow>=2.9
|
||||
# pyproj 2.2+ required for CRS object
|
||||
|
@ -49,8 +51,13 @@ install_requires =
|
|||
scikit-learn>=0.18
|
||||
# shapely 1.3+ required for Python 3 support
|
||||
shapely>=1.3
|
||||
# segmentation-models-pytorch 0.2+ required for smp.losses module
|
||||
segmentation-models-pytorch>=0.2
|
||||
# timm 0.2.1+ required for `features_only` option in create_model
|
||||
timm>=0.2.1
|
||||
# torch 1.7+ required for typing
|
||||
torch>=1.7
|
||||
torchmetrics
|
||||
# torchvision 0.3+ required for download_file_from_google_drive
|
||||
torchvision>=0.3
|
||||
python_requires = >= 3.6
|
||||
|
@ -72,15 +79,6 @@ datasets =
|
|||
rarfile>=3
|
||||
# scipy 0.9+ required for scipy.io.wavfile.read
|
||||
scipy>=0.9
|
||||
# Optional trainer requirements
|
||||
train =
|
||||
# omegaconf 2.1+ required for to_object method
|
||||
omegaconf>=2.1
|
||||
# segmentation-models-pytorch 0.2+ required for smp.losses module
|
||||
segmentation-models-pytorch>=0.2
|
||||
# timm 0.2.1+ required for `features_only` option in create_model
|
||||
timm>=0.2.1
|
||||
torchmetrics
|
||||
# Optional developer requirements
|
||||
style =
|
||||
# black 21+ required for Python 3.9 support
|
||||
|
@ -92,6 +90,7 @@ style =
|
|||
isort[colors]>=5.8
|
||||
# pydocstyle 6.1+ required for pyproject.toml support
|
||||
pydocstyle[toml]>=6.1
|
||||
# Optional testing requirements
|
||||
tests =
|
||||
# mypy 0.900+ required for pyproject.toml support
|
||||
mypy>=0.900
|
||||
|
|
|
@ -25,7 +25,6 @@ class TestADVANCE:
|
|||
def dataset(
|
||||
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
|
||||
) -> ADVANCE:
|
||||
pytest.importorskip("scipy", minversion="0.9.0")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.utils, "download_url", download_url
|
||||
)
|
||||
|
@ -57,6 +56,7 @@ class TestADVANCE:
|
|||
)
|
||||
|
||||
def test_getitem(self, dataset: ADVANCE) -> None:
|
||||
pytest.importorskip("scipy", minversion="0.9.0")
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["image"], torch.Tensor)
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Any, Generator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -17,7 +18,6 @@ from torch.utils.data import ConcatDataset
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import VHR10
|
||||
|
||||
pytest.importorskip("rarfile")
|
||||
pytest.importorskip("pycocotools")
|
||||
|
||||
|
||||
|
@ -35,6 +35,7 @@ class TestVHR10:
|
|||
tmp_path: Path,
|
||||
request: SubRequest,
|
||||
) -> VHR10:
|
||||
pytest.importorskip("rarfile", minversion="3")
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.nwpu, "download_url", download_url
|
||||
)
|
||||
|
@ -54,6 +55,21 @@ class TestVHR10:
|
|||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return VHR10(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(
|
||||
self, monkeypatch: Generator[MonkeyPatch, None, None]
|
||||
) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == "pycocotools.coco":
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
builtins, "__import__", mocked_import
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: VHR10) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
|
@ -84,3 +100,13 @@ class TestVHR10:
|
|||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
VHR10(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: VHR10, mock_missing_module: None
|
||||
) -> None:
|
||||
if dataset.split == "positive":
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match="pycocotools is not installed and is required to use this datase",
|
||||
):
|
||||
VHR10(dataset.root, dataset.split)
|
||||
|
|
|
@ -16,8 +16,6 @@ from _pytest.monkeypatch import MonkeyPatch
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import RESISC45, RESISC45DataModule
|
||||
|
||||
pytest.importorskip("rarfile")
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
@ -32,6 +30,8 @@ class TestRESISC45:
|
|||
tmp_path: Path,
|
||||
request: SubRequest,
|
||||
) -> RESISC45:
|
||||
pytest.importorskip("rarfile", minversion="3")
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.resisc45, "download_url", download_url
|
||||
)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Any, Generator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -13,6 +14,8 @@ from _pytest.monkeypatch import MonkeyPatch
|
|||
|
||||
from torchgeo.datasets import So2Sat, So2SatDataModule
|
||||
|
||||
pytest.importorskip("h5py")
|
||||
|
||||
|
||||
class TestSo2Sat:
|
||||
@pytest.fixture(params=["train", "validation", "test"])
|
||||
|
@ -31,6 +34,21 @@ class TestSo2Sat:
|
|||
transforms = nn.Identity() # type: ignore[attr-defined]
|
||||
return So2Sat(root, split, transforms, checksum=True)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(
|
||||
self, monkeypatch: Generator[MonkeyPatch, None, None]
|
||||
) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == "h5py":
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
builtins, "__import__", mocked_import
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: So2Sat) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
|
|
|
@ -93,7 +93,7 @@ def test_mock_missing_module(mock_missing_module: None) -> None:
|
|||
],
|
||||
)
|
||||
def test_extract_archive(src: str, tmp_path: Path) -> None:
|
||||
pytest.importorskip("rarfile")
|
||||
pytest.importorskip("rarfile", minversion="3")
|
||||
extract_archive(os.path.join("tests", "data", src), str(tmp_path))
|
||||
|
||||
|
||||
|
|
|
@ -117,7 +117,12 @@ class VHR10(VisionDataset):
|
|||
|
||||
if split == "positive":
|
||||
# Must be installed to parse annotations file
|
||||
from pycocotools.coco import COCO
|
||||
try:
|
||||
from pycocotools.coco import COCO # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pycocotools is not installed and is required to use this dataset"
|
||||
)
|
||||
|
||||
self.coco = COCO(
|
||||
os.path.join(
|
||||
|
|
|
@ -92,7 +92,12 @@ class So2Sat(VisionDataset):
|
|||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if data is not found in ``root``, or checksums don't match
|
||||
"""
|
||||
import h5py
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"h5py is not installed and is required to use this dataset"
|
||||
)
|
||||
|
||||
assert split in ["train", "validation", "test"]
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче