Modularize functions in Tacotron

This commit is contained in:
Eren Golge 2019-03-05 13:25:50 +01:00
Родитель cc34fe4c7c
Коммит 1e8fdec084
5 изменённых файлов: 121 добавлений и 91 удалений

Просмотреть файл

@ -29,7 +29,6 @@
"url": "tcp:\/\/localhost:54321"
},
"embedding_size": 256, // Character embedding vector length. You don't need to change it in general.
"text_cleaner": "phoneme_cleaners",
"epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.

Просмотреть файл

@ -301,7 +301,8 @@ class Decoder(nn.Module):
memory_size (int): size of the past window. if <= 0 memory_size = r
"""
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing):
def __init__(self, in_features, memory_dim, r, memory_size,
attn_windowing):
super(Decoder, self).__init__()
self.r = r
self.in_features = in_features
@ -309,7 +310,8 @@ class Decoder(nn.Module):
self.memory_size = memory_size if memory_size > 0 else r
self.memory_dim = memory_dim
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * self.memory_size, out_features=[256, 128])
self.prenet = Prenet(
memory_dim * self.memory_size, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNNCell(
out_dim=128,
@ -360,116 +362,135 @@ class Decoder(nn.Module):
B = inputs.size(0)
T = inputs.size(1)
# go frame as zeros matrix
initial_memory = self.memory_init(inputs.data.new_zeros(B).long())
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
# decoder states
attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long())
decoder_rnn_hiddens = [
self.decoder_rnn_inits(inputs.data.new_tensor([idx]*B).long())
self.attention_rnn_hidden = self.attention_rnn_init(
inputs.data.new_zeros(B).long())
self.decoder_rnn_hiddens = [
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
for idx in range(len(self.decoder_rnns))
]
current_context_vec = inputs.data.new(B, self.in_features).zero_()
self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_()
return (initial_memory, attention_rnn_hidden, decoder_rnn_hiddens,
current_context_vec, attention, attention_cum)
self.attention = inputs.data.new(B, T).zero_()
self.attention_cum = inputs.data.new(B, T).zero_()
def forward(self, inputs, memory=None, mask=None):
def _parse_outputs(self, outputs, stop_tokens, attentions):
# Back to batch first
attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, stop_tokens, attentions
def decode(self,
inputs,
t,
mask=None):
# Prenet
processed_memory = self.prenet(self.memory_input)
# Attention RNN
attention_cat = torch.cat(
(self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)), dim=1)
self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn(
processed_memory, self.current_context_vec, self.attention_rnn_hidden,
inputs, attention_cat, mask, t)
del attention_cat
self.attention_cum += self.attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((self.attention_rnn_hidden, self.current_context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
decoder_input, self.decoder_rnn_hiddens[idx])
# Residual connection
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input
del decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output)
# predict stop token
stopnet_input = torch.cat([decoder_output, output], -1)
del decoder_output
stop_token = self.stopnet(stopnet_input)
return output, stop_token, self.attention
def _update_memory_queue(self, new_memory):
if self.memory_size > 0:
self.memory_input = torch.cat([
self.memory_input[:, self.r * self.memory_dim:].clone(),
new_memory
], dim=-1)
else:
self.memory_input = new_memory
def forward(self, inputs, memory, mask):
"""
Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in
Tacotron paper, greedy decoding is adapted.
Args:
inputs: Encoder outputs.
memory (None): Decoder memory (autoregression. If None (at eval-time),
memory: Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. If None, it uses the last
output as the input.
mask (None): Attention mask for sequence padding.
mask: Attention mask for sequence padding.
Shapes:
- inputs: batch x time x encoder_out_dim
- memory: batch x #mel_specs x mel_spec_dim
"""
# Run greedy decoding if memory is None
greedy = not self.training
if memory is not None:
memory = self._reshape_memory(memory)
T_decoder = memory.size(0)
memory = self._reshape_memory(memory)
outputs = []
attentions = []
stop_tokens = []
t = 0
memory_input, attention_rnn_hidden, decoder_rnn_hiddens,\
current_context_vec, attention, attention_cum = self._init_states(inputs)
while True:
self._init_states(inputs)
while len(outputs) < memory.size(0):
if t > 0:
if memory is None:
new_memory = outputs[-1]
else:
new_memory = memory[t - 1]
# Queuing if memory size defined else use previous prediction only.
if self.memory_size > 0:
memory_input = torch.cat([memory_input[:, self.r * self.memory_dim:].clone(), new_memory], dim=-1)
else:
memory_input = new_memory
# Prenet
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_cat = torch.cat(
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, mask, t)
del attention_cat
attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
decoder_input, decoder_rnn_hiddens[idx])
# Residual connection
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input
del decoder_input
# predict mel vectors from decoder vectors
output = self.proj_to_mel(decoder_output)
output = torch.sigmoid(output)
# predict stop token
stopnet_input = torch.cat([decoder_output, output], -1)
del decoder_output
stop_token = self.stopnet(stopnet_input)
del stopnet_input
new_memory = memory[t - 1]
self._update_memory_queue(new_memory)
output, stop_token, attention = self.decode(inputs, t, mask)
outputs += [output]
attentions += [attention]
stop_tokens += [stop_token]
del output
t += 1
if memory is not None:
if t >= T_decoder:
break
else:
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or
attention[:, -1].item() > 0.6):
break
elif t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break
# Back to batch first
attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, attentions, stop_tokens
return self._parse_outputs(outputs, attentions, stop_tokens)
def inference(self, inputs):
"""
Args:
inputs: Encoder outputs.
Shapes:
- inputs: batch x time x encoder_out_dim
"""
outputs = []
attentions = []
stop_tokens = []
t = 0
self._init_states(inputs)
while True:
if t > 0:
new_memory = outputs[-1]
self._update_memory_queue(new_memory)
output, stop_token, attention = self.decode(inputs, t, None)
outputs += [output]
attentions += [attention]
stop_tokens += [stop_token]
t += 1
if t > inputs.shape[1] / 4 and (stop_token > 0.6
or attention[:, -1].item() > 0.6):
break
elif t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break
return self._parse_outputs(outputs, attentions, stop_tokens)
class StopNet(nn.Module):
r"""
Predicting stop-token in decoder.
Args:
in_features (int): feature dimension of input.
"""

Просмотреть файл

@ -8,21 +8,19 @@ from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
class Tacotron(nn.Module):
def __init__(self,
num_chars,
embedding_dim=256,
linear_dim=1025,
mel_dim=80,
r=5,
padding_idx=None,
padding_idx=None,
memory_size=5,
attn_windowing=False):
super(Tacotron, self).__init__()
self.r = r
self.mel_dim = mel_dim
self.linear_dim = linear_dim
self.embedding = nn.Embedding(
num_chars, embedding_dim, padding_idx=padding_idx)
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim)
self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing)
self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Sequential(
@ -37,9 +35,22 @@ class Tacotron(nn.Module):
# batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask)
# Reshape
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens
def inference(self, characters):
B = characters.size(0)
inputs = self.embedding(characters)
# batch x time x dim
encoder_outputs = self.encoder(inputs)
# batch x time x dim*r
mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens

Просмотреть файл

@ -390,7 +390,6 @@ def main(args):
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = Tacotron(
num_chars=num_chars,
embedding_dim=c.embedding_size,
linear_dim=ap.num_freq,
mel_dim=ap.num_mels,
r=c.r,

Просмотреть файл

@ -20,7 +20,7 @@ def synthesis(m, s, CONFIG, use_cuda, ap):
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda()
mel_spec, linear_spec, alignments, stop_tokens = m.forward(
mel_spec, linear_spec, alignments, stop_tokens = m.inference(
chars_var.long())
linear_spec = linear_spec[0].data.cpu().numpy()
mel_spec = mel_spec[0].data.cpu().numpy()