зеркало из https://github.com/mozilla/TTS.git
Modularize functions in Tacotron
This commit is contained in:
Родитель
cc34fe4c7c
Коммит
1e8fdec084
|
@ -29,7 +29,6 @@
|
|||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"embedding_size": 256, // Character embedding vector length. You don't need to change it in general.
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
|
|
|
@ -301,7 +301,8 @@ class Decoder(nn.Module):
|
|||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing):
|
||||
def __init__(self, in_features, memory_dim, r, memory_size,
|
||||
attn_windowing):
|
||||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
|
@ -309,7 +310,8 @@ class Decoder(nn.Module):
|
|||
self.memory_size = memory_size if memory_size > 0 else r
|
||||
self.memory_dim = memory_dim
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * self.memory_size, out_features=[256, 128])
|
||||
self.prenet = Prenet(
|
||||
memory_dim * self.memory_size, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNNCell(
|
||||
out_dim=128,
|
||||
|
@ -360,116 +362,135 @@ class Decoder(nn.Module):
|
|||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
# go frame as zeros matrix
|
||||
initial_memory = self.memory_init(inputs.data.new_zeros(B).long())
|
||||
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
||||
|
||||
# decoder states
|
||||
attention_rnn_hidden = self.attention_rnn_init(inputs.data.new_zeros(B).long())
|
||||
decoder_rnn_hiddens = [
|
||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx]*B).long())
|
||||
self.attention_rnn_hidden = self.attention_rnn_init(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.decoder_rnn_hiddens = [
|
||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
||||
for idx in range(len(self.decoder_rnns))
|
||||
]
|
||||
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
self.current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
attention_cum = inputs.data.new(B, T).zero_()
|
||||
return (initial_memory, attention_rnn_hidden, decoder_rnn_hiddens,
|
||||
current_context_vec, attention, attention_cum)
|
||||
self.attention = inputs.data.new(B, T).zero_()
|
||||
self.attention_cum = inputs.data.new(B, T).zero_()
|
||||
|
||||
def forward(self, inputs, memory=None, mask=None):
|
||||
def _parse_outputs(self, outputs, stop_tokens, attentions):
|
||||
# Back to batch first
|
||||
attentions = torch.stack(attentions).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
def decode(self,
|
||||
inputs,
|
||||
t,
|
||||
mask=None):
|
||||
# Prenet
|
||||
processed_memory = self.prenet(self.memory_input)
|
||||
# Attention RNN
|
||||
attention_cat = torch.cat(
|
||||
(self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)), dim=1)
|
||||
self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn(
|
||||
processed_memory, self.current_context_vec, self.attention_rnn_hidden,
|
||||
inputs, attention_cat, mask, t)
|
||||
del attention_cat
|
||||
self.attention_cum += self.attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((self.attention_rnn_hidden, self.current_context_vec), -1))
|
||||
# Pass through the decoder RNNs
|
||||
for idx in range(len(self.decoder_rnns)):
|
||||
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
||||
decoder_input, self.decoder_rnn_hiddens[idx])
|
||||
# Residual connection
|
||||
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
||||
decoder_output = decoder_input
|
||||
del decoder_input
|
||||
# predict mel vectors from decoder vectors
|
||||
output = self.proj_to_mel(decoder_output)
|
||||
output = torch.sigmoid(output)
|
||||
# predict stop token
|
||||
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||
del decoder_output
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
return output, stop_token, self.attention
|
||||
|
||||
def _update_memory_queue(self, new_memory):
|
||||
if self.memory_size > 0:
|
||||
self.memory_input = torch.cat([
|
||||
self.memory_input[:, self.r * self.memory_dim:].clone(),
|
||||
new_memory
|
||||
], dim=-1)
|
||||
else:
|
||||
self.memory_input = new_memory
|
||||
|
||||
def forward(self, inputs, memory, mask):
|
||||
"""
|
||||
Decoder forward step.
|
||||
|
||||
If decoder inputs are not given (e.g., at testing time), as noted in
|
||||
Tacotron paper, greedy decoding is adapted.
|
||||
|
||||
Args:
|
||||
inputs: Encoder outputs.
|
||||
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
||||
memory: Decoder memory (autoregression. If None (at eval-time),
|
||||
decoder outputs are used as decoder inputs. If None, it uses the last
|
||||
output as the input.
|
||||
mask (None): Attention mask for sequence padding.
|
||||
mask: Attention mask for sequence padding.
|
||||
|
||||
Shapes:
|
||||
- inputs: batch x time x encoder_out_dim
|
||||
- memory: batch x #mel_specs x mel_spec_dim
|
||||
"""
|
||||
# Run greedy decoding if memory is None
|
||||
greedy = not self.training
|
||||
if memory is not None:
|
||||
memory = self._reshape_memory(memory)
|
||||
T_decoder = memory.size(0)
|
||||
memory = self._reshape_memory(memory)
|
||||
outputs = []
|
||||
attentions = []
|
||||
stop_tokens = []
|
||||
t = 0
|
||||
memory_input, attention_rnn_hidden, decoder_rnn_hiddens,\
|
||||
current_context_vec, attention, attention_cum = self._init_states(inputs)
|
||||
while True:
|
||||
self._init_states(inputs)
|
||||
while len(outputs) < memory.size(0):
|
||||
if t > 0:
|
||||
if memory is None:
|
||||
new_memory = outputs[-1]
|
||||
else:
|
||||
new_memory = memory[t - 1]
|
||||
# Queuing if memory size defined else use previous prediction only.
|
||||
if self.memory_size > 0:
|
||||
memory_input = torch.cat([memory_input[:, self.r * self.memory_dim:].clone(), new_memory], dim=-1)
|
||||
else:
|
||||
memory_input = new_memory
|
||||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
# Attention RNN
|
||||
attention_cat = torch.cat(
|
||||
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||
inputs, attention_cat, mask, t)
|
||||
del attention_cat
|
||||
attention_cum += attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||
# Pass through the decoder RNNs
|
||||
for idx in range(len(self.decoder_rnns)):
|
||||
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
||||
decoder_input, decoder_rnn_hiddens[idx])
|
||||
# Residual connection
|
||||
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
||||
decoder_output = decoder_input
|
||||
del decoder_input
|
||||
# predict mel vectors from decoder vectors
|
||||
output = self.proj_to_mel(decoder_output)
|
||||
output = torch.sigmoid(output)
|
||||
# predict stop token
|
||||
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||
del decoder_output
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
del stopnet_input
|
||||
new_memory = memory[t - 1]
|
||||
self._update_memory_queue(new_memory)
|
||||
output, stop_token, attention = self.decode(inputs, t, mask)
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
del output
|
||||
t += 1
|
||||
if memory is not None:
|
||||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or
|
||||
attention[:, -1].item() > 0.6):
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
# Back to batch first
|
||||
attentions = torch.stack(attentions).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
return outputs, attentions, stop_tokens
|
||||
|
||||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||
|
||||
def inference(self, inputs):
|
||||
"""
|
||||
Args:
|
||||
inputs: Encoder outputs.
|
||||
|
||||
Shapes:
|
||||
- inputs: batch x time x encoder_out_dim
|
||||
"""
|
||||
outputs = []
|
||||
attentions = []
|
||||
stop_tokens = []
|
||||
t = 0
|
||||
self._init_states(inputs)
|
||||
while True:
|
||||
if t > 0:
|
||||
new_memory = outputs[-1]
|
||||
self._update_memory_queue(new_memory)
|
||||
output, stop_token, attention = self.decode(inputs, t, None)
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
t += 1
|
||||
if t > inputs.shape[1] / 4 and (stop_token > 0.6
|
||||
or attention[:, -1].item() > 0.6):
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||||
|
||||
|
||||
class StopNet(nn.Module):
|
||||
r"""
|
||||
Predicting stop-token in decoder.
|
||||
|
||||
Args:
|
||||
in_features (int): feature dimension of input.
|
||||
"""
|
||||
|
|
|
@ -8,21 +8,19 @@ from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
|
|||
class Tacotron(nn.Module):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
embedding_dim=256,
|
||||
linear_dim=1025,
|
||||
mel_dim=80,
|
||||
r=5,
|
||||
padding_idx=None,
|
||||
padding_idx=None,
|
||||
memory_size=5,
|
||||
attn_windowing=False):
|
||||
super(Tacotron, self).__init__()
|
||||
self.r = r
|
||||
self.mel_dim = mel_dim
|
||||
self.linear_dim = linear_dim
|
||||
self.embedding = nn.Embedding(
|
||||
num_chars, embedding_dim, padding_idx=padding_idx)
|
||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(embedding_dim)
|
||||
self.encoder = Encoder(256)
|
||||
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing)
|
||||
self.postnet = PostCBHG(mel_dim)
|
||||
self.last_linear = nn.Sequential(
|
||||
|
@ -37,9 +35,22 @@ class Tacotron(nn.Module):
|
|||
# batch x time x dim*r
|
||||
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, mask)
|
||||
# Reshape
|
||||
# batch x time x dim
|
||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||
linear_outputs = self.postnet(mel_outputs)
|
||||
linear_outputs = self.last_linear(linear_outputs)
|
||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||
|
||||
def inference(self, characters):
|
||||
B = characters.size(0)
|
||||
inputs = self.embedding(characters)
|
||||
# batch x time x dim
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
# batch x time x dim*r
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
# batch x time x dim
|
||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||
linear_outputs = self.postnet(mel_outputs)
|
||||
linear_outputs = self.last_linear(linear_outputs)
|
||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
1
train.py
1
train.py
|
@ -390,7 +390,6 @@ def main(args):
|
|||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = Tacotron(
|
||||
num_chars=num_chars,
|
||||
embedding_dim=c.embedding_size,
|
||||
linear_dim=ap.num_freq,
|
||||
mel_dim=ap.num_mels,
|
||||
r=c.r,
|
||||
|
|
|
@ -20,7 +20,7 @@ def synthesis(m, s, CONFIG, use_cuda, ap):
|
|||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
mel_spec, linear_spec, alignments, stop_tokens = m.forward(
|
||||
mel_spec, linear_spec, alignments, stop_tokens = m.inference(
|
||||
chars_var.long())
|
||||
linear_spec = linear_spec[0].data.cpu().numpy()
|
||||
mel_spec = mel_spec[0].data.cpu().numpy()
|
||||
|
|
Загрузка…
Ссылка в новой задаче