зеркало из https://github.com/microsoft/ptgnn.git
Improve einsum notation
This commit is contained in:
Родитель
8da3c1d91e
Коммит
bf72e7f18a
|
@ -151,7 +151,7 @@ class MultiheadSelfAttentionVarSizedElementReduce(AbstractVarSizedElementReduce)
|
|||
keys = self.__key_layer(inputs.element_embeddings) # [num_elements, H]
|
||||
keys = keys.reshape((keys.shape[0], self.__num_heads, keys.shape[1] // self.__num_heads))
|
||||
|
||||
attention_scores = torch.einsum("bkh,bkh->bk", queries_per_element, keys) / sqrt(
|
||||
attention_scores = torch.einsum("bhk,bhk->bh", queries_per_element, keys) / sqrt(
|
||||
keys.shape[-1]
|
||||
) # [num_elements, num_heads]
|
||||
attention_probs = torch.exp(
|
||||
|
|
Загрузка…
Ссылка в новой задаче