load_state_dict does not return the model (#1503)

* Update pretrained_weights.ipynb

Fixed an error in the state dict loading of the turorial and added a comment on the num_classes parameter when creating timm models.

* Update docs/tutorials/pretrained_weights.ipynb

* Update utils.py

* Import Tuple from typing
* Change return of `load_state_dict` from `model` to `Tuple[List[str], List[str]]`, matching the return of the standard PyTorch builtin function.

* Update pretrained_weights.ipynb

Remove example of loading pretrained model without prediction head (`num_classes=0`).

* Update README.md

Adapt new `load_state_dict` function.

* Mimic return type of builtin load_state_dict

* Modern type hints

* Blacken

* Try being explicit

---------

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Konstantin Klemmer 2024-02-06 04:55:15 -05:00 коммит произвёл GitHub
Родитель 91b5e01ef8
Коммит 55b3c50891
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 18 добавлений и 14 удалений

Просмотреть файл

@ -132,7 +132,7 @@ from torchgeo.models import ResNet18_Weights
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10)
model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
```
These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html).

Просмотреть файл

@ -228,7 +228,7 @@
"source": [
"in_chans = weights.meta[\"in_chans\"]\n",
"model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n",
"model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
"model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
]
},
{

Просмотреть файл

@ -41,7 +41,7 @@ def test_get_input_layer_name_and_module() -> None:
def test_load_state_dict(checkpoint: str, model: Module) -> None:
_, state_dict = extract_backbone(checkpoint)
model = load_state_dict(model, state_dict)
load_state_dict(model, state_dict)
def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) -> None:
@ -58,7 +58,7 @@ def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module)
f" model {expected_in_channels}. Overriding with new input channels"
)
with pytest.warns(UserWarning, match=warning):
model = load_state_dict(model, state_dict)
load_state_dict(model, state_dict)
def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None:
@ -74,7 +74,7 @@ def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None
f" {expected_num_classes}. Overriding with new num classes"
)
with pytest.warns(UserWarning, match=warning):
model = load_state_dict(model, state_dict)
load_state_dict(model, state_dict)
def test_reinit_initial_conv_layer() -> None:

Просмотреть файл

@ -343,7 +343,7 @@ class BYOLTask(BaseTask):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
backbone = utils.load_state_dict(backbone, state_dict)
utils.load_state_dict(backbone, state_dict)
self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224))

Просмотреть файл

@ -137,7 +137,7 @@ class ClassificationTask(BaseTask):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)
utils.load_state_dict(self.model, state_dict)
# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:

Просмотреть файл

@ -261,7 +261,7 @@ class MoCoTask(BaseTask):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.backbone = utils.load_state_dict(self.backbone, state_dict)
utils.load_state_dict(self.backbone, state_dict)
# Create projection (and prediction) head
batch_norm = version == 3

Просмотреть файл

@ -128,7 +128,7 @@ class RegressionTask(BaseTask):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)
utils.load_state_dict(self.model, state_dict)
# Freeze backbone and unfreeze classifier head
if self.hparams["freeze_backbone"]:

Просмотреть файл

@ -172,7 +172,7 @@ class SimCLRTask(BaseTask):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.backbone = utils.load_state_dict(self.backbone, state_dict)
utils.load_state_dict(self.backbone, state_dict)
# Create projection head
input_dim = self.backbone.num_features

Просмотреть файл

@ -71,7 +71,9 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]:
return key, module
def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Module:
def load_state_dict(
model: Module, state_dict: "OrderedDict[str, Tensor]"
) -> tuple[list[str], list[str]]:
"""Load pretrained resnet weights to a model.
Args:
@ -79,7 +81,7 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo
state_dict: dict containing tensor parameters
Returns:
the model with pretrained weights
The missing and unexpected keys
Warns:
If input channels in model != pretrained model input channels
@ -115,8 +117,10 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo
state_dict[output_module_key + ".bias"],
)
model.load_state_dict(state_dict, strict=False)
return model
missing_keys: list[str]
unexpected_keys: list[str]
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
return missing_keys, unexpected_keys
def reinit_initial_conv_layer(