This commit is contained in:
Eren Golge 2018-04-24 11:39:02 -07:00
Родитель 8c4e03cebf
Коммит 830a051a78
1 изменённых файлов: 2 добавлений и 4 удалений

Просмотреть файл

@ -106,10 +106,8 @@ def train(model, criterion, data_loader, optimizer, epoch):
# create attention mask
# TODO: vectorize
print(text_input_var.shape)
print(mel_spec_var.shape)
N = text_input_var.shape[1]
T = mel_spec_var.shape[1]
T = mel_spec_var.shape[1] / c.r
M = np.zeros([N, T])
for t in range(T):
for n in range(N):
@ -117,7 +115,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
M[n, t] = val
e_x = np.exp(M - np.max(M))
M = e_x / e_x.sum(axis=0) # only difference
M = Variable(torch.FloatTensor(M)).cuda()
M = Variable(torch.FloatTensor(M).t()).cuda()
M = torch.stack([M]*32)
# forward pass