stop token prediction update for tacotron model

This commit is contained in:
Eren Golge 2018-05-11 04:15:06 -07:00
Родитель 3ea1a5358d
Коммит 8be07ee3c5
1 изменённых файлов: 3 добавлений и 3 удалений

Просмотреть файл

@ -14,7 +14,7 @@ class Tacotron(nn.Module):
self.linear_dim = linear_dim
self.embedding = nn.Embedding(len(symbols), embedding_dim,
padding_idx=padding_idx)
print(" | > Number of characted : {}".format(len(symbols)))
print(" | > Number of characters : {}".format(len(symbols)))
self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(256, mel_dim, r)
@ -27,11 +27,11 @@ class Tacotron(nn.Module):
# batch x time x dim
encoder_outputs = self.encoder(inputs)
# batch x time x dim*r
mel_outputs, alignments = self.decoder(
mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs)
# Reshape
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments
return mel_outputs, linear_outputs, alignments, stop_tokens