This commit is contained in:
Eren 2018-09-19 15:08:43 +02:00
Родитель a165cd7bda
Коммит f2ef1ca36a
1 изменённых файлов: 2 добавлений и 1 удалений

Просмотреть файл

@ -131,7 +131,8 @@ class AttentionRNNCell(nn.Module):
mask = mask.view(memory.size(0), -1)
alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight
alignment = F.softmax(alignment, dim=-1)
# alignment = F.softmax(alignment, dim=-1)
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j