Fix metric state reset (PL^5273)
* Fix metric state reset * Fix test * Improve formatting Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai> (cherry picked from commit 26488794e1af6e2fbe05bf21de52cf999252cbe7) (cherry picked from commit c55f02a770b5fb5f6d59c48154256ef50f8fa6f1)
This commit is contained in:
Родитель
21081669b0
Коммит
20859b966e
|
@ -94,7 +94,8 @@ class Metric(nn.Module, ABC):
|
|||
reset to this value when ``self.reset()`` is called.
|
||||
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
|
||||
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
|
||||
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
|
||||
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
|
||||
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
|
||||
function in this parameter.
|
||||
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
|
||||
Default is ``False``.
|
||||
|
@ -244,7 +245,7 @@ class Metric(nn.Module, ABC):
|
|||
"""
|
||||
for attr, default in self._defaults.items():
|
||||
current_val = getattr(self, attr)
|
||||
if isinstance(current_val, torch.Tensor):
|
||||
if isinstance(default, torch.Tensor):
|
||||
setattr(self, attr, deepcopy(default).to(current_val.device))
|
||||
else:
|
||||
setattr(self, attr, deepcopy(default))
|
||||
|
|
|
@ -26,6 +26,20 @@ class Dummy(Metric):
|
|||
pass
|
||||
|
||||
|
||||
class DummyList(Metric):
|
||||
name = "DummyList"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", list(), dist_reduce_fx=None)
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def compute(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_inherit():
|
||||
Dummy()
|
||||
|
||||
|
@ -77,12 +91,21 @@ def test_reset():
|
|||
class A(Dummy):
|
||||
pass
|
||||
|
||||
class B(DummyList):
|
||||
pass
|
||||
|
||||
a = A()
|
||||
assert a.x == 0
|
||||
a.x = torch.tensor(5)
|
||||
a.reset()
|
||||
assert a.x == 0
|
||||
|
||||
b = B()
|
||||
assert isinstance(b.x, list) and len(b.x) == 0
|
||||
b.x = torch.tensor(5)
|
||||
b.reset()
|
||||
assert isinstance(b.x, list) and len(b.x) == 0
|
||||
|
||||
|
||||
def test_update():
|
||||
class A(Dummy):
|
||||
|
|
Загрузка…
Ссылка в новой задаче