зеркало из https://github.com/mozilla/TTS.git
guided attn #10
This commit is contained in:
Родитель
8c4e03cebf
Коммит
830a051a78
6
train.py
6
train.py
|
@ -106,10 +106,8 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
|
||||
# create attention mask
|
||||
# TODO: vectorize
|
||||
print(text_input_var.shape)
|
||||
print(mel_spec_var.shape)
|
||||
N = text_input_var.shape[1]
|
||||
T = mel_spec_var.shape[1]
|
||||
T = mel_spec_var.shape[1] / c.r
|
||||
M = np.zeros([N, T])
|
||||
for t in range(T):
|
||||
for n in range(N):
|
||||
|
@ -117,7 +115,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
M[n, t] = val
|
||||
e_x = np.exp(M - np.max(M))
|
||||
M = e_x / e_x.sum(axis=0) # only difference
|
||||
M = Variable(torch.FloatTensor(M)).cuda()
|
||||
M = Variable(torch.FloatTensor(M).t()).cuda()
|
||||
M = torch.stack([M]*32)
|
||||
|
||||
# forward pass
|
||||
|
|
Загрузка…
Ссылка в новой задаче