Bugfix: reset should set clear _computed (#147)
* Bugfix: reset should set clear _computed * changelog * Update CHANGELOG.md Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
Родитель
3d0823532b
Коммит
19b77cc2c2
|
@ -63,6 +63,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
- Fixed when `_stable_1d_sort` to work when n >= N ([#6177](https://github.com/PyTorchLightning/pytorch-lightning/pull/6177))
|
- Fixed when `_stable_1d_sort` to work when n >= N ([#6177](https://github.com/PyTorchLightning/pytorch-lightning/pull/6177))
|
||||||
|
|
||||||
|
|
||||||
|
- Fixed `_computed` attribute not being correctly reset ([#147](https://github.com/PyTorchLightning/metrics/pull/147))
|
||||||
|
|
||||||
|
|
||||||
## [0.2.0] - 2021-03-12
|
## [0.2.0] - 2021-03-12
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -95,6 +95,15 @@ def test_reset():
|
||||||
assert isinstance(b.x, list) and len(b.x) == 0
|
assert isinstance(b.x, list) and len(b.x) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_compute():
|
||||||
|
a = DummyMetricSum()
|
||||||
|
assert a.x == 0
|
||||||
|
a.update(tensor(5))
|
||||||
|
assert a.compute() == 5
|
||||||
|
a.reset()
|
||||||
|
assert a.compute() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_update():
|
def test_update():
|
||||||
|
|
||||||
class A(DummyMetric):
|
class A(DummyMetric):
|
||||||
|
|
|
@ -254,6 +254,8 @@ class Metric(nn.Module, ABC):
|
||||||
"""
|
"""
|
||||||
This method automatically resets the metric state variables to their default value.
|
This method automatically resets the metric state variables to their default value.
|
||||||
"""
|
"""
|
||||||
|
self._computed = None
|
||||||
|
|
||||||
for attr, default in self._defaults.items():
|
for attr, default in self._defaults.items():
|
||||||
current_val = getattr(self, attr)
|
current_val = getattr(self, attr)
|
||||||
if isinstance(default, Tensor):
|
if isinstance(default, Tensor):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче