Fix metric state reset (PL^5273)

* Fix metric state reset

* Fix test

* Improve formatting

Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
(cherry picked from commit 26488794e1af6e2fbe05bf21de52cf999252cbe7)
(cherry picked from commit c55f02a770b5fb5f6d59c48154256ef50f8fa6f1)
This commit is contained in:
Tadej Svetina 2020-12-29 22:09:10 +01:00 коммит произвёл Jirka Borovec
Родитель 21081669b0
Коммит 20859b966e
2 изменённых файлов: 26 добавлений и 2 удалений

Просмотреть файл

@ -94,7 +94,8 @@ class Metric(nn.Module, ABC):
reset to this value when ``self.reset()`` is called.
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
function in this parameter.
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Default is ``False``.
@ -244,7 +245,7 @@ class Metric(nn.Module, ABC):
"""
for attr, default in self._defaults.items():
current_val = getattr(self, attr)
if isinstance(current_val, torch.Tensor):
if isinstance(default, torch.Tensor):
setattr(self, attr, deepcopy(default).to(current_val.device))
else:
setattr(self, attr, deepcopy(default))

Просмотреть файл

@ -26,6 +26,20 @@ class Dummy(Metric):
pass
class DummyList(Metric):
name = "DummyList"
def __init__(self):
super().__init__()
self.add_state("x", list(), dist_reduce_fx=None)
def update(self):
pass
def compute(self):
pass
def test_inherit():
Dummy()
@ -77,12 +91,21 @@ def test_reset():
class A(Dummy):
pass
class B(DummyList):
pass
a = A()
assert a.x == 0
a.x = torch.tensor(5)
a.reset()
assert a.x == 0
b = B()
assert isinstance(b.x, list) and len(b.x) == 0
b.x = torch.tensor(5)
b.reset()
assert isinstance(b.x, list) and len(b.x) == 0
def test_update():
class A(Dummy):