Fix: Allow hashing of metrics with lists in their state (PL^5939)
* Fix: Allow hashing of metrics with lists in their state * Add test case and modify semantics of Metric __hash__ in order to be compatible with structural equality checks * Fix pep8 style issue Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> (cherry picked from commit a1c26d36e4260543f8642a4fb581f04c5b09c04c)
This commit is contained in:
Родитель
c998634d40
Коммит
40418b52ee
|
@ -333,7 +333,13 @@ class Metric(nn.Module, ABC):
|
|||
hash_vals = [self.__class__.__name__]
|
||||
|
||||
for key in self._defaults.keys():
|
||||
hash_vals.append(getattr(self, key))
|
||||
val = getattr(self, key)
|
||||
# Special case: allow list values, so long
|
||||
# as their elements are hashable
|
||||
if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor):
|
||||
hash_vals.extend(val)
|
||||
else:
|
||||
hash_vals.append(val)
|
||||
|
||||
return hash(tuple(hash_vals))
|
||||
|
||||
|
|
|
@ -154,6 +154,32 @@ def test_compute():
|
|||
assert a.compute() == 5
|
||||
|
||||
|
||||
def test_hash():
|
||||
|
||||
class A(Dummy):
|
||||
pass
|
||||
|
||||
class B(DummyList):
|
||||
pass
|
||||
|
||||
a1 = A()
|
||||
a2 = A()
|
||||
assert hash(a1) != hash(a2)
|
||||
|
||||
b1 = B()
|
||||
b2 = B()
|
||||
assert hash(b1) == hash(b2)
|
||||
assert isinstance(b1.x, list) and len(b1.x) == 0
|
||||
b1.x.append(torch.tensor(5))
|
||||
assert isinstance(hash(b1), int) # <- check that nothing crashes
|
||||
assert isinstance(b1.x, list) and len(b1.x) == 1
|
||||
b2.x.append(torch.tensor(5))
|
||||
# Sanity:
|
||||
assert isinstance(b2.x, list) and len(b2.x) == 1
|
||||
# Now that they have tensor contents, they should have different hashes:
|
||||
assert hash(b1) != hash(b2)
|
||||
|
||||
|
||||
def test_forward():
|
||||
|
||||
class A(Dummy):
|
||||
|
|
Загрузка…
Ссылка в новой задаче