зеркало из https://github.com/mozilla/TTS.git
Merge branch 'upstream_clean' of https://github.com/geneing/TTS into geneing-upstream_clean
This commit is contained in:
Коммит
71af8da293
|
@ -127,8 +127,8 @@ class GravesAttention(nn.Module):
|
|||
self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.constant_(self.N_a[2].bias[10:15], 0.5)
|
||||
torch.nn.init.constant_(self.N_a[2].bias[5:10], 10)
|
||||
torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.)
|
||||
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10)
|
||||
|
||||
def init_states(self, inputs):
|
||||
if self.J is None or inputs.shape[1] > self.J.shape[-1]:
|
||||
|
@ -159,20 +159,21 @@ class GravesAttention(nn.Module):
|
|||
k_t = gbk_t[:, 2, :]
|
||||
|
||||
# attention GMM parameters
|
||||
inv_sig_t = torch.exp(-torch.clamp(b_t, min=-6, max=9)) # variance
|
||||
sig_t = torch.nn.functional.softplus(b_t)+self.eps
|
||||
|
||||
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
||||
g_t = torch.softmax(g_t, dim=-1) * inv_sig_t + self.eps
|
||||
g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps
|
||||
|
||||
# each B x K x T_in
|
||||
g_t = g_t.unsqueeze(2).expand(g_t.size(0),
|
||||
g_t.size(1),
|
||||
inputs.size(1))
|
||||
inv_sig_t = inv_sig_t.unsqueeze(2).expand_as(g_t)
|
||||
sig_t = sig_t.unsqueeze(2).expand_as(g_t)
|
||||
mu_t_ = mu_t.unsqueeze(2).expand_as(g_t)
|
||||
j = self.J[:g_t.size(0), :, :inputs.size(1)]
|
||||
|
||||
# attention weights
|
||||
phi_t = g_t * torch.exp(-0.5 * inv_sig_t * (mu_t_ - j)**2)
|
||||
phi_t = g_t * torch.exp(-0.5 * (mu_t_ - j)**2 / (sig_t**2))
|
||||
alpha_t = self.COEF * torch.sum(phi_t, 1)
|
||||
|
||||
# apply masking
|
||||
|
|
|
@ -55,8 +55,11 @@ def tts():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not config or not synthesizer:
|
||||
args = create_argparser().parse_args()
|
||||
args = create_argparser().parse_args()
|
||||
|
||||
# Setup synthesizer from CLI args if they're specified or no embedded model
|
||||
# is present.
|
||||
if not config or not synthesizer or args.tts_checkpoint or args.tts_config:
|
||||
synthesizer = Synthesizer(args)
|
||||
|
||||
app.run(debug=config.debug, host='0.0.0.0', port=config.port)
|
||||
|
|
|
@ -53,15 +53,14 @@ class Synthesizer(object):
|
|||
num_speakers = 0
|
||||
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
||||
# load model state
|
||||
map_location = None if use_cuda else torch.device('cpu')
|
||||
cp = torch.load(tts_checkpoint, map_location=map_location)
|
||||
cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
|
||||
# load the model
|
||||
self.tts_model.load_state_dict(cp['model'])
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
self.tts_model.eval()
|
||||
self.tts_model.decoder.max_decoder_steps = 3000
|
||||
if 'r' in cp and self.tts_config.model in ["Tacotron", "TacotronGST"]:
|
||||
if 'r' in cp:
|
||||
self.tts_model.decoder.set_r(cp['r'])
|
||||
|
||||
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
||||
|
|
Загрузка…
Ссылка в новой задаче