enforce monotonic attention for forward attention in eval time

This commit is contained in:
Eren Golge 2019-05-27 14:41:30 +02:00
Родитель ba492f43be
Коммит 59ba37904d
1 изменённых файлов: 10 добавлений и 8 удалений

Просмотреть файл

@ -219,16 +219,18 @@ class Attention(nn.Module):
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(),
(1, 0, 0, 0)).to(inputs.device)
# force incremental alignment
if not self.training:
val, n = prev_alpha.max(1)
if alignment.shape[0] == 1:
alignment[:, n+2:] = 0
else:
for b in range(alignment.shape[0]):
alignment[b, n[b]+2:]
# compute transition potentials
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
self.u * prev_alpha) + 1e-8) * alignment
# force incremental alignment - TODO: make configurable
if not self.training and alignment.shape[0] == 1:
_, n = prev_alpha.max(1)
val, n2 = alpha.max(1)
for b in range(alignment.shape[0]):
alpha[b, n+2:] = 0
alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition.
alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step
# compute attention weights
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
# compute context
context = torch.bmm(self.alpha.unsqueeze(1), inputs)