From 830a051a78cd6c39a6e1836102d15f45ad6475aa Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 24 Apr 2018 11:39:02 -0700 Subject: [PATCH] guided attn #10 --- train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index fcd4d92..6b56ed5 100644 --- a/train.py +++ b/train.py @@ -106,10 +106,8 @@ def train(model, criterion, data_loader, optimizer, epoch): # create attention mask # TODO: vectorize - print(text_input_var.shape) - print(mel_spec_var.shape) N = text_input_var.shape[1] - T = mel_spec_var.shape[1] + T = mel_spec_var.shape[1] / c.r M = np.zeros([N, T]) for t in range(T): for n in range(N): @@ -117,7 +115,7 @@ def train(model, criterion, data_loader, optimizer, epoch): M[n, t] = val e_x = np.exp(M - np.max(M)) M = e_x / e_x.sum(axis=0) # only difference - M = Variable(torch.FloatTensor(M)).cuda() + M = Variable(torch.FloatTensor(M).t()).cuda() M = torch.stack([M]*32) # forward pass