зеркало из https://github.com/microsoft/torchgeo.git
Parametrize tests
This commit is contained in:
Родитель
9ee482e88f
Коммит
fa8399eb6a
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||
import shutil
|
||||
from typing import Generator
|
||||
|
||||
from _pytest.fixtures import SubRequest
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
import torch
|
||||
|
@ -25,9 +26,12 @@ class TestCOWC:
|
|||
|
||||
|
||||
class TestCOWCCounting:
|
||||
@pytest.fixture
|
||||
@pytest.fixture(params=["train", "test"])
|
||||
def dataset(
|
||||
self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path
|
||||
self,
|
||||
monkeypatch: Generator[MonkeyPatch, None, None],
|
||||
tmp_path: Path,
|
||||
request: SubRequest,
|
||||
) -> _COWC:
|
||||
monkeypatch.setattr( # type: ignore[attr-defined]
|
||||
torchgeo.datasets.cowc, "download_url", download_url
|
||||
|
@ -49,7 +53,7 @@ class TestCOWCCounting:
|
|||
monkeypatch.setattr(COWCCounting, "md5s", md5s) # type: ignore[attr-defined]
|
||||
(tmp_path / "cowc_counting").mkdir()
|
||||
root = str(tmp_path)
|
||||
split = "train"
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
return COWCCounting(root, split, transforms, download=True, checksum=True)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from _pytest.fixtures import SubRequest
|
||||
import pytest
|
||||
from pytest import MonkeyPatch
|
||||
import torch
|
||||
|
@ -11,8 +12,8 @@ from torchgeo.transforms import Identity
|
|||
|
||||
|
||||
class TestSEN12MS:
|
||||
@pytest.fixture
|
||||
def dataset(self, monkeypatch: Generator[MonkeyPatch, None, None]) -> SEN12MS:
|
||||
@pytest.fixture(params=["train", "test"])
|
||||
def dataset(self, monkeypatch: Generator[MonkeyPatch, None, None], request: SubRequest) -> SEN12MS:
|
||||
md5s = [
|
||||
"3079d1c5038fa101ec2072657f2cb1ab",
|
||||
"f11487a4b2e641b64ed80a031c4d121d",
|
||||
|
@ -32,7 +33,7 @@ class TestSEN12MS:
|
|||
|
||||
monkeypatch.setattr(SEN12MS, "md5s", md5s) # type: ignore[attr-defined]
|
||||
root = os.path.join("tests", "data")
|
||||
split = "train"
|
||||
split = request.param
|
||||
transforms = Identity()
|
||||
return SEN12MS(root, split, transforms, checksum=True)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче