Fix Metric.state_dict (PL^5614)

* Fix Metric.state_dict

* Update CHANGELOG.md

* Update CHANGELOG.md

* Detach tensors in a list if needed

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>

(cherry picked from commit c3df5122e71d497982b6764bbe51c22a26aa4fef)
(cherry picked from commit 7b0c79a7920403cc543a4a79b6efa0cca0508434)
This commit is contained in:
manipopopo 2021-01-25 21:48:12 +01:00 коммит произвёл Jirka Borovec
Родитель 2dd5749968
Коммит 635d9d8b7b
2 изменённых файлов: 36 добавлений и 4 удалений

Просмотреть файл

@ -298,14 +298,26 @@ class Metric(nn.Module, ABC):
for key in self._persistent.keys():
self._persistent[key] = mode
def state_dict(self, *args, **kwargs):
def state_dict(self, destination=None, prefix='', keep_vars=False):
destination = super().state_dict(
destination=destination,
prefix=prefix,
keep_vars=keep_vars
)
# Register metric states to be part of the state_dict
state_dict = super().state_dict()
for key in self._defaults.keys():
if self._persistent[key]:
current_val = getattr(self, key)
state_dict.update({key: current_val})
return state_dict
if not keep_vars:
if torch.is_tensor(current_val):
current_val = current_val.detach()
elif isinstance(current_val, list):
current_val = [
cur_v.detach() if torch.is_tensor(cur_v) else cur_v
for cur_v in current_val
]
destination[prefix + key] = current_val
return destination
def _filter_kwargs(self, **kwargs):
""" filter kwargs such that they match the update signature of the metric """

Просмотреть файл

@ -6,6 +6,7 @@ import cloudpickle
import numpy as np
import pytest
import torch
from torch import nn
from pytorch_lightning.metrics.metric import Metric, MetricCollection
@ -211,6 +212,25 @@ def test_state_dict(tmpdir):
assert metric.state_dict() == OrderedDict()
def test_child_metric_state_dict():
""" test that child metric states will be added to parent state dict """
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.metric = Dummy()
self.metric.add_state('a', torch.tensor(0), persistent=True)
self.metric.add_state('b', [], persistent=True)
self.metric.register_buffer('c', torch.tensor(0))
module = TestModule()
expected_state_dict = {
'metric.a': torch.tensor(0),
'metric.b': [],
'metric.c': torch.tensor(0)
}
assert module.state_dict() == expected_state_dict
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_device_and_dtype_transfer(tmpdir):
metric = DummyMetric1()