Fix: Allow hashing of metrics with lists in their state (PL^5939)

* Fix: Allow hashing of metrics with lists in their state

* Add test case and modify semantics of Metric __hash__ in order to be compatible with structural equality checks

* Fix pep8 style issue

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
(cherry picked from commit a1c26d36e4260543f8642a4fb581f04c5b09c04c)
This commit is contained in:
Philip E Blair 2021-02-18 10:54:12 +01:00 коммит произвёл Jirka Borovec
Родитель c998634d40
Коммит 40418b52ee
2 изменённых файлов: 33 добавлений и 1 удалений

Просмотреть файл

@ -333,7 +333,13 @@ class Metric(nn.Module, ABC):
hash_vals = [self.__class__.__name__]
for key in self._defaults.keys():
hash_vals.append(getattr(self, key))
val = getattr(self, key)
# Special case: allow list values, so long
# as their elements are hashable
if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor):
hash_vals.extend(val)
else:
hash_vals.append(val)
return hash(tuple(hash_vals))

Просмотреть файл

@ -154,6 +154,32 @@ def test_compute():
assert a.compute() == 5
def test_hash():
class A(Dummy):
pass
class B(DummyList):
pass
a1 = A()
a2 = A()
assert hash(a1) != hash(a2)
b1 = B()
b2 = B()
assert hash(b1) == hash(b2)
assert isinstance(b1.x, list) and len(b1.x) == 0
b1.x.append(torch.tensor(5))
assert isinstance(hash(b1), int) # <- check that nothing crashes
assert isinstance(b1.x, list) and len(b1.x) == 1
b2.x.append(torch.tensor(5))
# Sanity:
assert isinstance(b2.x, list) and len(b2.x) == 1
# Now that they have tensor contents, they should have different hashes:
assert hash(b1) != hash(b2)
def test_forward():
class A(Dummy):