зеркало из https://github.com/mozilla/TTS.git
enforce monotonic attention in forward attention y for batches
This commit is contained in:
Родитель
d905f6e795
Коммит
0b5a00d29e
|
@ -203,16 +203,13 @@ class Attention(nn.Module):
|
|||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||
self.u * prev_alpha) + 1e-8) * alignment
|
||||
# force incremental alignment - TODO: make configurable
|
||||
if not self.training and alignment.shape[0] == 1:
|
||||
if not self.training:
|
||||
_, n = prev_alpha.max(1)
|
||||
val, n2 = alpha.max(1)
|
||||
for b in range(alignment.shape[0]):
|
||||
alpha[b, n + 2:] = 0
|
||||
alpha[b, :(
|
||||
n - 1
|
||||
)] = 0 # ignore all previous states to prevent repetition.
|
||||
alpha[b, (
|
||||
n - 2)] = 0.01 * val # smoothing factor for the prev step
|
||||
alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition.
|
||||
alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step
|
||||
# compute attention weights
|
||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
|
|
Загрузка…
Ссылка в новой задаче