Fix Metric.state_dict (PL^5614)
* Fix Metric.state_dict * Update CHANGELOG.md * Update CHANGELOG.md * Detach tensors in a list if needed Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> (cherry picked from commit c3df5122e71d497982b6764bbe51c22a26aa4fef) (cherry picked from commit 7b0c79a7920403cc543a4a79b6efa0cca0508434)
This commit is contained in:
Родитель
2dd5749968
Коммит
635d9d8b7b
|
@ -298,14 +298,26 @@ class Metric(nn.Module, ABC):
|
|||
for key in self._persistent.keys():
|
||||
self._persistent[key] = mode
|
||||
|
||||
def state_dict(self, *args, **kwargs):
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
destination = super().state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars
|
||||
)
|
||||
# Register metric states to be part of the state_dict
|
||||
state_dict = super().state_dict()
|
||||
for key in self._defaults.keys():
|
||||
if self._persistent[key]:
|
||||
current_val = getattr(self, key)
|
||||
state_dict.update({key: current_val})
|
||||
return state_dict
|
||||
if not keep_vars:
|
||||
if torch.is_tensor(current_val):
|
||||
current_val = current_val.detach()
|
||||
elif isinstance(current_val, list):
|
||||
current_val = [
|
||||
cur_v.detach() if torch.is_tensor(cur_v) else cur_v
|
||||
for cur_v in current_val
|
||||
]
|
||||
destination[prefix + key] = current_val
|
||||
return destination
|
||||
|
||||
def _filter_kwargs(self, **kwargs):
|
||||
""" filter kwargs such that they match the update signature of the metric """
|
||||
|
|
|
@ -6,6 +6,7 @@ import cloudpickle
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric, MetricCollection
|
||||
|
||||
|
@ -211,6 +212,25 @@ def test_state_dict(tmpdir):
|
|||
assert metric.state_dict() == OrderedDict()
|
||||
|
||||
|
||||
def test_child_metric_state_dict():
|
||||
""" test that child metric states will be added to parent state dict """
|
||||
class TestModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = Dummy()
|
||||
self.metric.add_state('a', torch.tensor(0), persistent=True)
|
||||
self.metric.add_state('b', [], persistent=True)
|
||||
self.metric.register_buffer('c', torch.tensor(0))
|
||||
|
||||
module = TestModule()
|
||||
expected_state_dict = {
|
||||
'metric.a': torch.tensor(0),
|
||||
'metric.b': [],
|
||||
'metric.c': torch.tensor(0)
|
||||
}
|
||||
assert module.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
|
||||
def test_device_and_dtype_transfer(tmpdir):
|
||||
metric = DummyMetric1()
|
||||
|
|
Загрузка…
Ссылка в новой задаче