Bugfix: reset should set clear _computed (#147)
* Bugfix: reset should set clear _computed * changelog * Update CHANGELOG.md Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
Родитель
3d0823532b
Коммит
19b77cc2c2
|
@ -63,6 +63,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- Fixed when `_stable_1d_sort` to work when n >= N ([#6177](https://github.com/PyTorchLightning/pytorch-lightning/pull/6177))
|
||||
|
||||
|
||||
- Fixed `_computed` attribute not being correctly reset ([#147](https://github.com/PyTorchLightning/metrics/pull/147))
|
||||
|
||||
|
||||
## [0.2.0] - 2021-03-12
|
||||
|
||||
|
||||
|
|
|
@ -95,6 +95,15 @@ def test_reset():
|
|||
assert isinstance(b.x, list) and len(b.x) == 0
|
||||
|
||||
|
||||
def test_reset_compute():
|
||||
a = DummyMetricSum()
|
||||
assert a.x == 0
|
||||
a.update(tensor(5))
|
||||
assert a.compute() == 5
|
||||
a.reset()
|
||||
assert a.compute() == 0
|
||||
|
||||
|
||||
def test_update():
|
||||
|
||||
class A(DummyMetric):
|
||||
|
|
|
@ -254,6 +254,8 @@ class Metric(nn.Module, ABC):
|
|||
"""
|
||||
This method automatically resets the metric state variables to their default value.
|
||||
"""
|
||||
self._computed = None
|
||||
|
||||
for attr, default in self._defaults.items():
|
||||
current_val = getattr(self, attr)
|
||||
if isinstance(default, Tensor):
|
||||
|
|
Загрузка…
Ссылка в новой задаче