зеркало из https://github.com/microsoft/torchgeo.git
SimCLR: switch from Adam to LARS (#2196)
* SimCLR: switch from Adam to LARS * Bump minimum lightly version
This commit is contained in:
Родитель
8088267861
Коммит
ecb07ee898
|
@ -42,10 +42,10 @@ dependencies = [
|
|||
"fiona>=1.8.21",
|
||||
# kornia 0.7.3+ required for instance segmentation support in AugmentationSequential
|
||||
"kornia>=0.7.3",
|
||||
# lightly 1.4.4+ required for MoCo v3 support
|
||||
# lightly 1.4.5+ required for LARS optimizer
|
||||
# lightly 1.4.26 is incompatible with the version of timm required by smp
|
||||
# https://github.com/microsoft/torchgeo/issues/1824
|
||||
"lightly>=1.4.4,!=1.4.26",
|
||||
"lightly>=1.4.5,!=1.4.26",
|
||||
# lightning 2+ required for LightningCLI args + sys.argv support
|
||||
# lightning 2.3+ contains known bugs related to YAML parsing
|
||||
# https://github.com/Lightning-AI/pytorch-lightning/issues/19977
|
||||
|
|
|
@ -5,7 +5,7 @@ setuptools==61.0.0
|
|||
einops==0.3.0
|
||||
fiona==1.8.21
|
||||
kornia==0.7.3
|
||||
lightly==1.4.4
|
||||
lightly==1.4.5
|
||||
lightning[pytorch-extra]==2.0.0
|
||||
matplotlib==3.5.0
|
||||
numpy==1.21.2
|
||||
|
|
|
@ -15,8 +15,8 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from lightly.loss import NTXentLoss
|
||||
from lightly.models.modules import SimCLRProjectionHead
|
||||
from lightly.utils.lars import LARS
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam # type: ignore[attr-defined]
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
||||
from torchvision.models._api import WeightsEnum
|
||||
|
||||
|
@ -80,6 +80,7 @@ class SimCLRTask(BaseTask):
|
|||
hidden_dim: int | None = None,
|
||||
output_dim: int | None = None,
|
||||
lr: float = 4.8,
|
||||
momentum: float = 0.9,
|
||||
weight_decay: float = 1e-4,
|
||||
temperature: float = 0.07,
|
||||
memory_bank_size: int = 64000,
|
||||
|
@ -90,6 +91,9 @@ class SimCLRTask(BaseTask):
|
|||
) -> None:
|
||||
"""Initialize a new SimCLRTask instance.
|
||||
|
||||
.. versionadded:: 0.6
|
||||
The *momentum* parameter.
|
||||
|
||||
Args:
|
||||
model: Name of the `timm
|
||||
<https://huggingface.co/docs/timm/reference/models>`__ model to use.
|
||||
|
@ -104,6 +108,7 @@ class SimCLRTask(BaseTask):
|
|||
output_dim: Number of output dimensions in projection head
|
||||
(defaults to output dimension of model).
|
||||
lr: Learning rate (0.3 x batch_size / 256 is recommended).
|
||||
momentum: Momentum factor.
|
||||
weight_decay: Weight decay coefficient (1e-6 for v1, 1e-4 for v2).
|
||||
temperature: Temperature used in NT-Xent loss.
|
||||
memory_bank_size: Size of memory bank (0 for v1, 64K for v2).
|
||||
|
@ -283,13 +288,16 @@ class SimCLRTask(BaseTask):
|
|||
) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig':
|
||||
"""Initialize the optimizer and learning rate scheduler.
|
||||
|
||||
.. versionchanged:: 0.6
|
||||
Changed from Adam to LARS optimizer.
|
||||
|
||||
Returns:
|
||||
Optimizer and learning rate scheduler.
|
||||
"""
|
||||
# Original paper uses LARS optimizer, but this is not defined in PyTorch
|
||||
optimizer = Adam(
|
||||
optimizer = LARS(
|
||||
self.parameters(),
|
||||
lr=self.hparams['lr'],
|
||||
momentum=self.hparams['momentum'],
|
||||
weight_decay=self.hparams['weight_decay'],
|
||||
)
|
||||
max_epochs = 200
|
||||
|
|
Загрузка…
Ссылка в новой задаче