* add simclr and tests

* add lightly to reqs

* pyupgrade

* Copy things from prior implementation

* Add SimCLR v2 projection head

* Remove kwargs

* Call __init__ explicitly

* Fix mypy and docs

* Can't test newer setuptools

* Default to output dim of model

* Add memory bank

* Ignore erroneous warning

* Fix configs, test SSL4EO

* Fix a few layer bugs

* mypy fixes

* kernel_size must be an integer

* Fix SeCo in_channels

* Get more coverage

* Bump min lightly

* Default logging

* Test weights

* mypy fix

* Grab max_epochs from the trainer

* max_epochs param removed

* Use num_features

* Remove classification head

* SimCLR uses LARS, with Adam as a backup

* Add warnings

* Grab num features directly from model

* Check if identity

* Match timm model design

* Capture warnings

* Fix tests

* Increase coverage

* Fix method name

* More typos

* Escape regex

* Newer setuptools now supported

* New batch norm for every layer

* Rename forward arg

* Clarify usage of weights parameter

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>

* Fix flake8

* Check it

* Use hydra

* Track average L2 normed stdev over features

* SimCLR decays lr to 0

* Add lr warmup

* Fix version access

* Fix LinearLR

* isinstance supports tuples

* Comment capitalization

* Require lightly 1.4.3+

* Require lightly 1.4.3+

* Bump lightly version

* Add RandomGrayscale

* Flake8 fixes

* Placate pydocstyle

* Clarify docs

* Pass correct weights

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Isaac Corley 2023-05-11 11:40:53 -05:00 коммит произвёл GitHub
Родитель 3cc1427c24
Коммит ef7a9ad793
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
16 изменённых файлов: 535 добавлений и 4 удалений

Просмотреть файл

@ -62,6 +62,7 @@ nitpick_ignore = [
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
("py:class", "timm.models.resnet.ResNet"),
("py:class", "timm.models.vision_transformer.VisionTransformer"),
("py:class", "torch.optim.lr_scheduler.LRScheduler"),
("py:class", "torchvision.models._api.WeightsEnum"),
("py:class", "torchvision.models.resnet.ResNet"),
]

Просмотреть файл

@ -26,6 +26,7 @@ dependencies:
- isort[colors]>=5.8
- kornia>=0.6.5
- laspy>=2
- lightly>=1.4.4
- lightning>=1.8
- mypy>=0.900
- nbmake>=1.3.3

Просмотреть файл

@ -5,6 +5,7 @@ setuptools==42.0.0
einops==0.3.0
fiona==1.8.19
kornia==0.6.5
lightly==1.4.4
lightning==1.8.0
matplotlib==3.3.3
numpy==1.19.3

Просмотреть файл

@ -6,6 +6,7 @@ einops==0.6.1
fiona==1.9.3
kornia==0.6.12
lightning==2.0.2
lightly==1.4.4
matplotlib==3.7.1
numpy==1.24.3
pillow==9.5.0

Просмотреть файл

@ -30,6 +30,8 @@ install_requires =
fiona>=1.8.19,<2
# kornia 0.6.5+ required due to change in kornia.augmentation API
kornia>=0.6.5,<0.7
# lightly 1.4.4+ required for MoCo v3 support
lightly>=1.4.4
# lightning 1.8+ is first release
lightning>=1.8,<3
# matplotlib 3.3.3+ required for Python 3.9 wheels

Просмотреть файл

@ -0,0 +1,23 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 4
version: 1
layers: 2
memory_bank_size: 0
datamodule:
_target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
root: "tests/data/chesapeake/cvpr"
download: false
train_splits:
- "de-test"
val_splits:
- "de-test"
test_splits:
- "de-test"
batch_size: 2
patch_size: 64
num_workers: 0
class_set: 5
use_prior_labels: True

Просмотреть файл

@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 3
version: 1
layers: 2
hidden_dim: 8
output_dim: 8
weight_decay: 1e-6
memory_bank_size: 0
datamodule:
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
root: "tests/data/seco"
seasons: 1
batch_size: 2
num_workers: 0

Просмотреть файл

@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 3
version: 2
layers: 4
hidden_dim: 8
output_dim: 8
weight_decay: 1e-4
memory_bank_size: 10
datamodule:
_target_: torchgeo.datamodules.SeasonalContrastS2DataModule
root: "tests/data/seco"
seasons: 2
batch_size: 2
num_workers: 0

Просмотреть файл

@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 13
version: 1
layers: 2
hidden_dim: 8
output_dim: 8
weight_decay: 1e-6
memory_bank_size: 0
datamodule:
_target_: torchgeo.datamodules.SSL4EOS12DataModule
root: "tests/data/ssl4eo/s12"
seasons: 1
batch_size: 2
num_workers: 0

Просмотреть файл

@ -0,0 +1,17 @@
module:
_target_: torchgeo.trainers.SimCLRTask
model: "resnet18"
in_channels: 13
version: 2
layers: 3
hidden_dim: 8
output_dim: 8
weight_decay: 1e-4
memory_bank_size: 10
datamodule:
_target_: torchgeo.datamodules.SSL4EOS12DataModule
root: "tests/data/ssl4eo/s12"
seasons: 2
batch_size: 2
num_workers: 0

Просмотреть файл

@ -33,7 +33,8 @@ class ClassificationTestModel(Module):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=1, kernel_size=1)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(1, num_classes)
self.fc = nn.Linear(1, num_classes) if num_classes else nn.Identity()
self.num_features = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)

Просмотреть файл

@ -0,0 +1,154 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Any
import pytest
import timm
import torch
import torchvision
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from omegaconf import OmegaConf
from torch.nn import Module
from torchvision.models._api import WeightsEnum
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import SimCLRTask
from .test_classification import ClassificationTestModel
def create_model(*args: Any, **kwargs: Any) -> Module:
return ClassificationTestModel(**kwargs)
def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
return state_dict
class TestSimCLRTask:
@pytest.mark.parametrize(
"name",
[
"chesapeake_cvpr_prior_simclr",
"seco_simclr_1",
"seco_simclr_2",
"ssl4eo_s12_simclr_1",
"ssl4eo_s12_simclr_2",
],
)
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)
if name.startswith("ssl4eo_s12"):
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)
# Instantiate datamodule
datamodule = instantiate(conf.datamodule)
# Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
model = instantiate(conf.module)
# Instantiate trainer
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule)
def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"):
SimCLRTask(version=1, layers=3)
with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"):
SimCLRTask(version=1, memory_bank_size=10)
with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"):
SimCLRTask(version=2, layers=2)
with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"):
SimCLRTask(version=2, memory_bank_size=0)
@pytest.fixture(
params=[
weights for model in list_models() for weights in get_model_weights(model)
]
)
def weights(self, request: SubRequest) -> WeightsEnum:
return request.param
@pytest.fixture
def mocked_weights(
self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum
) -> WeightsEnum:
path = tmp_path / f"{weights}.pth"
model = timm.create_model(
weights.meta["model"], in_chans=weights.meta["in_chans"]
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, "url", str(path))
except AttributeError:
monkeypatch.setattr(weights, "url", str(path))
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
def test_weight_file(self, checkpoint: str) -> None:
model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)
def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": mocked_weights.meta["model"],
"weights": mocked_weights,
"in_channels": mocked_weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)
def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": mocked_weights.meta["model"],
"weights": str(mocked_weights),
"in_channels": mocked_weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)
@pytest.mark.slow
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": weights.meta["model"],
"weights": weights,
"in_channels": weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)
@pytest.mark.slow
def test_weight_str_download(self, weights: WeightsEnum) -> None:
model_kwargs: dict[str, Any] = {
"model": weights.meta["model"],
"weights": str(weights),
"in_channels": weights.meta["in_chans"],
}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
SimCLRTask(**model_kwargs)

Просмотреть файл

@ -8,6 +8,7 @@ from .classification import ClassificationTask, MultiLabelClassificationTask
from .detection import ObjectDetectionTask
from .regression import PixelwiseRegressionTask, RegressionTask
from .segmentation import SemanticSegmentationTask
from .simclr import SimCLRTask
__all__ = (
"BYOLTask",
@ -17,4 +18,5 @@ __all__ = (
"PixelwiseRegressionTask",
"RegressionTask",
"SemanticSegmentationTask",
"SimCLRTask",
)

274
torchgeo/trainers/simclr.py Normal file
Просмотреть файл

@ -0,0 +1,274 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""SimCLR trainer for self-supervised learning (SSL)."""
import os
import warnings
from typing import Optional, Union, cast
import kornia.augmentation as K
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightning import LightningModule
from torch import Tensor
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torchvision.models._api import WeightsEnum
import torchgeo.transforms as T
from ..models import get_weight
from . import utils
try:
from torch.optim.lr_scheduler import LRScheduler
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
def simclr_augmentations(size: int, weights: Tensor) -> nn.Module:
"""Data augmentation used by SimCLR.
Args:
size: Size of patch to crop.
weights: Weight vector for grayscale computation.
Returns:
Data augmentation pipeline.
"""
# https://github.com/google-research/simclr/blob/master/data_util.py
ks = size // 10 // 2 * 2 + 1
return K.AugmentationSequential(
K.RandomResizedCrop(size=(size, size), ratio=(0.75, 1.33)),
K.RandomHorizontalFlip(),
K.RandomVerticalFlip(), # added
# Not appropriate for multispectral imagery, seasonal contrast used instead
# K.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8)
T.RandomGrayscale(weights=weights, p=0.2),
K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2)),
data_keys=["input"],
)
class SimCLRTask(LightningModule): # type: ignore[misc]
"""SimCLR: a simple framework for contrastive learning of visual representations.
Reference implementation:
* https://github.com/google-research/simclr
If you use this trainer in your research, please cite the following papers:
* v1: https://arxiv.org/abs/2002.05709
* v2: https://arxiv.org/abs/2006.10029
.. versionadded:: 0.5
"""
def __init__(
self,
model: str = "resnet50",
weights: Optional[Union[WeightsEnum, str, bool]] = None,
in_channels: int = 3,
version: int = 2,
layers: int = 3,
hidden_dim: Optional[int] = None,
output_dim: Optional[int] = None,
lr: float = 4.8,
weight_decay: float = 1e-4,
temperature: float = 0.07,
memory_bank_size: int = 64000,
gather_distributed: bool = False,
size: int = 224,
grayscale_weights: Optional[Tensor] = None,
augmentations: Optional[nn.Module] = None,
) -> None:
"""Initialize a new SimCLRTask instance.
Args:
model: Name of the timm model to use.
weights: Initial model weights. Either a weight enum, the string
representation of a weight enum, True for ImageNet weights, False
or None for random weights, or the path to a saved model state dict.
in_channels: Number of input channels to model.
version: Version of SimCLR, 1--2.
layers: Number of layers in projection head (2 for v1, 3+ for v2).
hidden_dim: Number of hidden dimensions in projection head
(defaults to output dimension of model).
output_dim: Number of output dimensions in projection head
(defaults to output dimension of model).
lr: Learning rate
(0.3 x batch_size / 256 for v1, 0.3 x sqrt(batch size) for v2).
weight_decay: Weight decay coefficient (1e-6 for v1, 1e-4 for v2).
temperature: Temperature used in NT-Xent loss.
memory_bank_size: Size of memory bank (0 for v1, 64K for v2).
gather_distributed: Gather negatives from all GPUs during distributed
training (ignored if memory_bank_size > 0).
size: Size of patch to crop.
grayscale_weights: Weight vector for grayscale computation, see
:class:`~torchgeo.transforms.RandomGrayscale`. Only used when
``augmentations=None``. Defaults to average of all bands.
augmentations: Data augmentation. Defaults to SimCLR augmentation.
Raises:
AssertionError: If an invalid version of SimCLR is requested.
Warns:
UserWarning: If hyperparameters do not match SimCLR version requested.
"""
super().__init__()
# Validate hyperparameters
assert version in range(1, 3)
if version == 1:
if layers > 2:
warnings.warn("SimCLR v1 only uses 2 layers in its projection head")
if memory_bank_size > 0:
warnings.warn("SimCLR v1 does not use a memory bank")
elif version == 2:
if layers == 2:
warnings.warn("SimCLR v2 uses 3+ layers in its projection head")
if memory_bank_size == 0:
warnings.warn("SimCLR v2 uses a memory bank")
self.save_hyperparameters(ignore=["augmentations"])
grayscale_weights = grayscale_weights or torch.ones(in_channels)
self.augmentations = augmentations or simclr_augmentations(
size, grayscale_weights
)
# Create backbone
self.backbone = timm.create_model(
model, in_chans=in_channels, num_classes=0, pretrained=weights is True
)
# Load weights
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.backbone = utils.load_state_dict(self.backbone, state_dict)
# Create projection head
input_dim = self.backbone.num_features
if hidden_dim is None:
hidden_dim = input_dim
if output_dim is None:
output_dim = input_dim
self.projection_head = SimCLRProjectionHead(
input_dim, hidden_dim, output_dim, layers
)
# Define loss function
self.criterion = NTXentLoss(temperature, memory_bank_size, gather_distributed)
# Initialize moving average of output
self.avg_output_std = 0.0
# TODO
# v1+: add global batch norm
# v2: add selective kernels, channel-wise attention mechanism
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""Forward pass of the model.
Args:
x: Mini-batch of images.
Returns:
Output from the backbone and projection head.
"""
h = self.backbone(x) # shape of batch_size x num_features
z = self.projection_head(h)
return cast(Tensor, z), cast(Tensor, h)
def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
"""Compute the training loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
Returns:
The loss tensor.
"""
x = batch["image"]
in_channels = self.hparams["in_channels"]
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels
if x.size(1) == in_channels:
x1 = x
x2 = x
else:
x1 = x[:, :in_channels]
x2 = x[:, in_channels:]
with torch.no_grad():
x1 = self.augmentations(x1)
x2 = self.augmentations(x2)
z1, h1 = self(x1)
z2, h2 = self(x2)
loss = self.criterion(z1, z2)
# Calculate the mean normalized standard deviation over features dimensions.
# If this is << 1 / sqrt(h1.shape[1]), then the model is not learning anything.
output = h1.detach()
output = F.normalize(output, dim=1)
output_std = torch.std(output, dim=0)
output_std = torch.mean(output_std, dim=0)
self.avg_output_std = 0.9 * self.avg_output_std + (1 - 0.9) * output_std.item()
self.log("train_ssl_std", self.avg_output_std)
self.log("train_loss", loss)
return cast(Tensor, loss)
def validation_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""
def test_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""
# TODO
# v2: add distillation step
def predict_step(self, batch: dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""
def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]:
"""Initialize the optimizer and learning rate scheduler.
Returns:
Optimizer and learning rate scheduler.
"""
# Original paper uses LARS optimizer, but this is not defined in PyTorch
optimizer = Adam(
self.parameters(),
lr=self.hparams["lr"],
weight_decay=self.hparams["weight_decay"],
)
if self.hparams["version"] == 1:
warmup_epochs = 10
else:
warmup_epochs = int(self.trainer.max_epochs * 0.05)
lr_scheduler = SequentialLR(
optimizer,
schedulers=[
LinearLR(optimizer, total_iters=warmup_epochs),
CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs),
],
milestones=[warmup_epochs],
)
return [optimizer], [lr_scheduler]

Просмотреть файл

@ -90,6 +90,9 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo
expected_in_channels = state_dict[input_module_key + ".weight"].shape[1]
output_module_key, output_module = list(model.named_children())[-1]
if isinstance(output_module, nn.Identity):
num_classes = model.num_features
else:
num_classes = output_module.out_features
expected_num_classes = None
if output_module_key + ".weight" in state_dict:

Просмотреть файл

@ -16,7 +16,7 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from omegaconf import DictConfig, OmegaConf
from torchgeo.datamodules import MisconfigurationException
from torchgeo.trainers import BYOLTask, ObjectDetectionTask
from torchgeo.trainers import BYOLTask, ObjectDetectionTask, SimCLRTask
def set_up_omegaconf() -> DictConfig:
@ -95,7 +95,7 @@ def main(conf: DictConfig) -> None:
if isinstance(task, ObjectDetectionTask):
monitor_metric = "val_map"
mode = "max"
elif isinstance(task, BYOLTask):
elif isinstance(task, (BYOLTask, SimCLRTask)):
monitor_metric = "train_loss"
mode = "min"
else: