Adding timm models to the classification task and refactoring (#210)

* Adding support for VGG models in the classification task
Refactoring the logic for replacing the first conv layer in a network

* Fix formatting

* Testing the stuff

* mypy with torch is such a waste of time

* Update torchgeo/trainers/utils.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Update torchgeo/trainers/utils.py

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>

* Adding timm dependency

* Incorporate timm into ClassificationTask

* Fix tests?

* Formatting

* Allow for overriding stride and padding in `reinit_initial_conv_layer`

* Putting back some stuff I accidentally overwrote in the rebase

* Bug

* Format

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Caleb Robinson 2021-11-03 15:34:33 -07:00 коммит произвёл GitHub
Родитель de48bd54ad
Коммит 60674cc200
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 124 добавлений и 61 удалений

Просмотреть файл

@ -41,4 +41,5 @@ dependencies:
- segmentation-models-pytorch>=0.2
- setuptools>=30.4
- sphinx>=3
- timm>=0.2.1
- torchmetrics

Просмотреть файл

@ -77,6 +77,8 @@ train =
scikit-learn>=0.18
# segmentation-models-pytorch 0.2+ required for smp.losses module
segmentation-models-pytorch>=0.2
# timm 0.2.1+ required for `features_only` option in create_model
timm>=0.2.1
torchmetrics
# Optional developer requirements
style =

Просмотреть файл

@ -100,17 +100,21 @@ class TestClassificationTask:
@pytest.fixture(
scope="class",
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]),
params=zip(
["ce", "jaccard", "focal"],
["imagenet", "random", "random"],
["resnet18", "hrnet_w18_small_v2", "tf_efficientnet_b0"],
),
)
def config(
self, request: SubRequest, datamodule: DummyDataModule
) -> Dict[str, Any]:
loss, weights, model = request.param
task_args = {}
task_args["classification_model"] = "resnet18"
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
loss, weights = request.param
task_args["classification_model"] = model
task_args["learning_rate"] = 3e-4
task_args["learning_rate_schedule_patience"] = 6
task_args["in_channels"] = datamodule.num_channels
task_args["loss"] = loss
task_args["weights"] = weights
return task_args
@ -157,7 +161,7 @@ class TestClassificationTask:
def test_invalid_model(self, config: Dict[str, Any]) -> None:
config["classification_model"] = "invalid_model"
error_message = "Model type 'invalid_model' is not valid."
error_message = "Model type 'invalid_model' is not a valid timm model."
with pytest.raises(ValueError, match=error_message):
ClassificationTask(**config)

Просмотреть файл

@ -10,7 +10,11 @@ import torch
import torch.nn as nn
from torch.nn.modules import Module
from torchgeo.trainers.utils import extract_encoder, load_state_dict
from torchgeo.trainers.utils import (
extract_encoder,
load_state_dict,
reinit_initial_conv_layer,
)
class FakeExperiment(object):
@ -87,3 +91,21 @@ def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None
)
with pytest.warns(UserWarning, match=warning):
model = load_state_dict(model, state_dict)
def test_reinit_initial_conv_layer() -> None:
conv_layer = nn.Conv2d( # type: ignore[attr-defined]
3, 5, kernel_size=3, stride=2, padding=1, bias=True
)
initial_weights = conv_layer.weight.data.clone()
new_conv_layer = reinit_initial_conv_layer(conv_layer, 4, keep_rgb_weights=True)
out_channels, in_channels, k1, k2 = new_conv_layer.weight.shape
assert torch.allclose( # type: ignore[attr-defined]
initial_weights, new_conv_layer.weight.data[:, :3, :, :]
)
assert out_channels == 5
assert in_channels == 4
assert k1 == 3 and k2 == 3
assert new_conv_layer.stride[0] == 2

Просмотреть файл

@ -7,10 +7,10 @@ import os
from typing import Any, Dict, cast
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
@ -35,71 +35,46 @@ class ClassificationTask(pl.LightningModule):
def config_model(self) -> None:
"""Configures the model based on kwargs parameters passed to the constructor."""
in_channels = self.hparams["in_channels"]
classification_model = self.hparams["classification_model"]
pretrained = False
imagenet_pretrained = False
custom_pretrained = False
if not os.path.exists(self.hparams["weights"]):
if self.hparams["weights"] == "imagenet":
pretrained = True
imagenet_pretrained = True
elif self.hparams["weights"] == "random":
pretrained = False
imagenet_pretrained = False
else:
raise ValueError(
f"Weight type '{self.hparams['weights']}' is not valid."
)
custom_pretrained = False
else:
custom_pretrained = True
# Create the model
if "resnet" in self.hparams["classification_model"]:
self.model = getattr(
torchvision.models.resnet, self.hparams["classification_model"]
)(pretrained=pretrained)
in_features = self.model.fc.in_features
self.model.fc = Linear(in_features, out_features=self.num_classes)
# Update first layer
if in_channels != 3:
w_old = torch.empty(0) # type: ignore[attr-defined]
if pretrained:
w_old = torch.clone( # type: ignore[attr-defined]
self.model.conv1.weight
).detach()
# Create the new layer
self.model.conv1 = Conv2d(
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
)
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
)
# We copy over the pretrained RGB weights
if pretrained:
w_new = torch.clone( # type: ignore[attr-defined]
self.model.conv1.weight
).detach()
if in_channels > 3:
w_new[:, :3, :, :] = w_old
else:
w_old = w_old[:, :in_channels, :, :]
w_new[:, :in_channels, :, :] = w_old
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501
w_new
)
valid_models = timm.list_models(pretrained=True)
if classification_model in valid_models:
self.model = timm.create_model(
classification_model,
num_classes=self.num_classes,
in_chans=in_channels,
pretrained=imagenet_pretrained,
)
else:
raise ValueError(
f"Model type '{self.hparams['classification_model']}' is not valid."
f"Model type '{classification_model}' is not a valid timm model."
)
# Load pretrained weights checkpoint weights
if "resnet" in self.hparams["classification_model"]:
if os.path.exists(self.hparams["weights"]):
name, state_dict = utils.extract_encoder(self.hparams["weights"])
if custom_pretrained:
name, state_dict = utils.extract_encoder(self.hparams["weights"])
if self.hparams["classification_model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hparams['classification_model']}"
)
self.model = utils.load_state_dict(self.model, state_dict)
if self.hparams["classification_model"] != name:
raise ValueError(
f"Trying to load {name} weights into a "
f"{self.hparams['classification_model']}"
)
self.model = utils.load_state_dict(self.model, state_dict)
def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""

Просмотреть файл

@ -5,15 +5,17 @@
import warnings
from collections import OrderedDict
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.modules import Module
from torch.nn.modules import Conv2d, Module
# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Module.__module__ = "nn.Module"
Conv2d.__module__ = "nn.Conv2d"
def extract_encoder(path: str) -> Tuple[str, Dict[str, Tensor]]:
@ -94,3 +96,60 @@ def load_state_dict(model: Module, state_dict: Dict[str, Tensor]) -> Module:
model.load_state_dict(state_dict, strict=False) # type: ignore[arg-type]
return model
def reinit_initial_conv_layer(
layer: Conv2d,
new_in_channels: int,
keep_rgb_weights: bool,
new_stride: Optional[Union[int, Tuple[int, int]]] = None,
new_padding: Optional[Union[str, Union[int, Tuple[int, int]]]] = None,
) -> Conv2d:
"""Clones a Conv2d layer while optionally retaining some of the original weights.
When replacing the first convolutional layer in a model with one that operates over
different number of input channels, we sometimes want to keep a subset of the kernel
weights the same (e.g. the RGB weights of an ImageNet pretrained model). This is a
convenience function that performs that function.
Args:
layer: the Conv2d layer to initialize
new_in_channels: the new number of input channels
keep_rgb_weights: flag indicating whether to re-initialize the first 3 channels
new_stride: optionally, overwrites the ``layer``'s stride with this value
new_padding: optionally, overwrites the ``layers``'s padding with this value
Returns:
a Conv2d layer with new kernel weights
"""
use_bias = layer.bias is not None
if keep_rgb_weights:
w_old = layer.weight.data[:, :3, :, :].clone()
if use_bias:
# mypy doesn't realize that bias isn't None here...
b_old = layer.bias.data.clone() # type: ignore[union-attr]
updated_stride = layer.stride if new_stride is None else new_stride
updated_padding = layer.padding if new_padding is None else new_padding
new_layer = Conv2d(
new_in_channels,
layer.out_channels,
kernel_size=layer.kernel_size, # type: ignore[arg-type]
stride=updated_stride, # type: ignore[arg-type]
padding=updated_padding, # type: ignore[arg-type]
dilation=layer.dilation, # type: ignore[arg-type]
groups=layer.groups,
bias=use_bias,
padding_mode=layer.padding_mode,
)
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
new_layer.weight, mode="fan_out", nonlinearity="relu"
)
if keep_rgb_weights:
new_layer.weight.data[:, :3, :, :] = w_old
if use_bias:
new_layer.bias.data = b_old # type: ignore[union-attr]
return new_layer