зеркало из https://github.com/microsoft/torchgeo.git
Test fewer models in trainers to avoid exceeding RAM (#1377)
* Stop the madness * isort * flake8 * Repeat for other trainers * Parentheses not needed --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
38f59d1efe
Коммит
108c94bb9b
|
@ -10,7 +10,6 @@ import timm
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from hydra.utils import instantiate
|
||||
from lightning.pytorch import Trainer
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -19,7 +18,7 @@ from torchvision.models import resnet18
|
|||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
|
||||
from torchgeo.models import get_model_weights, list_models
|
||||
from torchgeo.models import ResNet18_Weights
|
||||
from torchgeo.trainers import BYOLTask
|
||||
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
|
||||
|
||||
|
@ -98,13 +97,9 @@ class TestBYOLTask:
|
|||
"weights": None,
|
||||
}
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
weights for model in list_models() for weights in get_model_weights(model)
|
||||
]
|
||||
)
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
@pytest.fixture
|
||||
def weights(self) -> WeightsEnum:
|
||||
return ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
|
|
|
@ -10,7 +10,6 @@ import timm
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from hydra.utils import instantiate
|
||||
from lightning.pytorch import Trainer
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -24,7 +23,7 @@ from torchgeo.datamodules import (
|
|||
MisconfigurationException,
|
||||
)
|
||||
from torchgeo.datasets import BigEarthNet, EuroSAT
|
||||
from torchgeo.models import get_model_weights, list_models
|
||||
from torchgeo.models import ResNet18_Weights
|
||||
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
|
||||
|
||||
|
||||
|
@ -124,13 +123,9 @@ class TestClassificationTask:
|
|||
"weights": None,
|
||||
}
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
weights for model in list_models() for weights in get_model_weights(model)
|
||||
]
|
||||
)
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
@pytest.fixture
|
||||
def weights(self) -> WeightsEnum:
|
||||
return ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
|
|
|
@ -9,7 +9,6 @@ import pytest
|
|||
import timm
|
||||
import torch
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from hydra.utils import instantiate
|
||||
from lightning.pytorch import Trainer
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -18,7 +17,7 @@ from torch.nn import Module
|
|||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
|
||||
from torchgeo.models import get_model_weights, list_models
|
||||
from torchgeo.models import ResNet18_Weights
|
||||
from torchgeo.trainers import MoCoTask
|
||||
|
||||
from .test_classification import ClassificationTestModel
|
||||
|
@ -85,13 +84,9 @@ class TestMoCoTask:
|
|||
with pytest.warns(UserWarning, match="MoCo v3 does not use a memory bank"):
|
||||
MoCoTask(version=3, layers=3, memory_bank_size=10)
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
weights for model in list_models() for weights in get_model_weights(model)
|
||||
]
|
||||
)
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
@pytest.fixture
|
||||
def weights(self) -> WeightsEnum:
|
||||
return ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
|
|
|
@ -11,7 +11,6 @@ import timm
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from hydra.utils import instantiate
|
||||
from lightning.pytorch import Trainer
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -21,7 +20,7 @@ from torchvision.models._api import WeightsEnum
|
|||
|
||||
from torchgeo.datamodules import MisconfigurationException, TropicalCycloneDataModule
|
||||
from torchgeo.datasets import TropicalCyclone
|
||||
from torchgeo.models import get_model_weights, list_models
|
||||
from torchgeo.models import ResNet18_Weights
|
||||
from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask
|
||||
|
||||
from .test_classification import ClassificationTestModel
|
||||
|
@ -106,13 +105,9 @@ class TestRegressionTask:
|
|||
"loss": "mse",
|
||||
}
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
weights for model in list_models() for weights in get_model_weights(model)
|
||||
]
|
||||
)
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
@pytest.fixture
|
||||
def weights(self) -> WeightsEnum:
|
||||
return ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
|
@ -297,16 +292,9 @@ class TestPixelwiseRegressionTask:
|
|||
"learning_rate_schedule_patience": 6,
|
||||
}
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
weights
|
||||
for model in list_models()
|
||||
for weights in get_model_weights(model)
|
||||
if "resnet" in weights.meta["model"]
|
||||
]
|
||||
)
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
@pytest.fixture
|
||||
def weights(self) -> WeightsEnum:
|
||||
return ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
|
|
|
@ -11,7 +11,6 @@ import timm
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from hydra.utils import instantiate
|
||||
from lightning.pytorch import Trainer
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -21,7 +20,7 @@ from torchvision.models._api import WeightsEnum
|
|||
|
||||
from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule
|
||||
from torchgeo.datasets import LandCoverAI
|
||||
from torchgeo.models import get_model_weights, list_models
|
||||
from torchgeo.models import ResNet18_Weights
|
||||
from torchgeo.trainers import SemanticSegmentationTask
|
||||
|
||||
|
||||
|
@ -124,16 +123,9 @@ class TestSemanticSegmentationTask:
|
|||
"ignore_index": 0,
|
||||
}
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
weights
|
||||
for model in list_models()
|
||||
for weights in get_model_weights(model)
|
||||
if "resnet" in weights.meta["model"]
|
||||
]
|
||||
)
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
@pytest.fixture
|
||||
def weights(self) -> WeightsEnum:
|
||||
return ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
|
|
|
@ -9,7 +9,6 @@ import pytest
|
|||
import timm
|
||||
import torch
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from hydra.utils import instantiate
|
||||
from lightning.pytorch import Trainer
|
||||
from omegaconf import OmegaConf
|
||||
|
@ -18,7 +17,7 @@ from torch.nn import Module
|
|||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
|
||||
from torchgeo.models import get_model_weights, list_models
|
||||
from torchgeo.models import ResNet18_Weights
|
||||
from torchgeo.trainers import SimCLRTask
|
||||
|
||||
from .test_classification import ClassificationTestModel
|
||||
|
@ -83,13 +82,9 @@ class TestSimCLRTask:
|
|||
with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"):
|
||||
SimCLRTask(version=2, memory_bank_size=0)
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
weights for model in list_models() for weights in get_model_weights(model)
|
||||
]
|
||||
)
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
@pytest.fixture
|
||||
def weights(self) -> WeightsEnum:
|
||||
return ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
|
|
Загрузка…
Ссылка в новой задаче