From bf72e7f18aa086577f09e1f66c3353e720d46414 Mon Sep 17 00:00:00 2001 From: Miltos Date: Fri, 3 Sep 2021 11:30:18 +0100 Subject: [PATCH] Improve einsum notation --- ptgnn/neuralmodels/reduceops/varsizedsummary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ptgnn/neuralmodels/reduceops/varsizedsummary.py b/ptgnn/neuralmodels/reduceops/varsizedsummary.py index 816778c..c51a16e 100644 --- a/ptgnn/neuralmodels/reduceops/varsizedsummary.py +++ b/ptgnn/neuralmodels/reduceops/varsizedsummary.py @@ -151,7 +151,7 @@ class MultiheadSelfAttentionVarSizedElementReduce(AbstractVarSizedElementReduce) keys = self.__key_layer(inputs.element_embeddings) # [num_elements, H] keys = keys.reshape((keys.shape[0], self.__num_heads, keys.shape[1] // self.__num_heads)) - attention_scores = torch.einsum("bkh,bkh->bk", queries_per_element, keys) / sqrt( + attention_scores = torch.einsum("bhk,bhk->bh", queries_per_element, keys) / sqrt( keys.shape[-1] ) # [num_elements, num_heads] attention_probs = torch.exp(