зеркало из https://github.com/mozilla/TTS.git
Update attention module Possible BUG FIX2
This commit is contained in:
Родитель
b6c5771a6f
Коммит
3cafc6568c
|
@ -63,6 +63,8 @@ class AttentionWrapper(nn.Module):
|
|||
# Alignment
|
||||
# (batch, max_time)
|
||||
# e_{ij} = a(s_{i-1}, h_j)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
alignment = self.alignment_model(cell_state, processed_inputs)
|
||||
|
||||
if mask is not None:
|
||||
|
@ -80,12 +82,13 @@ class AttentionWrapper(nn.Module):
|
|||
|
||||
# Concat input query and previous context_vec context
|
||||
cell_input = torch.cat((query, context_vec), -1)
|
||||
cell_input = cell_input.unsqueeze(1)
|
||||
#cell_input = cell_input.unsqueeze(1)
|
||||
|
||||
# Feed it to RNN
|
||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||
cell_output = self.rnn_cell(cell_input, cell_state)
|
||||
|
||||
context_vec = context_vec.squeeze(1)
|
||||
return cell_output, context_vec, alignment
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче