diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index e5f8a0f..cbedcf1 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -298,14 +298,26 @@ class Metric(nn.Module, ABC): for key in self._persistent.keys(): self._persistent[key] = mode - def state_dict(self, *args, **kwargs): + def state_dict(self, destination=None, prefix='', keep_vars=False): + destination = super().state_dict( + destination=destination, + prefix=prefix, + keep_vars=keep_vars + ) # Register metric states to be part of the state_dict - state_dict = super().state_dict() for key in self._defaults.keys(): if self._persistent[key]: current_val = getattr(self, key) - state_dict.update({key: current_val}) - return state_dict + if not keep_vars: + if torch.is_tensor(current_val): + current_val = current_val.detach() + elif isinstance(current_val, list): + current_val = [ + cur_v.detach() if torch.is_tensor(cur_v) else cur_v + for cur_v in current_val + ] + destination[prefix + key] = current_val + return destination def _filter_kwargs(self, **kwargs): """ filter kwargs such that they match the update signature of the metric """ diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 83f400e..e4a4ec9 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -6,6 +6,7 @@ import cloudpickle import numpy as np import pytest import torch +from torch import nn from pytorch_lightning.metrics.metric import Metric, MetricCollection @@ -211,6 +212,25 @@ def test_state_dict(tmpdir): assert metric.state_dict() == OrderedDict() +def test_child_metric_state_dict(): + """ test that child metric states will be added to parent state dict """ + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.metric = Dummy() + self.metric.add_state('a', torch.tensor(0), persistent=True) + self.metric.add_state('b', [], persistent=True) + self.metric.register_buffer('c', torch.tensor(0)) + + module = TestModule() + expected_state_dict = { + 'metric.a': torch.tensor(0), + 'metric.b': [], + 'metric.c': torch.tensor(0) + } + assert module.state_dict() == expected_state_dict + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") def test_device_and_dtype_transfer(tmpdir): metric = DummyMetric1()