зеркало из https://github.com/microsoft/torchgeo.git
Родитель
e3e8d7d042
Коммит
9b27bd705b
|
@ -2,7 +2,6 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
import timm
|
||||
|
@ -14,11 +13,6 @@ from torchvision.models._api import WeightsEnum
|
|||
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestResNet18:
|
||||
@pytest.fixture(params=[*ResNet18_Weights])
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
|
@ -31,8 +25,7 @@ class TestResNet18:
|
|||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
|
||||
monkeypatch.setattr(weights, "url", path.as_uri())
|
||||
return weights
|
||||
|
||||
def test_resnet(self) -> None:
|
||||
|
@ -58,8 +51,7 @@ class TestResNet50:
|
|||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
|
||||
monkeypatch.setattr(weights, "url", path.as_uri())
|
||||
return weights
|
||||
|
||||
def test_resnet(self) -> None:
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
import timm
|
||||
|
@ -14,11 +13,6 @@ from torchvision.models._api import WeightsEnum
|
|||
from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestViTSmall16:
|
||||
@pytest.fixture(params=[*ViTSmall16_Weights])
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
|
@ -33,8 +27,7 @@ class TestViTSmall16:
|
|||
weights.meta["model"], in_chans=weights.meta["in_chans"]
|
||||
)
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
|
||||
monkeypatch.setattr(weights, "url", path.as_uri())
|
||||
return weights
|
||||
|
||||
def test_vit(self) -> None:
|
||||
|
|
|
@ -23,11 +23,6 @@ from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
|
|||
from .test_utils import SegmentationTestModel
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestBYOL:
|
||||
def test_custom_augment_fn(self) -> None:
|
||||
backbone = resnet18()
|
||||
|
@ -84,8 +79,7 @@ class TestBYOLTask:
|
|||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
|
||||
monkeypatch.setattr(weights, "url", path.as_uri())
|
||||
return weights
|
||||
|
||||
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
|
||||
|
|
|
@ -31,11 +31,6 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
|
|||
return ClassificationTestModel(**kwargs)
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestClassificationTask:
|
||||
@pytest.mark.parametrize(
|
||||
"name,classname",
|
||||
|
@ -107,8 +102,7 @@ class TestClassificationTask:
|
|||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
|
||||
monkeypatch.setattr(weights, "url", path.as_uri())
|
||||
return weights
|
||||
|
||||
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
|
||||
|
|
|
@ -20,11 +20,6 @@ from torchgeo.trainers import RegressionTask
|
|||
from .test_utils import RegressionTestModel
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
state_dict: Dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestRegressionTask:
|
||||
@pytest.mark.parametrize(
|
||||
"name,classname",
|
||||
|
@ -88,8 +83,7 @@ class TestRegressionTask:
|
|||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
|
||||
torch.save(model.state_dict(), path)
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
|
||||
monkeypatch.setattr(weights, "url", path.as_uri())
|
||||
return weights
|
||||
|
||||
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:
|
||||
|
|
Загрузка…
Ссылка в новой задаче