enforce monotonic attention in forward attention y for batches

This commit is contained in:
Eren Golge 2019-05-28 14:28:32 +02:00
Родитель d905f6e795
Коммит 0b5a00d29e
1 изменённых файлов: 3 добавлений и 6 удалений

Просмотреть файл

@ -203,16 +203,13 @@ class Attention(nn.Module):
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
self.u * prev_alpha) + 1e-8) * alignment
# force incremental alignment - TODO: make configurable
if not self.training and alignment.shape[0] == 1:
if not self.training:
_, n = prev_alpha.max(1)
val, n2 = alpha.max(1)
for b in range(alignment.shape[0]):
alpha[b, n + 2:] = 0
alpha[b, :(
n - 1
)] = 0 # ignore all previous states to prevent repetition.
alpha[b, (
n - 2)] = 0.01 * val # smoothing factor for the prev step
alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition.
alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step
# compute attention weights
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
# compute context