From ef7a9ad79359053c1ba0d38343795e936ee7dabb Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 11 May 2023 11:40:53 -0500 Subject: [PATCH] Add SimCLR trainer (#1252) * add simclr and tests * add lightly to reqs * pyupgrade * Copy things from prior implementation * Add SimCLR v2 projection head * Remove kwargs * Call __init__ explicitly * Fix mypy and docs * Can't test newer setuptools * Default to output dim of model * Add memory bank * Ignore erroneous warning * Fix configs, test SSL4EO * Fix a few layer bugs * mypy fixes * kernel_size must be an integer * Fix SeCo in_channels * Get more coverage * Bump min lightly * Default logging * Test weights * mypy fix * Grab max_epochs from the trainer * max_epochs param removed * Use num_features * Remove classification head * SimCLR uses LARS, with Adam as a backup * Add warnings * Grab num features directly from model * Check if identity * Match timm model design * Capture warnings * Fix tests * Increase coverage * Fix method name * More typos * Escape regex * Newer setuptools now supported * New batch norm for every layer * Rename forward arg * Clarify usage of weights parameter Co-authored-by: Caleb Robinson * Fix flake8 * Check it * Use hydra * Track average L2 normed stdev over features * SimCLR decays lr to 0 * Add lr warmup * Fix version access * Fix LinearLR * isinstance supports tuples * Comment capitalization * Require lightly 1.4.3+ * Require lightly 1.4.3+ * Bump lightly version * Add RandomGrayscale * Flake8 fixes * Placate pydocstyle * Clarify docs * Pass correct weights --------- Co-authored-by: Adam J. Stewart Co-authored-by: Caleb Robinson --- docs/conf.py | 1 + environment.yml | 1 + requirements/min-reqs.old | 1 + requirements/required.txt | 1 + setup.cfg | 2 + tests/conf/chesapeake_cvpr_prior_simclr.yaml | 23 ++ tests/conf/seco_simclr_1.yaml | 17 ++ tests/conf/seco_simclr_2.yaml | 17 ++ tests/conf/ssl4eo_s12_simclr_1.yaml | 17 ++ tests/conf/ssl4eo_s12_simclr_2.yaml | 17 ++ tests/trainers/test_classification.py | 3 +- tests/trainers/test_simclr.py | 154 +++++++++++ torchgeo/trainers/__init__.py | 2 + torchgeo/trainers/simclr.py | 274 +++++++++++++++++++ torchgeo/trainers/utils.py | 5 +- train.py | 4 +- 16 files changed, 535 insertions(+), 4 deletions(-) create mode 100644 tests/conf/chesapeake_cvpr_prior_simclr.yaml create mode 100644 tests/conf/seco_simclr_1.yaml create mode 100644 tests/conf/seco_simclr_2.yaml create mode 100644 tests/conf/ssl4eo_s12_simclr_1.yaml create mode 100644 tests/conf/ssl4eo_s12_simclr_2.yaml create mode 100644 tests/trainers/test_simclr.py create mode 100644 torchgeo/trainers/simclr.py diff --git a/docs/conf.py b/docs/conf.py index 9b2ec3ff7..44514ff94 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,6 +62,7 @@ nitpick_ignore = [ ("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"), ("py:class", "timm.models.resnet.ResNet"), ("py:class", "timm.models.vision_transformer.VisionTransformer"), + ("py:class", "torch.optim.lr_scheduler.LRScheduler"), ("py:class", "torchvision.models._api.WeightsEnum"), ("py:class", "torchvision.models.resnet.ResNet"), ] diff --git a/environment.yml b/environment.yml index 625a33697..e99d50b75 100644 --- a/environment.yml +++ b/environment.yml @@ -26,6 +26,7 @@ dependencies: - isort[colors]>=5.8 - kornia>=0.6.5 - laspy>=2 + - lightly>=1.4.4 - lightning>=1.8 - mypy>=0.900 - nbmake>=1.3.3 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index 38c842a9e..c1485fc34 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -5,6 +5,7 @@ setuptools==42.0.0 einops==0.3.0 fiona==1.8.19 kornia==0.6.5 +lightly==1.4.4 lightning==1.8.0 matplotlib==3.3.3 numpy==1.19.3 diff --git a/requirements/required.txt b/requirements/required.txt index 83a96915b..2a97f1732 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -6,6 +6,7 @@ einops==0.6.1 fiona==1.9.3 kornia==0.6.12 lightning==2.0.2 +lightly==1.4.4 matplotlib==3.7.1 numpy==1.24.3 pillow==9.5.0 diff --git a/setup.cfg b/setup.cfg index 3c24e2918..e5c5f9dc3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,6 +30,8 @@ install_requires = fiona>=1.8.19,<2 # kornia 0.6.5+ required due to change in kornia.augmentation API kornia>=0.6.5,<0.7 + # lightly 1.4.4+ required for MoCo v3 support + lightly>=1.4.4 # lightning 1.8+ is first release lightning>=1.8,<3 # matplotlib 3.3.3+ required for Python 3.9 wheels diff --git a/tests/conf/chesapeake_cvpr_prior_simclr.yaml b/tests/conf/chesapeake_cvpr_prior_simclr.yaml new file mode 100644 index 000000000..731e9bf8b --- /dev/null +++ b/tests/conf/chesapeake_cvpr_prior_simclr.yaml @@ -0,0 +1,23 @@ +module: + _target_: torchgeo.trainers.SimCLRTask + model: "resnet18" + in_channels: 4 + version: 1 + layers: 2 + memory_bank_size: 0 + +datamodule: + _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule + root: "tests/data/chesapeake/cvpr" + download: false + train_splits: + - "de-test" + val_splits: + - "de-test" + test_splits: + - "de-test" + batch_size: 2 + patch_size: 64 + num_workers: 0 + class_set: 5 + use_prior_labels: True diff --git a/tests/conf/seco_simclr_1.yaml b/tests/conf/seco_simclr_1.yaml new file mode 100644 index 000000000..ec0fa60d0 --- /dev/null +++ b/tests/conf/seco_simclr_1.yaml @@ -0,0 +1,17 @@ +module: + _target_: torchgeo.trainers.SimCLRTask + model: "resnet18" + in_channels: 3 + version: 1 + layers: 2 + hidden_dim: 8 + output_dim: 8 + weight_decay: 1e-6 + memory_bank_size: 0 + +datamodule: + _target_: torchgeo.datamodules.SeasonalContrastS2DataModule + root: "tests/data/seco" + seasons: 1 + batch_size: 2 + num_workers: 0 diff --git a/tests/conf/seco_simclr_2.yaml b/tests/conf/seco_simclr_2.yaml new file mode 100644 index 000000000..22e00585c --- /dev/null +++ b/tests/conf/seco_simclr_2.yaml @@ -0,0 +1,17 @@ +module: + _target_: torchgeo.trainers.SimCLRTask + model: "resnet18" + in_channels: 3 + version: 2 + layers: 4 + hidden_dim: 8 + output_dim: 8 + weight_decay: 1e-4 + memory_bank_size: 10 + +datamodule: + _target_: torchgeo.datamodules.SeasonalContrastS2DataModule + root: "tests/data/seco" + seasons: 2 + batch_size: 2 + num_workers: 0 diff --git a/tests/conf/ssl4eo_s12_simclr_1.yaml b/tests/conf/ssl4eo_s12_simclr_1.yaml new file mode 100644 index 000000000..50636ad5e --- /dev/null +++ b/tests/conf/ssl4eo_s12_simclr_1.yaml @@ -0,0 +1,17 @@ +module: + _target_: torchgeo.trainers.SimCLRTask + model: "resnet18" + in_channels: 13 + version: 1 + layers: 2 + hidden_dim: 8 + output_dim: 8 + weight_decay: 1e-6 + memory_bank_size: 0 + +datamodule: + _target_: torchgeo.datamodules.SSL4EOS12DataModule + root: "tests/data/ssl4eo/s12" + seasons: 1 + batch_size: 2 + num_workers: 0 diff --git a/tests/conf/ssl4eo_s12_simclr_2.yaml b/tests/conf/ssl4eo_s12_simclr_2.yaml new file mode 100644 index 000000000..ecf72369c --- /dev/null +++ b/tests/conf/ssl4eo_s12_simclr_2.yaml @@ -0,0 +1,17 @@ +module: + _target_: torchgeo.trainers.SimCLRTask + model: "resnet18" + in_channels: 13 + version: 2 + layers: 3 + hidden_dim: 8 + output_dim: 8 + weight_decay: 1e-4 + memory_bank_size: 10 + +datamodule: + _target_: torchgeo.datamodules.SSL4EOS12DataModule + root: "tests/data/ssl4eo/s12" + seasons: 2 + batch_size: 2 + num_workers: 0 diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 49628d7a8..91c1ee9e5 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -33,7 +33,8 @@ class ClassificationTestModel(Module): super().__init__() self.conv1 = nn.Conv2d(in_channels=in_chans, out_channels=1, kernel_size=1) self.pool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(1, num_classes) + self.fc = nn.Linear(1, num_classes) if num_classes else nn.Identity() + self.num_features = 1 def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py new file mode 100644 index 000000000..e7af9e8dd --- /dev/null +++ b/tests/trainers/test_simclr.py @@ -0,0 +1,154 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from pathlib import Path +from typing import Any + +import pytest +import timm +import torch +import torchvision +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch +from hydra.utils import instantiate +from lightning.pytorch import Trainer +from omegaconf import OmegaConf +from torch.nn import Module +from torchvision.models._api import WeightsEnum + +from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2 +from torchgeo.models import get_model_weights, list_models +from torchgeo.trainers import SimCLRTask + +from .test_classification import ClassificationTestModel + + +def create_model(*args: Any, **kwargs: Any) -> Module: + return ClassificationTestModel(**kwargs) + + +def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: + state_dict: dict[str, Any] = torch.load(url) + return state_dict + + +class TestSimCLRTask: + @pytest.mark.parametrize( + "name", + [ + "chesapeake_cvpr_prior_simclr", + "seco_simclr_1", + "seco_simclr_2", + "ssl4eo_s12_simclr_1", + "ssl4eo_s12_simclr_2", + ], + ) + def test_trainer( + self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool + ) -> None: + conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) + + if name.startswith("seco"): + monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2) + + if name.startswith("ssl4eo_s12"): + monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2) + + # Instantiate datamodule + datamodule = instantiate(conf.datamodule) + + # Instantiate model + monkeypatch.setattr(timm, "create_model", create_model) + model = instantiate(conf.module) + + # Instantiate trainer + trainer = Trainer( + accelerator="cpu", + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + trainer.fit(model=model, datamodule=datamodule) + + def test_version_warnings(self) -> None: + with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"): + SimCLRTask(version=1, layers=3) + with pytest.warns(UserWarning, match="SimCLR v1 does not use a memory bank"): + SimCLRTask(version=1, memory_bank_size=10) + with pytest.warns(UserWarning, match=r"SimCLR v2 uses 3\+ layers"): + SimCLRTask(version=2, layers=2) + with pytest.warns(UserWarning, match="SimCLR v2 uses a memory bank"): + SimCLRTask(version=2, memory_bank_size=0) + + @pytest.fixture( + params=[ + weights for model in list_models() for weights in get_model_weights(model) + ] + ) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: + path = tmp_path / f"{weights}.pth" + model = timm.create_model( + weights.meta["model"], in_chans=weights.meta["in_chans"] + ) + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, "url", str(path)) + except AttributeError: + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights + + def test_weight_file(self, checkpoint: str) -> None: + model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint} + match = "num classes .* != num classes in pretrained model" + with pytest.warns(UserWarning, match=match): + SimCLRTask(**model_kwargs) + + def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: + model_kwargs: dict[str, Any] = { + "model": mocked_weights.meta["model"], + "weights": mocked_weights, + "in_channels": mocked_weights.meta["in_chans"], + } + match = "num classes .* != num classes in pretrained model" + with pytest.warns(UserWarning, match=match): + SimCLRTask(**model_kwargs) + + def test_weight_str(self, mocked_weights: WeightsEnum) -> None: + model_kwargs: dict[str, Any] = { + "model": mocked_weights.meta["model"], + "weights": str(mocked_weights), + "in_channels": mocked_weights.meta["in_chans"], + } + match = "num classes .* != num classes in pretrained model" + with pytest.warns(UserWarning, match=match): + SimCLRTask(**model_kwargs) + + @pytest.mark.slow + def test_weight_enum_download(self, weights: WeightsEnum) -> None: + model_kwargs: dict[str, Any] = { + "model": weights.meta["model"], + "weights": weights, + "in_channels": weights.meta["in_chans"], + } + match = "num classes .* != num classes in pretrained model" + with pytest.warns(UserWarning, match=match): + SimCLRTask(**model_kwargs) + + @pytest.mark.slow + def test_weight_str_download(self, weights: WeightsEnum) -> None: + model_kwargs: dict[str, Any] = { + "model": weights.meta["model"], + "weights": str(weights), + "in_channels": weights.meta["in_chans"], + } + match = "num classes .* != num classes in pretrained model" + with pytest.warns(UserWarning, match=match): + SimCLRTask(**model_kwargs) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index e1db43fd3..bc1f412d4 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -8,6 +8,7 @@ from .classification import ClassificationTask, MultiLabelClassificationTask from .detection import ObjectDetectionTask from .regression import PixelwiseRegressionTask, RegressionTask from .segmentation import SemanticSegmentationTask +from .simclr import SimCLRTask __all__ = ( "BYOLTask", @@ -17,4 +18,5 @@ __all__ = ( "PixelwiseRegressionTask", "RegressionTask", "SemanticSegmentationTask", + "SimCLRTask", ) diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py new file mode 100644 index 000000000..81fcafe30 --- /dev/null +++ b/torchgeo/trainers/simclr.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SimCLR trainer for self-supervised learning (SSL).""" + +import os +import warnings +from typing import Optional, Union, cast + +import kornia.augmentation as K +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightly.loss import NTXentLoss +from lightly.models.modules import SimCLRProjectionHead +from lightning import LightningModule +from torch import Tensor +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR +from torchvision.models._api import WeightsEnum + +import torchgeo.transforms as T + +from ..models import get_weight +from . import utils + +try: + from torch.optim.lr_scheduler import LRScheduler +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + + +def simclr_augmentations(size: int, weights: Tensor) -> nn.Module: + """Data augmentation used by SimCLR. + + Args: + size: Size of patch to crop. + weights: Weight vector for grayscale computation. + + Returns: + Data augmentation pipeline. + """ + # https://github.com/google-research/simclr/blob/master/data_util.py + ks = size // 10 // 2 * 2 + 1 + return K.AugmentationSequential( + K.RandomResizedCrop(size=(size, size), ratio=(0.75, 1.33)), + K.RandomHorizontalFlip(), + K.RandomVerticalFlip(), # added + # Not appropriate for multispectral imagery, seasonal contrast used instead + # K.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2, p=0.8) + T.RandomGrayscale(weights=weights, p=0.2), + K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2)), + data_keys=["input"], + ) + + +class SimCLRTask(LightningModule): # type: ignore[misc] + """SimCLR: a simple framework for contrastive learning of visual representations. + + Reference implementation: + + * https://github.com/google-research/simclr + + If you use this trainer in your research, please cite the following papers: + + * v1: https://arxiv.org/abs/2002.05709 + * v2: https://arxiv.org/abs/2006.10029 + + .. versionadded:: 0.5 + """ + + def __init__( + self, + model: str = "resnet50", + weights: Optional[Union[WeightsEnum, str, bool]] = None, + in_channels: int = 3, + version: int = 2, + layers: int = 3, + hidden_dim: Optional[int] = None, + output_dim: Optional[int] = None, + lr: float = 4.8, + weight_decay: float = 1e-4, + temperature: float = 0.07, + memory_bank_size: int = 64000, + gather_distributed: bool = False, + size: int = 224, + grayscale_weights: Optional[Tensor] = None, + augmentations: Optional[nn.Module] = None, + ) -> None: + """Initialize a new SimCLRTask instance. + + Args: + model: Name of the timm model to use. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False + or None for random weights, or the path to a saved model state dict. + in_channels: Number of input channels to model. + version: Version of SimCLR, 1--2. + layers: Number of layers in projection head (2 for v1, 3+ for v2). + hidden_dim: Number of hidden dimensions in projection head + (defaults to output dimension of model). + output_dim: Number of output dimensions in projection head + (defaults to output dimension of model). + lr: Learning rate + (0.3 x batch_size / 256 for v1, 0.3 x sqrt(batch size) for v2). + weight_decay: Weight decay coefficient (1e-6 for v1, 1e-4 for v2). + temperature: Temperature used in NT-Xent loss. + memory_bank_size: Size of memory bank (0 for v1, 64K for v2). + gather_distributed: Gather negatives from all GPUs during distributed + training (ignored if memory_bank_size > 0). + size: Size of patch to crop. + grayscale_weights: Weight vector for grayscale computation, see + :class:`~torchgeo.transforms.RandomGrayscale`. Only used when + ``augmentations=None``. Defaults to average of all bands. + augmentations: Data augmentation. Defaults to SimCLR augmentation. + + Raises: + AssertionError: If an invalid version of SimCLR is requested. + + Warns: + UserWarning: If hyperparameters do not match SimCLR version requested. + """ + super().__init__() + + # Validate hyperparameters + assert version in range(1, 3) + if version == 1: + if layers > 2: + warnings.warn("SimCLR v1 only uses 2 layers in its projection head") + if memory_bank_size > 0: + warnings.warn("SimCLR v1 does not use a memory bank") + elif version == 2: + if layers == 2: + warnings.warn("SimCLR v2 uses 3+ layers in its projection head") + if memory_bank_size == 0: + warnings.warn("SimCLR v2 uses a memory bank") + + self.save_hyperparameters(ignore=["augmentations"]) + + grayscale_weights = grayscale_weights or torch.ones(in_channels) + self.augmentations = augmentations or simclr_augmentations( + size, grayscale_weights + ) + + # Create backbone + self.backbone = timm.create_model( + model, in_chans=in_channels, num_classes=0, pretrained=weights is True + ) + + # Load weights + if weights and weights is not True: + if isinstance(weights, WeightsEnum): + state_dict = weights.get_state_dict(progress=True) + elif os.path.exists(weights): + _, state_dict = utils.extract_backbone(weights) + else: + state_dict = get_weight(weights).get_state_dict(progress=True) + self.backbone = utils.load_state_dict(self.backbone, state_dict) + + # Create projection head + input_dim = self.backbone.num_features + if hidden_dim is None: + hidden_dim = input_dim + if output_dim is None: + output_dim = input_dim + + self.projection_head = SimCLRProjectionHead( + input_dim, hidden_dim, output_dim, layers + ) + + # Define loss function + self.criterion = NTXentLoss(temperature, memory_bank_size, gather_distributed) + + # Initialize moving average of output + self.avg_output_std = 0.0 + + # TODO + # v1+: add global batch norm + # v2: add selective kernels, channel-wise attention mechanism + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass of the model. + + Args: + x: Mini-batch of images. + + Returns: + Output from the backbone and projection head. + """ + h = self.backbone(x) # shape of batch_size x num_features + z = self.projection_head(h) + return cast(Tensor, z), cast(Tensor, h) + + def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor: + """Compute the training loss and additional metrics. + + Args: + batch: The output of your DataLoader. + batch_idx: Integer displaying index of this batch. + + Returns: + The loss tensor. + """ + x = batch["image"] + + in_channels = self.hparams["in_channels"] + assert x.size(1) == in_channels or x.size(1) == 2 * in_channels + + if x.size(1) == in_channels: + x1 = x + x2 = x + else: + x1 = x[:, :in_channels] + x2 = x[:, in_channels:] + + with torch.no_grad(): + x1 = self.augmentations(x1) + x2 = self.augmentations(x2) + + z1, h1 = self(x1) + z2, h2 = self(x2) + + loss = self.criterion(z1, z2) + + # Calculate the mean normalized standard deviation over features dimensions. + # If this is << 1 / sqrt(h1.shape[1]), then the model is not learning anything. + output = h1.detach() + output = F.normalize(output, dim=1) + output_std = torch.std(output, dim=0) + output_std = torch.mean(output_std, dim=0) + self.avg_output_std = 0.9 * self.avg_output_std + (1 - 0.9) * output_std.item() + + self.log("train_ssl_std", self.avg_output_std) + self.log("train_loss", loss) + + return cast(Tensor, loss) + + def validation_step(self, batch: dict[str, Tensor], batch_idx: int) -> None: + """No-op, does nothing.""" + + def test_step(self, batch: dict[str, Tensor], batch_idx: int) -> None: + """No-op, does nothing.""" + # TODO + # v2: add distillation step + + def predict_step(self, batch: dict[str, Tensor], batch_idx: int) -> None: + """No-op, does nothing.""" + + def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]: + """Initialize the optimizer and learning rate scheduler. + + Returns: + Optimizer and learning rate scheduler. + """ + # Original paper uses LARS optimizer, but this is not defined in PyTorch + optimizer = Adam( + self.parameters(), + lr=self.hparams["lr"], + weight_decay=self.hparams["weight_decay"], + ) + if self.hparams["version"] == 1: + warmup_epochs = 10 + else: + warmup_epochs = int(self.trainer.max_epochs * 0.05) + lr_scheduler = SequentialLR( + optimizer, + schedulers=[ + LinearLR(optimizer, total_iters=warmup_epochs), + CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs), + ], + milestones=[warmup_epochs], + ) + return [optimizer], [lr_scheduler] diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index 530a8e33c..e1b6678ed 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -90,7 +90,10 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo expected_in_channels = state_dict[input_module_key + ".weight"].shape[1] output_module_key, output_module = list(model.named_children())[-1] - num_classes = output_module.out_features + if isinstance(output_module, nn.Identity): + num_classes = model.num_features + else: + num_classes = output_module.out_features expected_num_classes = None if output_module_key + ".weight" in state_dict: expected_num_classes = state_dict[output_module_key + ".weight"].shape[0] diff --git a/train.py b/train.py index 0722a3d7b..0647355f8 100755 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from omegaconf import DictConfig, OmegaConf from torchgeo.datamodules import MisconfigurationException -from torchgeo.trainers import BYOLTask, ObjectDetectionTask +from torchgeo.trainers import BYOLTask, ObjectDetectionTask, SimCLRTask def set_up_omegaconf() -> DictConfig: @@ -95,7 +95,7 @@ def main(conf: DictConfig) -> None: if isinstance(task, ObjectDetectionTask): monitor_metric = "val_map" mode = "max" - elif isinstance(task, BYOLTask): + elif isinstance(task, (BYOLTask, SimCLRTask)): monitor_metric = "train_loss" mode = "min" else: