Revert "Monkeypatch everything"

This reverts commit e3e8d7d042.
This commit is contained in:
Adam J. Stewart 2023-01-21 13:54:00 -06:00
Родитель e3e8d7d042
Коммит 9b27bd705b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
5 изменённых файлов: 6 добавлений и 39 удалений

Просмотреть файл

@ -2,7 +2,6 @@
# Licensed under the MIT License.
from pathlib import Path
from typing import Any, Dict
import pytest
import timm
@ -14,11 +13,6 @@ from torchvision.models._api import WeightsEnum
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestResNet18:
@pytest.fixture(params=[*ResNet18_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
@ -31,8 +25,7 @@ class TestResNet18:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
monkeypatch.setattr(weights, "url", path.as_uri())
return weights
def test_resnet(self) -> None:
@ -58,8 +51,7 @@ class TestResNet50:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet50", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
monkeypatch.setattr(weights, "url", path.as_uri())
return weights
def test_resnet(self) -> None:

Просмотреть файл

@ -2,7 +2,6 @@
# Licensed under the MIT License.
from pathlib import Path
from typing import Any, Dict
import pytest
import timm
@ -14,11 +13,6 @@ from torchvision.models._api import WeightsEnum
from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestViTSmall16:
@pytest.fixture(params=[*ViTSmall16_Weights])
def weights(self, request: SubRequest) -> WeightsEnum:
@ -33,8 +27,7 @@ class TestViTSmall16:
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
monkeypatch.setattr(weights, "url", path.as_uri())
return weights
def test_vit(self) -> None:

Просмотреть файл

@ -23,11 +23,6 @@ from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
from .test_utils import SegmentationTestModel
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestBYOL:
def test_custom_augment_fn(self) -> None:
backbone = resnet18()
@ -84,8 +79,7 @@ class TestBYOLTask:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
monkeypatch.setattr(weights, "url", path.as_uri())
return weights
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:

Просмотреть файл

@ -31,11 +31,6 @@ def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestClassificationTask:
@pytest.mark.parametrize(
"name,classname",
@ -107,8 +102,7 @@ class TestClassificationTask:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
monkeypatch.setattr(weights, "url", path.as_uri())
return weights
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None:

Просмотреть файл

@ -20,11 +20,6 @@ from torchgeo.trainers import RegressionTask
from .test_utils import RegressionTestModel
def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]:
state_dict: Dict[str, Any] = torch.load(url)
return state_dict
class TestRegressionTask:
@pytest.mark.parametrize(
"name,classname",
@ -88,8 +83,7 @@ class TestRegressionTask:
path = tmp_path / f"{weights}.pth"
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"])
torch.save(model.state_dict(), path)
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torch.hub, "load_state_dict_from_url", load)
monkeypatch.setattr(weights, "url", path.as_uri())
return weights
def test_weight_file(self, model_kwargs: Dict[str, Any], checkpoint: str) -> None: