зеркало из https://github.com/mozilla/TTS.git
Setting up network size according to the reference paper
This commit is contained in:
Родитель
790a1b4639
Коммит
d5febfb187
|
@ -109,18 +109,20 @@ class CBHG(nn.Module):
|
|||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hid_features=128,
|
||||
K=16,
|
||||
projections=[128, 128],
|
||||
num_highways=4):
|
||||
super(CBHG, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.hid_features = hid_features
|
||||
self.relu = nn.ReLU()
|
||||
# list of conv1d bank with filter size k=1...K
|
||||
# TODO: try dilational layers instead
|
||||
self.conv1d_banks = nn.ModuleList([
|
||||
BatchNormConv1d(
|
||||
in_features,
|
||||
in_features,
|
||||
hid_features,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=k // 2,
|
||||
|
@ -129,7 +131,7 @@ class CBHG(nn.Module):
|
|||
# max pooling of conv bank
|
||||
# TODO: try average pooling OR larger kernel size
|
||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
out_features = [K * in_features] + projections[:-1]
|
||||
out_features = [K * hid_features] + projections[:-1]
|
||||
activations = [self.relu] * (len(projections) - 1)
|
||||
activations += [None]
|
||||
# setup conv1d projection layers
|
||||
|
@ -146,12 +148,13 @@ class CBHG(nn.Module):
|
|||
layer_set.append(layer)
|
||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||
# setup Highway layers
|
||||
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
|
||||
if self.hid_features != self.in_features:
|
||||
self.pre_highway = nn.Linear(projections[-1], hid_features, bias=False)
|
||||
self.highways = nn.ModuleList(
|
||||
[Highway(in_features, in_features) for _ in range(num_highways)])
|
||||
[Highway(hid_features, hid_features) for _ in range(num_highways)])
|
||||
# bi-directional GPU layer
|
||||
self.gru = nn.GRU(
|
||||
in_features, in_features, 1, batch_first=True, bidirectional=True)
|
||||
128, 128, 1, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
# (B, T_in, in_features)
|
||||
|
@ -161,7 +164,7 @@ class CBHG(nn.Module):
|
|||
if x.size(-1) == self.in_features:
|
||||
x = x.transpose(1, 2)
|
||||
T = x.size(-1)
|
||||
# (B, in_features*K, T_in)
|
||||
# (B, hid_features*K, T_in)
|
||||
# Concat conv1d bank outputs
|
||||
outs = []
|
||||
for conv1d in self.conv1d_banks:
|
||||
|
@ -169,35 +172,45 @@ class CBHG(nn.Module):
|
|||
out = out[:, :, :T]
|
||||
outs.append(out)
|
||||
x = torch.cat(outs, dim=1)
|
||||
assert x.size(1) == self.in_features * len(self.conv1d_banks)
|
||||
assert x.size(1) == self.hid_features * len(self.conv1d_banks)
|
||||
x = self.max_pool1d(x)[:, :, :T]
|
||||
for conv1d in self.conv1d_projections:
|
||||
x = conv1d(x)
|
||||
# (B, T_in, in_features)
|
||||
# Back to the original shape
|
||||
# (B, T_in, hid_feature)
|
||||
x = x.transpose(1, 2)
|
||||
if x.size(-1) != self.in_features:
|
||||
# Back to the original shape
|
||||
x += inputs
|
||||
if x.size(-1) != self.hid_features:
|
||||
x = self.pre_highway(x)
|
||||
# Residual connection
|
||||
# TODO: try residual scaling as in Deep Voice 3
|
||||
# TODO: try plain residual layers
|
||||
x += inputs
|
||||
for highway in self.highways:
|
||||
x = highway(x)
|
||||
# (B, T_in, in_features*2)
|
||||
# (B, T_in, hid_features*2)
|
||||
# TODO: replace GRU with convolution as in Deep Voice 3
|
||||
# self.gru.flatten_parameters()
|
||||
outputs, _ = self.gru(x)
|
||||
return outputs
|
||||
|
||||
|
||||
class EncoderCBHG(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(EncoderCBHG, self).__init__()
|
||||
self.cbhg = CBHG(128, hid_features=128, K=16, projections=[128, 128])
|
||||
|
||||
def forward(self, x):
|
||||
return self.cbhg(x)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""Encapsulate Prenet and CBHG modules for encoder"""
|
||||
|
||||
def __init__(self, in_features):
|
||||
super(Encoder, self).__init__()
|
||||
self.prenet = Prenet(in_features, out_features=[256, 128])
|
||||
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
||||
self.cbhg = EncoderCBHG()
|
||||
|
||||
def forward(self, inputs):
|
||||
r"""
|
||||
|
@ -212,6 +225,16 @@ class Encoder(nn.Module):
|
|||
return self.cbhg(inputs)
|
||||
|
||||
|
||||
class PostCBHG(nn.Module):
|
||||
|
||||
def __init__(self, mel_dim):
|
||||
super(PostCBHG, self).__init__()
|
||||
self.cbhg = CBHG(mel_dim, hid_features=128, K=8, projections=[256, mel_dim])
|
||||
|
||||
def forward(self, x):
|
||||
return self.cbhg(x)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
r"""Decoder module.
|
||||
|
||||
|
@ -336,10 +359,10 @@ class Decoder(nn.Module):
|
|||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > inputs.shape[1] / 2 and stop_token > 0.6:
|
||||
if t > inputs.shape[1] / 4 and stop_token > 0.6:
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" | | > Decoder stopped with 'max_decoder_steps")
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
assert greedy or len(outputs) == T_decoder
|
||||
# Back to batch first
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from utils.text.symbols import symbols
|
||||
from layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
||||
from layers.tacotron import Prenet, Encoder, Decoder, PostCBHG
|
||||
|
||||
|
||||
class Tacotron(nn.Module):
|
||||
|
@ -22,8 +22,8 @@ class Tacotron(nn.Module):
|
|||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(embedding_dim)
|
||||
self.decoder = Decoder(256, mel_dim, r)
|
||||
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
||||
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
||||
self.postnet = PostCBHG(mel_dim)
|
||||
self.last_linear = nn.Linear(256, linear_dim)
|
||||
|
||||
def forward(self, characters, mel_specs=None, text_lens=None):
|
||||
B = characters.size(0)
|
||||
|
|
4
train.py
4
train.py
|
@ -37,6 +37,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
avg_step_time = 0
|
||||
print(" | > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||
batch_n_iter = len(data_loader.dataset) / c.batch_size
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -114,9 +115,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
epoch_time += step_time
|
||||
|
||||
if current_step % c.print_step == 0:
|
||||
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||
print(" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter,
|
||||
batch_n_iter,
|
||||
current_step,
|
||||
loss.item(),
|
||||
linear_loss.item(),
|
||||
|
|
|
@ -120,9 +120,9 @@ class AudioProcessor(object):
|
|||
D = processor.run_lws(S.astype(np.float64).T**self.power)
|
||||
y = processor.istft(D).astype(np.float32)
|
||||
# Reconstruct phase
|
||||
sys.stdout = old_out
|
||||
if self.preemphasis:
|
||||
return self.apply_inv_preemphasis(y)
|
||||
sys.stdout = old_out
|
||||
return y
|
||||
|
||||
def _linear_to_mel(self, spectrogram):
|
||||
|
|
Загрузка…
Ссылка в новой задаче