Merge branch 'larger-attn-larger-model-sigmoid' into attn-smoothing-bgs-sigmoid-wd

This commit is contained in:
Eren 2018-09-26 16:51:34 +02:00
Родитель c89a3098dd 95eb3367bd
Коммит 3dda49f1c9
3 изменённых файлов: 6 добавлений и 3 удалений

Просмотреть файл

@ -1,6 +1,6 @@
{
"model_name": "TTS-weight-decay",
"model_description": "Weight decay as in FastAI",
"model_name": "TTS-sigmoid",
"model_description": "Net outputting Sigmoid unit",
"audio_processor": "audio",
"num_mels": 80,
"num_freq": 1025,

Просмотреть файл

@ -374,6 +374,7 @@ class Decoder(nn.Module):
decoder_output = decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output)
stop_input = output
# predict stop token
stop_token, stopnet_rnn_hidden = self.stopnet(

Просмотреть файл

@ -23,7 +23,9 @@ class Tacotron(nn.Module):
self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(256, mel_dim, r)
self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim)
self.last_linear = nn.Sequential(
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
nn.Sigmoid())
def forward(self, characters, mel_specs=None, mask=None):
B = characters.size(0)