зеркало из https://github.com/microsoft/torchgeo.git
load_state_dict does not return the model (#1503)
* Update pretrained_weights.ipynb Fixed an error in the state dict loading of the turorial and added a comment on the num_classes parameter when creating timm models. * Update docs/tutorials/pretrained_weights.ipynb * Update utils.py * Import Tuple from typing * Change return of `load_state_dict` from `model` to `Tuple[List[str], List[str]]`, matching the return of the standard PyTorch builtin function. * Update pretrained_weights.ipynb Remove example of loading pretrained model without prediction head (`num_classes=0`). * Update README.md Adapt new `load_state_dict` function. * Mimic return type of builtin load_state_dict * Modern type hints * Blacken * Try being explicit --------- Co-authored-by: Caleb Robinson <calebrob6@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
91b5e01ef8
Коммит
55b3c50891
|
@ -132,7 +132,7 @@ from torchgeo.models import ResNet18_Weights
|
|||
|
||||
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
|
||||
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10)
|
||||
model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
||||
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
|
||||
```
|
||||
|
||||
These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html).
|
||||
|
|
|
@ -228,7 +228,7 @@
|
|||
"source": [
|
||||
"in_chans = weights.meta[\"in_chans\"]\n",
|
||||
"model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n",
|
||||
"model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
|
||||
"model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -41,7 +41,7 @@ def test_get_input_layer_name_and_module() -> None:
|
|||
|
||||
def test_load_state_dict(checkpoint: str, model: Module) -> None:
|
||||
_, state_dict = extract_backbone(checkpoint)
|
||||
model = load_state_dict(model, state_dict)
|
||||
load_state_dict(model, state_dict)
|
||||
|
||||
|
||||
def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module) -> None:
|
||||
|
@ -58,7 +58,7 @@ def test_load_state_dict_unequal_input_channels(checkpoint: str, model: Module)
|
|||
f" model {expected_in_channels}. Overriding with new input channels"
|
||||
)
|
||||
with pytest.warns(UserWarning, match=warning):
|
||||
model = load_state_dict(model, state_dict)
|
||||
load_state_dict(model, state_dict)
|
||||
|
||||
|
||||
def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None:
|
||||
|
@ -74,7 +74,7 @@ def test_load_state_dict_unequal_classes(checkpoint: str, model: Module) -> None
|
|||
f" {expected_num_classes}. Overriding with new num classes"
|
||||
)
|
||||
with pytest.warns(UserWarning, match=warning):
|
||||
model = load_state_dict(model, state_dict)
|
||||
load_state_dict(model, state_dict)
|
||||
|
||||
|
||||
def test_reinit_initial_conv_layer() -> None:
|
||||
|
|
|
@ -343,7 +343,7 @@ class BYOLTask(BaseTask):
|
|||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
backbone = utils.load_state_dict(backbone, state_dict)
|
||||
utils.load_state_dict(backbone, state_dict)
|
||||
|
||||
self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224))
|
||||
|
||||
|
|
|
@ -137,7 +137,7 @@ class ClassificationTask(BaseTask):
|
|||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
self.model = utils.load_state_dict(self.model, state_dict)
|
||||
utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
# Freeze backbone and unfreeze classifier head
|
||||
if self.hparams["freeze_backbone"]:
|
||||
|
|
|
@ -261,7 +261,7 @@ class MoCoTask(BaseTask):
|
|||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
self.backbone = utils.load_state_dict(self.backbone, state_dict)
|
||||
utils.load_state_dict(self.backbone, state_dict)
|
||||
|
||||
# Create projection (and prediction) head
|
||||
batch_norm = version == 3
|
||||
|
|
|
@ -128,7 +128,7 @@ class RegressionTask(BaseTask):
|
|||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
self.model = utils.load_state_dict(self.model, state_dict)
|
||||
utils.load_state_dict(self.model, state_dict)
|
||||
|
||||
# Freeze backbone and unfreeze classifier head
|
||||
if self.hparams["freeze_backbone"]:
|
||||
|
|
|
@ -172,7 +172,7 @@ class SimCLRTask(BaseTask):
|
|||
_, state_dict = utils.extract_backbone(weights)
|
||||
else:
|
||||
state_dict = get_weight(weights).get_state_dict(progress=True)
|
||||
self.backbone = utils.load_state_dict(self.backbone, state_dict)
|
||||
utils.load_state_dict(self.backbone, state_dict)
|
||||
|
||||
# Create projection head
|
||||
input_dim = self.backbone.num_features
|
||||
|
|
|
@ -71,7 +71,9 @@ def _get_input_layer_name_and_module(model: Module) -> tuple[str, Module]:
|
|||
return key, module
|
||||
|
||||
|
||||
def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Module:
|
||||
def load_state_dict(
|
||||
model: Module, state_dict: "OrderedDict[str, Tensor]"
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Load pretrained resnet weights to a model.
|
||||
|
||||
Args:
|
||||
|
@ -79,7 +81,7 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo
|
|||
state_dict: dict containing tensor parameters
|
||||
|
||||
Returns:
|
||||
the model with pretrained weights
|
||||
The missing and unexpected keys
|
||||
|
||||
Warns:
|
||||
If input channels in model != pretrained model input channels
|
||||
|
@ -115,8 +117,10 @@ def load_state_dict(model: Module, state_dict: "OrderedDict[str, Tensor]") -> Mo
|
|||
state_dict[output_module_key + ".bias"],
|
||||
)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
missing_keys: list[str]
|
||||
unexpected_keys: list[str]
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
|
||||
def reinit_initial_conv_layer(
|
||||
|
|
Загрузка…
Ссылка в новой задаче