This commit is contained in:
erogol 2020-06-04 14:58:18 +02:00
Родитель a82f7d129d
Коммит fd8f1ecb7d
1 изменённых файлов: 24 добавлений и 4 удалений

Просмотреть файл

@ -4,12 +4,12 @@ from abc import ABC, abstractmethod
import torch
from torch import nn
from TTS.layers.gst_layers import GST
from TTS.utils.generic_utils import sequence_mask
class TacotronAbstract(ABC, nn.Module):
def __init__(self, num_chars,
def __init__(self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
@ -31,6 +31,7 @@ class TacotronAbstract(ABC, nn.Module):
gst=False):
""" Abstract Tacotron class """
super().__init__()
self.num_chars = num_chars
self.r = r
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
@ -39,6 +40,17 @@ class TacotronAbstract(ABC, nn.Module):
self.bidirectional_decoder = bidirectional_decoder
self.double_decoder_consistency = double_decoder_consistency
self.ddc_r = ddc_r
self.attn_type = attn_type
self.attn_win = attn_win
self.attn_norm = attn_norm
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
self.forward_attn = forward_attn
self.trans_agent = trans_agent
self.forward_attn_mask = forward_attn_mask
self.location_attn = location_attn
self.attn_K = attn_K
self.separate_stopnet = separate_stopnet
# layers
self.embedding = None
@ -48,9 +60,16 @@ class TacotronAbstract(ABC, nn.Module):
# global style token
if self.gst:
gst_embedding_dim = None
self.gst_layer = None
# model states
self.speaker_embeddings = None
self.speaker_embeddings_projected = None
# additional layers
self.decoder_backward = None
self.coarse_decoder = None
#############################
# INIT FUNCTIONS
#############################
@ -114,7 +133,7 @@ class TacotronAbstract(ABC, nn.Module):
(0, 0, 0, padding_size, 0, 0))
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
encoder_outputs.detach(), mel_specs, input_mask)
scale_factor = self.decoder.r_init / self.decoder.r
# scale_factor = self.decoder.r_init / self.decoder.r
alignments_backward = torch.nn.functional.interpolate(
alignments_backward.transpose(1, 2),
size=alignments.shape[1],
@ -141,6 +160,7 @@ class TacotronAbstract(ABC, nn.Module):
def compute_gst(self, inputs, mel_specs):
""" Compute global style token """
# pylint: disable=not-callable
gst_outputs = self.gst_layer(mel_specs)
inputs = self._add_speaker_embedding(inputs, gst_outputs)
return inputs