зеркало из https://github.com/microsoft/torchgeo.git
Adding timm models to the classification task and refactoring (#210)
* Adding support for VGG models in the classification task Refactoring the logic for replacing the first conv layer in a network * Fix formatting * Testing the stuff * mypy with torch is such a waste of time * Update torchgeo/trainers/utils.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/trainers/utils.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Adding timm dependency * Incorporate timm into ClassificationTask * Fix tests? * Formatting * Allow for overriding stride and padding in `reinit_initial_conv_layer` * Putting back some stuff I accidentally overwrote in the rebase * Bug * Format Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
de48bd54ad
Коммит
60674cc200
|
@ -41,4 +41,5 @@ dependencies:
|
|||
- segmentation-models-pytorch>=0.2
|
||||
- setuptools>=30.4
|
||||
- sphinx>=3
|
||||
- timm>=0.2.1
|
||||
- torchmetrics
|
||||
|
|
|
@ -77,6 +77,8 @@ train =
|
|||
scikit-learn>=0.18
|
||||
# segmentation-models-pytorch 0.2+ required for smp.losses module
|
||||
segmentation-models-pytorch>=0.2
|
||||
# timm 0.2.1+ required for `features_only` option in create_model
|
||||
timm>=0.2.1
|
||||
torchmetrics
|
||||
# Optional developer requirements
|
||||
style =
|
||||
|
|
|
@ -100,17 +100,21 @@ class TestClassificationTask:
|
|||
|
||||
@pytest.fixture(
|
||||
scope="class",
|
||||
params=zip(["ce", "jaccard", "focal"], ["imagenet", "random", "random"]),
|
||||
params=zip(
|
||||
["ce", "jaccard", "focal"],
|
||||
["imagenet", "random", "random"],
|
||||
["resnet18", "hrnet_w18_small_v2", "tf_efficientnet_b0"],
|
||||
),
|
||||
)
|
||||
def config(
|
||||
self, request: SubRequest, datamodule: DummyDataModule
|
||||
) -> Dict[str, Any]:
|
||||
loss, weights, model = request.param
|
||||
task_args = {}
|
||||
task_args["classification_model"] = "resnet18"
|
||||
task_args["learning_rate"] = 3e-4 # type: ignore[assignment]
|
||||
task_args["learning_rate_schedule_patience"] = 6 # type: ignore[assignment]
|
||||
task_args["in_channels"] = datamodule.num_channels # type: ignore[assignment]
|
||||
loss, weights = request.param
|
||||
task_args["classification_model"] = model
|
||||
task_args["learning_rate"] = 3e-4
|
||||
task_args["learning_rate_schedule_patience"] = 6
|
||||
task_args["in_channels"] = datamodule.num_channels
|
||||
task_args["loss"] = loss
|
||||
task_args["weights"] = weights
|
||||
return task_args
|
||||
|
@ -157,7 +161,7 @@ class TestClassificationTask:
|
|||
|
||||
def test_invalid_model(self, config: Dict[str, Any]) -> None:
|
||||
config["classification_model"] = "invalid_model"
|
||||
error_message = "Model type 'invalid_model' is not valid."
|
||||
error_message = "Model type 'invalid_model' is not a valid timm model."
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
ClassificationTask(**config)
|
||||
|
||||
|
|
|
@ -10,7 +10,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.nn.modules import Module
|
||||
|
||||
from torchgeo.trainers.utils import extract_encoder, load_state_dict
|
||||
from torchgeo.trainers.utils import (
|
||||
extract_encoder,
|
||||
load_state_dict,
|
||||
reinit_initial_conv_layer,
|
||||
)
|
||||
|
||||
|
||||
class FakeExperiment(object):
|
||||
|
@ -87,3 +91,21 @@ def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None
|
|||
)
|
||||
with pytest.warns(UserWarning, match=warning):
|
||||
model = load_state_dict(model, state_dict)
|
||||
|
||||
|
||||
def test_reinit_initial_conv_layer() -> None:
|
||||
conv_layer = nn.Conv2d( # type: ignore[attr-defined]
|
||||
3, 5, kernel_size=3, stride=2, padding=1, bias=True
|
||||
)
|
||||
initial_weights = conv_layer.weight.data.clone()
|
||||
|
||||
new_conv_layer = reinit_initial_conv_layer(conv_layer, 4, keep_rgb_weights=True)
|
||||
|
||||
out_channels, in_channels, k1, k2 = new_conv_layer.weight.shape
|
||||
assert torch.allclose( # type: ignore[attr-defined]
|
||||
initial_weights, new_conv_layer.weight.data[:, :3, :, :]
|
||||
)
|
||||
assert out_channels == 5
|
||||
assert in_channels == 4
|
||||
assert k1 == 3 and k2 == 3
|
||||
assert new_conv_layer.stride[0] == 2
|
||||
|
|
|
@ -7,10 +7,10 @@ import os
|
|||
from typing import Any, Dict, cast
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models
|
||||
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Conv2d, Linear
|
||||
|
@ -35,71 +35,46 @@ class ClassificationTask(pl.LightningModule):
|
|||
def config_model(self) -> None:
|
||||
"""Configures the model based on kwargs parameters passed to the constructor."""
|
||||
in_channels = self.hparams["in_channels"]
|
||||
classification_model = self.hparams["classification_model"]
|
||||
|
||||
pretrained = False
|
||||
imagenet_pretrained = False
|
||||
custom_pretrained = False
|
||||
if not os.path.exists(self.hparams["weights"]):
|
||||
if self.hparams["weights"] == "imagenet":
|
||||
pretrained = True
|
||||
imagenet_pretrained = True
|
||||
elif self.hparams["weights"] == "random":
|
||||
pretrained = False
|
||||
imagenet_pretrained = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Weight type '{self.hparams['weights']}' is not valid."
|
||||
)
|
||||
custom_pretrained = False
|
||||
else:
|
||||
custom_pretrained = True
|
||||
|
||||
# Create the model
|
||||
if "resnet" in self.hparams["classification_model"]:
|
||||
self.model = getattr(
|
||||
torchvision.models.resnet, self.hparams["classification_model"]
|
||||
)(pretrained=pretrained)
|
||||
in_features = self.model.fc.in_features
|
||||
self.model.fc = Linear(in_features, out_features=self.num_classes)
|
||||
|
||||
# Update first layer
|
||||
if in_channels != 3:
|
||||
w_old = torch.empty(0) # type: ignore[attr-defined]
|
||||
if pretrained:
|
||||
w_old = torch.clone( # type: ignore[attr-defined]
|
||||
self.model.conv1.weight
|
||||
).detach()
|
||||
# Create the new layer
|
||||
self.model.conv1 = Conv2d(
|
||||
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
|
||||
)
|
||||
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
|
||||
self.model.conv1.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
|
||||
# We copy over the pretrained RGB weights
|
||||
if pretrained:
|
||||
w_new = torch.clone( # type: ignore[attr-defined]
|
||||
self.model.conv1.weight
|
||||
).detach()
|
||||
if in_channels > 3:
|
||||
w_new[:, :3, :, :] = w_old
|
||||
else:
|
||||
w_old = w_old[:, :in_channels, :, :]
|
||||
w_new[:, :in_channels, :, :] = w_old
|
||||
self.model.conv1.weight = nn.Parameter( # type: ignore[attr-defined] # noqa: E501
|
||||
w_new
|
||||
)
|
||||
valid_models = timm.list_models(pretrained=True)
|
||||
if classification_model in valid_models:
|
||||
self.model = timm.create_model(
|
||||
classification_model,
|
||||
num_classes=self.num_classes,
|
||||
in_chans=in_channels,
|
||||
pretrained=imagenet_pretrained,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model type '{self.hparams['classification_model']}' is not valid."
|
||||
f"Model type '{classification_model}' is not a valid timm model."
|
||||
)
|
||||
|
||||
# Load pretrained weights checkpoint weights
|
||||
if "resnet" in self.hparams["classification_model"]:
|
||||
if os.path.exists(self.hparams["weights"]):
|
||||
name, state_dict = utils.extract_encoder(self.hparams["weights"])
|
||||
if custom_pretrained:
|
||||
name, state_dict = utils.extract_encoder(self.hparams["weights"])
|
||||
|
||||
if self.hparams["classification_model"] != name:
|
||||
raise ValueError(
|
||||
f"Trying to load {name} weights into a "
|
||||
f"{self.hparams['classification_model']}"
|
||||
)
|
||||
|
||||
self.model = utils.load_state_dict(self.model, state_dict)
|
||||
if self.hparams["classification_model"] != name:
|
||||
raise ValueError(
|
||||
f"Trying to load {name} weights into a "
|
||||
f"{self.hparams['classification_model']}"
|
||||
)
|
||||
self.model = utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
def config_task(self) -> None:
|
||||
"""Configures the task based on kwargs parameters passed to the constructor."""
|
||||
|
|
|
@ -5,15 +5,17 @@
|
|||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.modules import Module
|
||||
from torch.nn.modules import Conv2d, Module
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/60979
|
||||
# https://github.com/pytorch/pytorch/pull/61045
|
||||
Module.__module__ = "nn.Module"
|
||||
Conv2d.__module__ = "nn.Conv2d"
|
||||
|
||||
|
||||
def extract_encoder(path: str) -> Tuple[str, Dict[str, Tensor]]:
|
||||
|
@ -94,3 +96,60 @@ def load_state_dict(model: Module, state_dict: Dict[str, Tensor]) -> Module:
|
|||
model.load_state_dict(state_dict, strict=False) # type: ignore[arg-type]
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def reinit_initial_conv_layer(
|
||||
layer: Conv2d,
|
||||
new_in_channels: int,
|
||||
keep_rgb_weights: bool,
|
||||
new_stride: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
new_padding: Optional[Union[str, Union[int, Tuple[int, int]]]] = None,
|
||||
) -> Conv2d:
|
||||
"""Clones a Conv2d layer while optionally retaining some of the original weights.
|
||||
|
||||
When replacing the first convolutional layer in a model with one that operates over
|
||||
different number of input channels, we sometimes want to keep a subset of the kernel
|
||||
weights the same (e.g. the RGB weights of an ImageNet pretrained model). This is a
|
||||
convenience function that performs that function.
|
||||
|
||||
Args:
|
||||
layer: the Conv2d layer to initialize
|
||||
new_in_channels: the new number of input channels
|
||||
keep_rgb_weights: flag indicating whether to re-initialize the first 3 channels
|
||||
new_stride: optionally, overwrites the ``layer``'s stride with this value
|
||||
new_padding: optionally, overwrites the ``layers``'s padding with this value
|
||||
|
||||
Returns:
|
||||
a Conv2d layer with new kernel weights
|
||||
"""
|
||||
use_bias = layer.bias is not None
|
||||
if keep_rgb_weights:
|
||||
w_old = layer.weight.data[:, :3, :, :].clone()
|
||||
if use_bias:
|
||||
# mypy doesn't realize that bias isn't None here...
|
||||
b_old = layer.bias.data.clone() # type: ignore[union-attr]
|
||||
|
||||
updated_stride = layer.stride if new_stride is None else new_stride
|
||||
updated_padding = layer.padding if new_padding is None else new_padding
|
||||
|
||||
new_layer = Conv2d(
|
||||
new_in_channels,
|
||||
layer.out_channels,
|
||||
kernel_size=layer.kernel_size, # type: ignore[arg-type]
|
||||
stride=updated_stride, # type: ignore[arg-type]
|
||||
padding=updated_padding, # type: ignore[arg-type]
|
||||
dilation=layer.dilation, # type: ignore[arg-type]
|
||||
groups=layer.groups,
|
||||
bias=use_bias,
|
||||
padding_mode=layer.padding_mode,
|
||||
)
|
||||
nn.init.kaiming_normal_( # type: ignore[no-untyped-call]
|
||||
new_layer.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
|
||||
if keep_rgb_weights:
|
||||
new_layer.weight.data[:, :3, :, :] = w_old
|
||||
if use_bias:
|
||||
new_layer.bias.data = b_old # type: ignore[union-attr]
|
||||
|
||||
return new_layer
|
||||
|
|
Загрузка…
Ссылка в новой задаче