Test fewer models in trainers to avoid exceeding RAM (#1377)

* Stop the madness

* isort

* flake8

* Repeat for other trainers

* Parentheses not needed

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Caleb Robinson 2023-05-29 09:28:36 -07:00 коммит произвёл GitHub
Родитель 38f59d1efe
Коммит 108c94bb9b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 27 добавлений и 67 удалений

Просмотреть файл

@ -10,7 +10,6 @@ import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
@ -19,7 +18,7 @@ from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
@ -98,13 +97,9 @@ class TestBYOLTask:
"weights": None,
}
@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@pytest.fixture
def mocked_weights(

Просмотреть файл

@ -10,7 +10,6 @@ import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
@ -24,7 +23,7 @@ from torchgeo.datamodules import (
MisconfigurationException,
)
from torchgeo.datasets import BigEarthNet, EuroSAT
from torchgeo.models import get_model_weights, list_models
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
@ -124,13 +123,9 @@ class TestClassificationTask:
"weights": None,
}
@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@pytest.fixture
def mocked_weights(

Просмотреть файл

@ -9,7 +9,6 @@ import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
@ -18,7 +17,7 @@ from torch.nn import Module
from torchvision.models._api import WeightsEnum
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import MoCoTask
from .test_classification import ClassificationTestModel
@ -85,13 +84,9 @@ class TestMoCoTask:
with pytest.warns(UserWarning, match="MoCo v3 does not use a memory bank"):
MoCoTask(version=3, layers=3, memory_bank_size=10)
@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@pytest.fixture
def mocked_weights(

Просмотреть файл

@ -11,7 +11,6 @@ import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
@ -21,7 +20,7 @@ from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import MisconfigurationException, TropicalCycloneDataModule
from torchgeo.datasets import TropicalCyclone
from torchgeo.models import get_model_weights, list_models
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask
from .test_classification import ClassificationTestModel
@ -106,13 +105,9 @@ class TestRegressionTask:
"loss": "mse",
}
@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@pytest.fixture
def mocked_weights(
@ -297,16 +292,9 @@ class TestPixelwiseRegressionTask:
"learning_rate_schedule_patience": 6,
}
@pytest.fixture(
params=[
weights
for model in list_models()
for weights in get_model_weights(model)
if "resnet" in weights.meta["model"]
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@pytest.fixture
def mocked_weights(

Просмотреть файл

@ -11,7 +11,6 @@ import timm
import torch
import torch.nn as nn
import torchvision
from _pytest.fixtures import SubRequest
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
@ -21,7 +20,7 @@ from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule
from torchgeo.datasets import LandCoverAI
from torchgeo.models import get_model_weights, list_models
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import SemanticSegmentationTask
@ -124,16 +123,9 @@ class TestSemanticSegmentationTask:
"ignore_index": 0,
}
@pytest.fixture(
params=[
weights
for model in list_models()
for weights in get_model_weights(model)
if "resnet" in weights.meta["model"]
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@pytest.fixture
def mocked_weights(

Просмотреть файл

@ -9,7 +9,6 @@ import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
@ -18,7 +17,7 @@ from torch.nn import Module
from torchvision.models._api import WeightsEnum
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import SimCLRTask
from .test_classification import ClassificationTestModel
@ -83,13 +82,9 @@ class TestSimCLRTask:
with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"):
SimCLRTask(version=2, memory_bank_size=0)
@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@pytest.fixture
def mocked_weights(