diff --git a/layers/common_layers.py b/layers/common_layers.py index c15c0b1..d939fbc 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -203,16 +203,13 @@ class Attention(nn.Module): alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment # force incremental alignment - TODO: make configurable - if not self.training and alignment.shape[0] == 1: + if not self.training: _, n = prev_alpha.max(1) val, n2 = alpha.max(1) for b in range(alignment.shape[0]): alpha[b, n + 2:] = 0 - alpha[b, :( - n - 1 - )] = 0 # ignore all previous states to prevent repetition. - alpha[b, ( - n - 2)] = 0.01 * val # smoothing factor for the prev step + alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition. + alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step # compute attention weights self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) # compute context