SimCLR: switch from Adam to LARS (#2196)

* SimCLR: switch from Adam to LARS

* Bump minimum lightly version
This commit is contained in:
Adam J. Stewart 2024-08-01 12:49:43 +02:00 коммит произвёл GitHub
Родитель 8088267861
Коммит ecb07ee898
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 14 добавлений и 6 удалений

Просмотреть файл

@ -42,10 +42,10 @@ dependencies = [
"fiona>=1.8.21",
# kornia 0.7.3+ required for instance segmentation support in AugmentationSequential
"kornia>=0.7.3",
# lightly 1.4.4+ required for MoCo v3 support
# lightly 1.4.5+ required for LARS optimizer
# lightly 1.4.26 is incompatible with the version of timm required by smp
# https://github.com/microsoft/torchgeo/issues/1824
"lightly>=1.4.4,!=1.4.26",
"lightly>=1.4.5,!=1.4.26",
# lightning 2+ required for LightningCLI args + sys.argv support
# lightning 2.3+ contains known bugs related to YAML parsing
# https://github.com/Lightning-AI/pytorch-lightning/issues/19977

Просмотреть файл

@ -5,7 +5,7 @@ setuptools==61.0.0
einops==0.3.0
fiona==1.8.21
kornia==0.7.3
lightly==1.4.4
lightly==1.4.5
lightning[pytorch-extra]==2.0.0
matplotlib==3.5.0
numpy==1.21.2

Просмотреть файл

@ -15,8 +15,8 @@ import torch.nn as nn
import torch.nn.functional as F
from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.utils.lars import LARS
from torch import Tensor
from torch.optim import Adam # type: ignore[attr-defined]
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torchvision.models._api import WeightsEnum
@ -80,6 +80,7 @@ class SimCLRTask(BaseTask):
hidden_dim: int | None = None,
output_dim: int | None = None,
lr: float = 4.8,
momentum: float = 0.9,
weight_decay: float = 1e-4,
temperature: float = 0.07,
memory_bank_size: int = 64000,
@ -90,6 +91,9 @@ class SimCLRTask(BaseTask):
) -> None:
"""Initialize a new SimCLRTask instance.
.. versionadded:: 0.6
The *momentum* parameter.
Args:
model: Name of the `timm
<https://huggingface.co/docs/timm/reference/models>`__ model to use.
@ -104,6 +108,7 @@ class SimCLRTask(BaseTask):
output_dim: Number of output dimensions in projection head
(defaults to output dimension of model).
lr: Learning rate (0.3 x batch_size / 256 is recommended).
momentum: Momentum factor.
weight_decay: Weight decay coefficient (1e-6 for v1, 1e-4 for v2).
temperature: Temperature used in NT-Xent loss.
memory_bank_size: Size of memory bank (0 for v1, 64K for v2).
@ -283,13 +288,16 @@ class SimCLRTask(BaseTask):
) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig':
"""Initialize the optimizer and learning rate scheduler.
.. versionchanged:: 0.6
Changed from Adam to LARS optimizer.
Returns:
Optimizer and learning rate scheduler.
"""
# Original paper uses LARS optimizer, but this is not defined in PyTorch
optimizer = Adam(
optimizer = LARS(
self.parameters(),
lr=self.hparams['lr'],
momentum=self.hparams['momentum'],
weight_decay=self.hparams['weight_decay'],
)
max_epochs = 200