This commit is contained in:
Miltos 2021-09-03 11:30:18 +01:00 коммит произвёл GitHub
Родитель 8da3c1d91e
Коммит bf72e7f18a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 1 добавлений и 1 удалений

Просмотреть файл

@ -151,7 +151,7 @@ class MultiheadSelfAttentionVarSizedElementReduce(AbstractVarSizedElementReduce)
keys = self.__key_layer(inputs.element_embeddings) # [num_elements, H]
keys = keys.reshape((keys.shape[0], self.__num_heads, keys.shape[1] // self.__num_heads))
attention_scores = torch.einsum("bkh,bkh->bk", queries_per_element, keys) / sqrt(
attention_scores = torch.einsum("bhk,bhk->bh", queries_per_element, keys) / sqrt(
keys.shape[-1]
) # [num_elements, num_heads]
attention_probs = torch.exp(