зеркало из https://github.com/mozilla/TTS.git
renaming layers to be converted to TF counterpart
This commit is contained in:
Родитель
bee288fa93
Коммит
d282222553
|
@ -33,7 +33,7 @@ class LinearBN(nn.Module):
|
|||
super(LinearBN, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
self.bn = nn.BatchNorm1d(out_features)
|
||||
self.batch_normalization = nn.BatchNorm1d(out_features)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
|
@ -45,7 +45,7 @@ class LinearBN(nn.Module):
|
|||
out = self.linear_layer(x)
|
||||
if len(out.shape) == 3:
|
||||
out = out.permute(1, 2, 0)
|
||||
out = self.bn(out)
|
||||
out = self.batch_normalization(out)
|
||||
if len(out.shape) == 3:
|
||||
out = out.permute(2, 0, 1)
|
||||
return out
|
||||
|
@ -63,18 +63,18 @@ class Prenet(nn.Module):
|
|||
self.prenet_dropout = prenet_dropout
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
if prenet_type == "bn":
|
||||
self.layers = nn.ModuleList([
|
||||
self.linear_layers = nn.ModuleList([
|
||||
LinearBN(in_size, out_size, bias=bias)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
elif prenet_type == "original":
|
||||
self.layers = nn.ModuleList([
|
||||
self.linear_layers = nn.ModuleList([
|
||||
Linear(in_size, out_size, bias=bias)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
for linear in self.linear_layers:
|
||||
if self.prenet_dropout:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
||||
else:
|
||||
|
@ -93,7 +93,7 @@ class LocationLayer(nn.Module):
|
|||
attention_n_filters=32,
|
||||
attention_kernel_size=31):
|
||||
super(LocationLayer, self).__init__()
|
||||
self.location_conv = nn.Conv1d(
|
||||
self.location_conv1d = nn.Conv1d(
|
||||
in_channels=2,
|
||||
out_channels=attention_n_filters,
|
||||
kernel_size=attention_kernel_size,
|
||||
|
@ -104,7 +104,7 @@ class LocationLayer(nn.Module):
|
|||
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
||||
|
||||
def forward(self, attention_cat):
|
||||
processed_attention = self.location_conv(attention_cat)
|
||||
processed_attention = self.location_conv1d(attention_cat)
|
||||
processed_attention = self.location_dense(
|
||||
processed_attention.transpose(1, 2))
|
||||
return processed_attention
|
||||
|
|
|
@ -6,130 +6,126 @@ from .common_layers import init_attn, Prenet, Linear
|
|||
|
||||
|
||||
class ConvBNBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, activation=None):
|
||||
super(ConvBNBlock, self).__init__()
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv1d = nn.Conv1d(in_channels,
|
||||
self.convolution1d = nn.Conv1d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=padding)
|
||||
norm = nn.BatchNorm1d(out_channels)
|
||||
dropout = nn.Dropout(p=0.5)
|
||||
if nonlinear == 'relu':
|
||||
self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout)
|
||||
elif nonlinear == 'tanh':
|
||||
self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout)
|
||||
self.batch_normalization = nn.BatchNorm1d(out_channels)
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
if activation == 'relu':
|
||||
self.activation = nn.ReLU()
|
||||
elif activation == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
else:
|
||||
self.net = nn.Sequential(conv1d, norm, dropout)
|
||||
self.activation = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
output = self.net(x)
|
||||
return output
|
||||
o = self.convolution1d(x)
|
||||
o = self.batch_normalization(o)
|
||||
o = self.activation(o)
|
||||
o = self.dropout(o)
|
||||
return o
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
def __init__(self, mel_dim, num_convs=5):
|
||||
def __init__(self, output_dim, num_convs=5):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh'))
|
||||
ConvBNBlock(output_dim, 512, kernel_size=5, activation='tanh'))
|
||||
for _ in range(1, num_convs - 1):
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh'))
|
||||
ConvBNBlock(512, 512, kernel_size=5, activation='tanh'))
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None))
|
||||
ConvBNBlock(512, output_dim, kernel_size=5, activation=None))
|
||||
|
||||
def forward(self, x):
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
x = layer(x)
|
||||
return x
|
||||
o = layer(o)
|
||||
return o
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_features=512):
|
||||
def __init__(self, output_input_dim=512):
|
||||
super(Encoder, self).__init__()
|
||||
convolutions = []
|
||||
self.convolutions = nn.ModuleList()
|
||||
for _ in range(3):
|
||||
convolutions.append(
|
||||
ConvBNBlock(in_features, in_features, 5, 'relu'))
|
||||
self.convolutions = nn.Sequential(*convolutions)
|
||||
self.lstm = nn.LSTM(in_features,
|
||||
int(in_features / 2),
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(output_input_dim, output_input_dim, 5, 'relu'))
|
||||
self.lstm = nn.LSTM(output_input_dim,
|
||||
int(output_input_dim / 2),
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True)
|
||||
self.rnn_state = None
|
||||
|
||||
def forward(self, x, input_lengths):
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
x = nn.utils.rnn.pack_padded_sequence(x,
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o)
|
||||
o = o.transpose(1, 2)
|
||||
o = nn.utils.rnn.pack_padded_sequence(o,
|
||||
input_lengths,
|
||||
batch_first=True)
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, _ = self.lstm(x)
|
||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
outputs,
|
||||
batch_first=True,
|
||||
)
|
||||
return outputs
|
||||
o, _ = self.lstm(o)
|
||||
o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True)
|
||||
return o
|
||||
|
||||
def inference(self, x):
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
o = x
|
||||
for layer in self.convolutions:
|
||||
o = layer(o)
|
||||
o = x.transpose(1, 2)
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, _ = self.lstm(x)
|
||||
return outputs
|
||||
|
||||
def inference_truncated(self, x):
|
||||
"""
|
||||
Preserve encoder state for continuous inference
|
||||
"""
|
||||
x = self.convolutions(x)
|
||||
x = x.transpose(1, 2)
|
||||
self.lstm.flatten_parameters()
|
||||
outputs, self.rnn_state = self.lstm(x, self.rnn_state)
|
||||
return outputs
|
||||
o, _ = self.lstm(o)
|
||||
return o
|
||||
|
||||
|
||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||
class Decoder(nn.Module):
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
def __init__(self, in_features, memory_dim, r, attn_type, attn_win, attn_norm,
|
||||
def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm,
|
||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||
forward_attn_mask, location_attn, attn_K, separate_stopnet,
|
||||
speaker_embedding_dim):
|
||||
super(Decoder, self).__init__()
|
||||
self.memory_dim = memory_dim
|
||||
self.frame_dim = frame_dim
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
self.encoder_embedding_dim = in_features
|
||||
self.encoder_embedding_dim = input_dim
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.max_decoder_steps = 1000
|
||||
self.gate_threshold = 0.5
|
||||
|
||||
# model dimensions
|
||||
self.query_dim = 1024
|
||||
self.decoder_rnn_dim = 1024
|
||||
self.prenet_dim = 256
|
||||
self.max_decoder_steps = 1000
|
||||
self.gate_threshold = 0.5
|
||||
self.attn_dim = 128
|
||||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
prenet_dim = self.memory_dim
|
||||
self.prenet = Prenet(
|
||||
prenet_dim,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[self.prenet_dim, self.prenet_dim],
|
||||
bias=False)
|
||||
prenet_dim = self.frame_dim
|
||||
self.prenet = Prenet(prenet_dim,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[self.prenet_dim, self.prenet_dim],
|
||||
bias=False)
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim,
|
||||
self.query_dim)
|
||||
|
||||
self.attention = init_attn(attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=in_features,
|
||||
embedding_dim=input_dim,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
|
@ -141,15 +137,15 @@ class Decoder(nn.Module):
|
|||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features,
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim,
|
||||
self.decoder_rnn_dim, 1)
|
||||
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_features,
|
||||
self.memory_dim * self.r_init)
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + input_dim,
|
||||
self.frame_dim * self.r_init)
|
||||
|
||||
self.stopnet = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init,
|
||||
Linear(self.decoder_rnn_dim + self.frame_dim * self.r_init,
|
||||
1,
|
||||
bias=True,
|
||||
init_gain='sigmoid'))
|
||||
|
@ -161,7 +157,7 @@ class Decoder(nn.Module):
|
|||
def get_go_frame(self, inputs):
|
||||
B = inputs.size(0)
|
||||
memory = torch.zeros(1, device=inputs.device).repeat(B,
|
||||
self.memory_dim * self.r)
|
||||
self.frame_dim * self.r)
|
||||
return memory
|
||||
|
||||
def _init_states(self, inputs, mask, keep_states=False):
|
||||
|
@ -187,9 +183,9 @@ class Decoder(nn.Module):
|
|||
Reshape the spectrograms for given 'r'
|
||||
"""
|
||||
# Grouping multiple frames if necessary
|
||||
if memory.size(-1) == self.memory_dim:
|
||||
if memory.size(-1) == self.frame_dim:
|
||||
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
# Time first (T_decoder, B, frame_dim)
|
||||
memory = memory.transpose(0, 1)
|
||||
return memory
|
||||
|
||||
|
@ -197,22 +193,22 @@ class Decoder(nn.Module):
|
|||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
outputs = outputs.view(outputs.size(0), -1, self.memory_dim)
|
||||
outputs = outputs.view(outputs.size(0), -1, self.frame_dim)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
return outputs, stop_tokens, alignments
|
||||
|
||||
def _update_memory(self, memory):
|
||||
if len(memory.shape) == 2:
|
||||
return memory[:, self.memory_dim * (self.r - 1):]
|
||||
return memory[:, :, self.memory_dim * (self.r - 1):]
|
||||
return memory[:, self.frame_dim * (self.r - 1):]
|
||||
return memory[:, :, self.frame_dim * (self.r - 1):]
|
||||
|
||||
def decode(self, memory):
|
||||
'''
|
||||
shapes:
|
||||
- memory: B x r * self.memory_dim
|
||||
- memory: B x r * self.frame_dim
|
||||
'''
|
||||
# self.context: B x D_en
|
||||
# query_input: B x D_en + (r * self.memory_dim)
|
||||
# query_input: B x D_en + (r * self.frame_dim)
|
||||
query_input = torch.cat((memory, self.context), -1)
|
||||
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
|
||||
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||
|
@ -235,16 +231,16 @@ class Decoder(nn.Module):
|
|||
# B x (D_decoder_rnn + D_en)
|
||||
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
|
||||
dim=1)
|
||||
# B x (self.r * self.memory_dim)
|
||||
# B x (self.r * self.frame_dim)
|
||||
decoder_output = self.linear_projection(decoder_hidden_context)
|
||||
# B x (D_decoder_rnn + (self.r * self.memory_dim))
|
||||
# B x (D_decoder_rnn + (self.r * self.frame_dim))
|
||||
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
|
||||
if self.separate_stopnet:
|
||||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
# select outputs for the reduction rate self.r
|
||||
decoder_output = decoder_output[:, :self.r * self.memory_dim]
|
||||
decoder_output = decoder_output[:, :self.r * self.frame_dim]
|
||||
return decoder_output, self.attention.attention_weights, stop_token
|
||||
|
||||
def forward(self, inputs, memories, mask, speaker_embeddings=None):
|
||||
|
|
|
@ -29,7 +29,7 @@ class Tacotron2(nn.Module):
|
|||
super(Tacotron2, self).__init__()
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.n_frames_per_step = r
|
||||
self.r = r
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
decoder_dim = 512 if num_speakers > 1 else 512
|
||||
encoder_dim = 512 if num_speakers > 1 else 512
|
||||
|
|
Загрузка…
Ссылка в новой задаче