зеркало из https://github.com/microsoft/torchgeo.git
Add SimCLR trainer (#1252)
* add simclr and tests * add lightly to reqs * pyupgrade * Copy things from prior implementation * Add SimCLR v2 projection head * Remove kwargs * Call __init__ explicitly * Fix mypy and docs * Can't test newer setuptools * Default to output dim of model * Add memory bank * Ignore erroneous warning * Fix configs, test SSL4EO * Fix a few layer bugs * mypy fixes * kernel_size must be an integer * Fix SeCo in_channels * Get more coverage * Bump min lightly * Default logging * Test weights * mypy fix * Grab max_epochs from the trainer * max_epochs param removed * Use num_features * Remove classification head * SimCLR uses LARS, with Adam as a backup * Add warnings * Grab num features directly from model * Check if identity * Match timm model design * Capture warnings * Fix tests * Increase coverage * Fix method name * More typos * Escape regex * Newer setuptools now supported * New batch norm for every layer * Rename forward arg * Clarify usage of weights parameter Co-authored-by: Caleb Robinson <calebrob6@gmail.com> * Fix flake8 * Check it * Use hydra * Track average L2 normed stdev over features * SimCLR decays lr to 0 * Add lr warmup * Fix version access * Fix LinearLR * isinstance supports tuples * Comment capitalization * Require lightly 1.4.3+ * Require lightly 1.4.3+ * Bump lightly version * Add RandomGrayscale * Flake8 fixes * Placate pydocstyle * Clarify docs * Pass correct weights --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Родитель
3cc1427c24
Коммит
ef7a9ad793
|
@ -62,6 +62,7 @@ nitpick_ignore = [
|
|||
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
|
||||
("py:class", "timm.models.resnet.ResNet"),
|
||||
("py:class", "timm.models.vision_transformer.VisionTransformer"),
|
||||
("py:class", "torch.optim.lr_scheduler.LRScheduler"),
|
||||
("py:class", "torchvision.models._api.WeightsEnum"),
|
||||
("py:class", "torchvision.models.resnet.ResNet"),
|
||||
]
|
||||
|
|
|
@ -26,6 +26,7 @@ dependencies:
|
|||
- isort[colors]>=5.8
|
||||
- kornia>=0.6.5
|
||||
- laspy>=2
|
||||
- lightly>=1.4.4
|
||||
- lightning>=1.8
|
||||
- mypy>=0.900
|
||||
- nbmake>=1.3.3
|
||||
|
|
|
@ -5,6 +5,7 @@ setuptools==42.0.0
|
|||
einops==0.3.0
|
||||
fiona==1.8.19
|
||||
kornia==0.6.5
|
||||
lightly==1.4.4
|
||||
lightning==1.8.0
|
||||
matplotlib==3.3.3
|
||||
numpy==1.19.3
|
||||
|
|
|
@ -6,6 +6,7 @@ einops==0.6.1
|
|||
fiona==1.9.3
|
||||
kornia==0.6.12
|
||||
lightning==2.0.2
|
||||
lightly==1.4.4
|
||||
matplotlib==3.7.1
|
||||
numpy==1.24.3
|
||||
pillow==9.5.0
|
||||
|
|
|
@ -30,6 +30,8 @@ install_requires =
|
|||
fiona>=1.8.19,<2
|
||||
# kornia 0.6.5+ required due to change in kornia.augmentation API
|
||||
kornia>=0.6.5,<0.7
|
||||
# lightly 1.4.4+ required for MoCo v3 support
|
||||
lightly>=1.4.4
|
||||
# lightning 1.8+ is first release
|
||||
lightning>=1.8,<3
|
||||
# matplotlib 3.3.3+ required for Python 3.9 wheels
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
module:
|
||||
_target_: torchgeo.trainers.SimCLRTask
|
||||
model: "resnet18"
|
||||
in_channels: 4
|
||||
version: 1
|
||||
layers: 2
|
||||
memory_bank_size: 0
|
||||
|
||||
datamodule:
|
||||
_target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
|
||||
root: "tests/data/chesapeake/cvpr"
|
||||
download: false
|
||||
train_splits:
|
||||
- "de-test"
|
||||
val_splits:
|
||||
- "de-test"
|
||||
test_splits:
|
||||
- "de-test"
|
||||
batch_size: 2
|
||||
patch_size: 64
|
||||
num_workers: 0
|
||||
class_set: 5
|
||||
use_prior_labels: True
|
|
@ -0,0 +1,17 @@
|
|||
module:
|
||||
_target_: torchgeo.trainers.SimCLRTask
|
||||
model: "resnet18"
|
||||
in_channels: 3
|
||||
version: 1
|
||||
layers: 2
|
||||
hidden_dim: 8
|
||||
output_dim: 8
|
||||
weight_decay: 1e-6
|
||||
memory_bank_size: 0
|
||||
|
||||
datamodule:
|
||||
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
|
||||
root: "tests/data/seco"
|
||||
seasons: 1
|
||||
batch_size: 2
|
||||
num_workers: 0
|
|
@ -0,0 +1,17 @@
|
|||
module:
|
||||
_target_: torchgeo.trainers.SimCLRTask
|
||||
model: "resnet18"
|
||||
in_channels: 3
|
||||
version: 2
|
||||
layers: 4
|
||||
hidden_dim: 8
|
||||
output_dim: 8
|
||||
weight_decay: 1e-4
|
||||
memory_bank_size: 10
|
||||
|
||||
datamodule:
|
||||
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
|
||||
root: "tests/data/seco"
|
||||
seasons: 2
|
||||
batch_size: 2
|
||||
num_workers: 0
|
|
@ -0,0 +1,17 @@
|
|||
module:
|
||||
_target_: torchgeo.trainers.SimCLRTask
|
||||
model: "resnet18"
|
||||
in_channels: 13
|
||||
version: 1
|
||||
layers: 2
|
||||
hidden_dim: 8
|
||||
output_dim: 8
|
||||
weight_decay: 1e-6
|
||||
memory_bank_size: 0
|
||||
|
||||
datamodule:
|
||||
_target_: torchgeo.datamodules.SSL4EOS12DataModule
|
||||
root: "tests/data/ssl4eo/s12"
|
||||
seasons: 1
|
||||
batch_size: 2
|
||||
num_workers: 0
|
|
@ -0,0 +1,17 @@
|
|||
module:
|
||||
_target_: torchgeo.trainers.SimCLRTask
|
||||
model: "resnet18"
|
||||
in_channels: 13
|
||||
version: 2
|
||||
layers: 3
|
||||
hidden_dim: 8
|
||||
output_dim: 8
|
||||
weight_decay: 1e-4
|
||||
memory_bank_size: 10
|
||||
|
||||
datamodule:
|
||||
_target_: torchgeo.datamodules.SSL4EOS12DataModule
|
||||
root: "tests/data/ssl4eo/s12"
|
||||
seasons: 2
|
||||
batch_size: 2
|
||||
num_workers: 0
|
|
@ -33,7 +33,8 @@ class ClassificationTestModel(Module):
|
|||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=1, kernel_size=1)
|
||||
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(1, num_classes)
|
||||
self.fc = nn.Linear(1, num_classes) if num_classes else nn.Identity()
|
||||
self.num_features = 1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv1(x)
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import timm
|
||||
import torch
|
||||
import torchvision
|
||||
from _pytest.fixtures import SubRequest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from hydra.utils import instantiate
|
||||
from lightning.pytorch import Trainer
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn import Module
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
|
||||
from torchgeo.models import get_model_weights, list_models
|
||||
from torchgeo.trainers import SimCLRTask
|
||||
|
||||
from .test_classification import ClassificationTestModel
|
||||
|
||||
|
||||
def create_model(*args: Any, **kwargs: Any) -> Module:
|
||||
return ClassificationTestModel(**kwargs)
|
||||
|
||||
|
||||
def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
state_dict: dict[str, Any] = torch.load(url)
|
||||
return state_dict
|
||||
|
||||
|
||||
class TestSimCLRTask:
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"chesapeake_cvpr_prior_simclr",
|
||||
"seco_simclr_1",
|
||||
"seco_simclr_2",
|
||||
"ssl4eo_s12_simclr_1",
|
||||
"ssl4eo_s12_simclr_2",
|
||||
],
|
||||
)
|
||||
def test_trainer(
|
||||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
|
||||
) -> None:
|
||||
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
|
||||
|
||||
if name.startswith("seco"):
|
||||
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)
|
||||
|
||||
if name.startswith("ssl4eo_s12"):
|
||||
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)
|
||||
|
||||
# Instantiate datamodule
|
||||
datamodule = instantiate(conf.datamodule)
|
||||
|
||||
# Instantiate model
|
||||
monkeypatch.setattr(timm, "create_model", create_model)
|
||||
model = instantiate(conf.module)
|
||||
|
||||
# Instantiate trainer
|
||||
trainer = Trainer(
|
||||
accelerator="cpu",
|
||||
fast_dev_run=fast_dev_run,
|
||||
log_every_n_steps=1,
|
||||
max_epochs=1,
|
||||
)
|
||||
trainer.fit(model=model, datamodule=datamodule)
|
||||
|
||||
def test_version_warnings(self) -> None:
|
||||
with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"):
|
||||
SimCLRTask(version=1, layers=3)
|
||||
with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"):
|
||||
SimCLRTask(version=1, memory_bank_size=10)
|
||||
with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"):
|
||||
SimCLRTask(version=2, layers=2)
|
||||
with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"):
|
||||
SimCLRTask(version=2, memory_bank_size=0)
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
weights for model in list_models() for weights in get_model_weights(model)
|
||||
]
|
||||
)
|
||||
def weights(self, request: SubRequest) -> WeightsEnum:
|
||||
return request.param
|
||||
|
||||
@pytest.fixture
|
||||
def mocked_weights(
|
||||
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
|
||||
) -> WeightsEnum:
|
||||
path = tmp_path / f"{weights}.pth"
|
||||
model = timm.create_model(
|
||||
weights.meta["model"], in_chans=weights.meta["in_chans"]
|
||||
)
|
||||
torch.save(model.state_dict(), path)
|
||||
try:
|
||||
monkeypatch.setattr(weights.value, "url", str(path))
|
||||
except AttributeError:
|
||||
monkeypatch.setattr(weights, "url", str(path))
|
||||
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
|
||||
return weights
|
||||
|
||||
def test_weight_file(self, checkpoint: str) -> None:
|
||||
model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint}
|
||||
match = "num classes .* != num classes in pretrained model"
|
||||
with pytest.warns(UserWarning, match=match):
|
||||
SimCLRTask(**model_kwargs)
|
||||
|
||||
def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
|
||||
model_kwargs: dict[str, Any] = {
|
||||
"model": mocked_weights.meta["model"],
|
||||
"weights": mocked_weights,
|
||||
"in_channels": mocked_weights.meta["in_chans"],
|
||||
}
|
||||
match = "num classes .* != num classes in pretrained model"
|
||||
with pytest.warns(UserWarning, match=match):
|
||||
SimCLRTask(**model_kwargs)
|
||||
|
||||
def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
|
||||
model_kwargs: dict[str, Any] = {
|
||||
"model": mocked_weights.meta["model"],
|
||||
"weights": str(mocked_weights),
|
||||
"in_channels": mocked_weights.meta["in_chans"],
|
||||
}
|
||||
match = "num classes .* != num classes in pretrained model"
|
||||
with pytest.warns(UserWarning, match=match):
|
||||
SimCLRTask(**model_kwargs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
|
||||
model_kwargs: dict[str, Any] = {
|
||||
"model": weights.meta["model"],
|
||||
"weights": weights,
|
||||
"in_channels": weights.meta["in_chans"],
|
||||
}
|
||||
match = "num classes .* != num classes in pretrained model"
|
||||
with pytest.warns(UserWarning, match=match):
|
||||
SimCLRTask(**model_kwargs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_weight_str_download(self, weights: WeightsEnum) -> None:
|
||||
model_kwargs: dict[str, Any] = {
|
||||
"model": weights.meta["model"],
|
||||
"weights": str(weights),
|
||||
"in_channels": weights.meta["in_chans"],
|
||||
}
|
||||
match = "num classes .* != num classes in pretrained model"
|
||||
with pytest.warns(UserWarning, match=match):
|
||||
SimCLRTask(**model_kwargs)
|
|
@ -8,6 +8,7 @@ from .classification import ClassificationTask, MultiLabelClassificationTask
|
|||
from .detection import ObjectDetectionTask
|
||||
from .regression import PixelwiseRegressionTask, RegressionTask
|
||||
from .segmentation import SemanticSegmentationTask
|
||||
from .simclr import SimCLRTask
|
||||
|
||||
__all__ = (
|
||||
"BYOLTask",
|
||||
|
@ -17,4 +18,5 @@ __all__ = (
|
|||
"PixelwiseRegressionTask",
|
||||
"RegressionTask",
|
||||
"SemanticSegmentationTask",
|
||||
"SimCLRTask",
|
||||
)
|
||||
|
|
|
@ -0,0 +1,274 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""SimCLR trainer for self-supervised learning (SSL)."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from lightly.loss import NTXentLoss
|
||||
from lightly.models.modules import SimCLRProjectionHead
|
||||
from lightning import LightningModule
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
import torchgeo.transforms as T
|
||||
|
||||
from ..models import get_weight
|
||||
from . import utils
|
||||
|
||||
try:
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
except ImportError:
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
|
||||
def simclr_augmentations(size: int, weights: Tensor) -> nn.Module:
|
||||
"""Data augmentation used by SimCLR.
|
||||
|
||||
Args:
|
||||
size: Size of patch to crop.
|
||||
weights: Weight vector for grayscale computation.
|
||||
|
||||
Returns:
|
||||
Data augmentation pipeline.
|
||||
"""
|
||||
# https://github.com/google-research/simclr/blob/master/data_util.py
|
||||
ks = size // 10 // 2 * 2 + 1
|
||||
return K.AugmentationSequential(
|
||||
K.RandomResizedCrop(size=(size, size), ratio=(0.75, 1.33)),
|
||||
K.RandomHorizontalFlip(),
|
||||
K.RandomVerticalFlip(), # added
|
||||
# Not appropriate for multispectral imagery, seasonal contrast used instead
|
||||
# K.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8)
|
||||
T.RandomGrayscale(weights=weights, p=0.2),
|
||||
K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2)),
|
||||
data_keys=["input"],
|
||||
)
|
||||
|
||||
|
||||
class SimCLRTask(LightningModule): # type: ignore[misc]
|
||||
"""SimCLR: a simple framework for contrastive learning of visual representations.
|
||||
|
||||
Reference implementation:
|
||||
|
||||
* https://github.com/google-research/simclr
|
||||
|
||||
If you use this trainer in your research, please cite the following papers:
|
||||
|
||||
* v1: https://arxiv.org/abs/2002.05709
|
||||
* v2: https://arxiv.org/abs/2006.10029
|
||||
|
||||
.. versionadded:: 0.5
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "resnet50",
|
||||
weights: Optional[Union[WeightsEnum, str, bool]] = None,
|
||||
in_channels: int = 3,
|
||||
version: int = 2,
|
||||
layers: int = 3,
|
||||
hidden_dim: Optional[int] = None,
|
||||
output_dim: Optional[int] = None,
|
||||
lr: float = 4.8,
|
||||
weight_decay: float = 1e-4,
|
||||
temperature: float = 0.07,
|
||||
memory_bank_size: int = 64000,
|
||||
gather_distributed: bool = False,
|
||||
size: int = 224,
|
||||
grayscale_weights: Optional[Tensor] = None,
|
||||
augmentations: Optional[nn.Module] = None,
|
||||
) -> None:
|
||||
"""Initialize a new SimCLRTask instance.
|
||||
|
||||
Args:
|
||||
model: Name of the timm model to use.
|
||||
weights: Initial model weights. Either a weight enum, the string
|
||||
representation of a weight enum, True for ImageNet weights, False
|
||||
or None for random weights, or the path to a saved model state dict.
|
||||
in_channels: Number of input channels to model.
|
||||
version: Version of SimCLR, 1--2.
|
||||
layers: Number of layers in projection head (2 for v1, 3+ for v2).
|
||||
hidden_dim: Number of hidden dimensions in projection head
|
||||
(defaults to output dimension of model).
|
||||
output_dim: Number of output dimensions in projection head
|
||||
(defaults to output dimension of model).
|
||||
lr: Learning rate
|
||||
(0.3 x batch_size / 256 for v1, 0.3 x sqrt(batch size) for v2).
|
||||
weight_decay: Weight decay coefficient (1e-6 for v1, 1e-4 for v2).
|
||||
temperature: Temperature used in NT-Xent loss.
|
||||
memory_bank_size: Size of memory bank (0 for v1, 64K for v2).
|
||||
gather_distributed: Gather negatives from all GPUs during distributed
|
||||
training (ignored if memory_bank_size > 0).
|
||||
size: Size of patch to crop.
|
||||
grayscale_weights: Weight vector for grayscale computation, see
|
||||
:class:`~torchgeo.transforms.RandomGrayscale`. Only used when
|
||||
``augmentations=None``. Defaults to average of all bands.
|
||||
augmentations: Data augmentation. Defaults to SimCLR augmentation.
|
||||
|
||||
Raises:
|
||||
AssertionError: If an invalid version of SimCLR is requested.
|
||||
|
||||
Warns:
|
||||
UserWarning: If hyperparameters do not match SimCLR version requested.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Validate hyperparameters
|
||||
assert version in range(1, 3)
|
||||
if version == 1:
|
||||
if layers > 2:
|
||||
warnings.warn("SimCLR v1 only uses 2 layers in its projection head")
|
||||
if memory_bank_size > 0:
|
||||
warnings.warn("SimCLR v1 does not use a memory bank")
|
||||
elif version == 2:
|
||||
if layers == 2:
|
||||
warnings.warn("SimCLR v2 uses 3+ layers in its projection head")
|
||||
if memory_bank_size == 0:
|
||||
warnings.warn("SimCLR v2 uses a memory bank")
|
||||
|
||||
self.save_hyperparameters(ignore=["augmentations"])
|
||||
|
||||
grayscale_weights = grayscale_weights or torch.ones(in_channels)
|
||||
self.augmentations = augmentations or simclr_augmentations(
|
||||
size, grayscale_weights
|
||||
)
|
||||
|
||||
# Create backbone
|
||||
self.backbone = timm.create_model(
|
||||
model, in_chans=in_channels, num_classes=0, pretrained=weights is True
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if weights and weights is not True:
|
||||
if isinstance(weights, WeightsEnum):
|
||||
state_dict = weights.get_state_dict(progress=True)
|
||||
elif os.path.exists(weights):
|
||||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
self.backbone = utils.load_state_dict(self.backbone, state_dict)
|
||||
|
||||
# Create projection head
|
||||
input_dim = self.backbone.num_features
|
||||
if hidden_dim is None:
|
||||
hidden_dim = input_dim
|
||||
if output_dim is None:
|
||||
output_dim = input_dim
|
||||
|
||||
self.projection_head = SimCLRProjectionHead(
|
||||
input_dim, hidden_dim, output_dim, layers
|
||||
)
|
||||
|
||||
# Define loss function
|
||||
self.criterion = NTXentLoss(temperature, memory_bank_size, gather_distributed)
|
||||
|
||||
# Initialize moving average of output
|
||||
self.avg_output_std = 0.0
|
||||
|
||||
# TODO
|
||||
# v1+: add global batch norm
|
||||
# v2: add selective kernels, channel-wise attention mechanism
|
||||
|
||||
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Forward pass of the model.
|
||||
|
||||
Args:
|
||||
x: Mini-batch of images.
|
||||
|
||||
Returns:
|
||||
Output from the backbone and projection head.
|
||||
"""
|
||||
h = self.backbone(x) # shape of batch_size x num_features
|
||||
z = self.projection_head(h)
|
||||
return cast(Tensor, z), cast(Tensor, h)
|
||||
|
||||
def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
|
||||
"""Compute the training loss and additional metrics.
|
||||
|
||||
Args:
|
||||
batch: The output of your DataLoader.
|
||||
batch_idx: Integer displaying index of this batch.
|
||||
|
||||
Returns:
|
||||
The loss tensor.
|
||||
"""
|
||||
x = batch["image"]
|
||||
|
||||
in_channels = self.hparams["in_channels"]
|
||||
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels
|
||||
|
||||
if x.size(1) == in_channels:
|
||||
x1 = x
|
||||
x2 = x
|
||||
else:
|
||||
x1 = x[:, :in_channels]
|
||||
x2 = x[:, in_channels:]
|
||||
|
||||
with torch.no_grad():
|
||||
x1 = self.augmentations(x1)
|
||||
x2 = self.augmentations(x2)
|
||||
|
||||
z1, h1 = self(x1)
|
||||
z2, h2 = self(x2)
|
||||
|
||||
loss = self.criterion(z1, z2)
|
||||
|
||||
# Calculate the mean normalized standard deviation over features dimensions.
|
||||
# If this is << 1 / sqrt(h1.shape[1]), then the model is not learning anything.
|
||||
output = h1.detach()
|
||||
output = F.normalize(output, dim=1)
|
||||
output_std = torch.std(output, dim=0)
|
||||
output_std = torch.mean(output_std, dim=0)
|
||||
self.avg_output_std = 0.9 * self.avg_output_std + (1 - 0.9) * output_std.item()
|
||||
|
||||
self.log("train_ssl_std", self.avg_output_std)
|
||||
self.log("train_loss", loss)
|
||||
|
||||
return cast(Tensor, loss)
|
||||
|
||||
def validation_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
|
||||
"""No-op, does nothing."""
|
||||
|
||||
def test_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
|
||||
"""No-op, does nothing."""
|
||||
# TODO
|
||||
# v2: add distillation step
|
||||
|
||||
def predict_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
|
||||
"""No-op, does nothing."""
|
||||
|
||||
def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]:
|
||||
"""Initialize the optimizer and learning rate scheduler.
|
||||
|
||||
Returns:
|
||||
Optimizer and learning rate scheduler.
|
||||
"""
|
||||
# Original paper uses LARS optimizer, but this is not defined in PyTorch
|
||||
optimizer = Adam(
|
||||
self.parameters(),
|
||||
lr=self.hparams["lr"],
|
||||
weight_decay=self.hparams["weight_decay"],
|
||||
)
|
||||
if self.hparams["version"] == 1:
|
||||
warmup_epochs = 10
|
||||
else:
|
||||
warmup_epochs = int(self.trainer.max_epochs * 0.05)
|
||||
lr_scheduler = SequentialLR(
|
||||
optimizer,
|
||||
schedulers=[
|
||||
LinearLR(optimizer, total_iters=warmup_epochs),
|
||||
CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs),
|
||||
],
|
||||
milestones=[warmup_epochs],
|
||||
)
|
||||
return [optimizer], [lr_scheduler]
|
|
@ -90,6 +90,9 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo
|
|||
expected_in_channels = state_dict[input_module_key + ".weight"].shape[1]
|
||||
|
||||
output_module_key, output_module = list(model.named_children())[-1]
|
||||
if isinstance(output_module, nn.Identity):
|
||||
num_classes = model.num_features
|
||||
else:
|
||||
num_classes = output_module.out_features
|
||||
expected_num_classes = None
|
||||
if output_module_key + ".weight" in state_dict:
|
||||
|
|
4
train.py
4
train.py
|
@ -16,7 +16,7 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
|
|||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from torchgeo.datamodules import MisconfigurationException
|
||||
from torchgeo.trainers import BYOLTask, ObjectDetectionTask
|
||||
from torchgeo.trainers import BYOLTask, ObjectDetectionTask, SimCLRTask
|
||||
|
||||
|
||||
def set_up_omegaconf() -> DictConfig:
|
||||
|
@ -95,7 +95,7 @@ def main(conf: DictConfig) -> None:
|
|||
if isinstance(task, ObjectDetectionTask):
|
||||
monitor_metric = "val_map"
|
||||
mode = "max"
|
||||
elif isinstance(task, BYOLTask):
|
||||
elif isinstance(task, (BYOLTask, SimCLRTask)):
|
||||
monitor_metric = "train_loss"
|
||||
mode = "min"
|
||||
else:
|
||||
|
|
Загрузка…
Ссылка в новой задаче