From 19b77cc2c290e811ae1c1ea44125776414b4e431 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Mon, 29 Mar 2021 13:53:28 -0700 Subject: [PATCH] Bugfix: reset should set clear _computed (#147) * Bugfix: reset should set clear _computed * changelog * Update CHANGELOG.md Co-authored-by: Nicki Skafte --- CHANGELOG.md | 3 +++ tests/bases/test_metric.py | 9 +++++++++ torchmetrics/metric.py | 2 ++ 3 files changed, 14 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 962388e..52244e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed when `_stable_1d_sort` to work when n >= N ([#6177](https://github.com/PyTorchLightning/pytorch-lightning/pull/6177)) +- Fixed `_computed` attribute not being correctly reset ([#147](https://github.com/PyTorchLightning/metrics/pull/147)) + + ## [0.2.0] - 2021-03-12 diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 164b59d..cb4eb55 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -95,6 +95,15 @@ def test_reset(): assert isinstance(b.x, list) and len(b.x) == 0 +def test_reset_compute(): + a = DummyMetricSum() + assert a.x == 0 + a.update(tensor(5)) + assert a.compute() == 5 + a.reset() + assert a.compute() == 0 + + def test_update(): class A(DummyMetric): diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 29921ee..4ea08e3 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -254,6 +254,8 @@ class Metric(nn.Module, ABC): """ This method automatically resets the metric state variables to their default value. """ + self._computed = None + for attr, default in self._defaults.items(): current_val = getattr(self, attr) if isinstance(default, Tensor):