зеркало из https://github.com/mozilla/TTS.git
lint updates
This commit is contained in:
Родитель
a82f7d129d
Коммит
fd8f1ecb7d
|
@ -4,12 +4,12 @@ from abc import ABC, abstractmethod
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.layers.gst_layers import GST
|
||||
from TTS.utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
class TacotronAbstract(ABC, nn.Module):
|
||||
def __init__(self, num_chars,
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
|
@ -31,6 +31,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
gst=False):
|
||||
""" Abstract Tacotron class """
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
|
@ -39,6 +40,17 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.double_decoder_consistency = double_decoder_consistency
|
||||
self.ddc_r = ddc_r
|
||||
self.attn_type = attn_type
|
||||
self.attn_win = attn_win
|
||||
self.attn_norm = attn_norm
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.forward_attn_mask = forward_attn_mask
|
||||
self.location_attn = location_attn
|
||||
self.attn_K = attn_K
|
||||
self.separate_stopnet = separate_stopnet
|
||||
|
||||
# layers
|
||||
self.embedding = None
|
||||
|
@ -48,9 +60,16 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
|
||||
# global style token
|
||||
if self.gst:
|
||||
gst_embedding_dim = None
|
||||
self.gst_layer = None
|
||||
|
||||
# model states
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
# additional layers
|
||||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
||||
#############################
|
||||
# INIT FUNCTIONS
|
||||
#############################
|
||||
|
@ -114,7 +133,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
(0, 0, 0, padding_size, 0, 0))
|
||||
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
|
||||
encoder_outputs.detach(), mel_specs, input_mask)
|
||||
scale_factor = self.decoder.r_init / self.decoder.r
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2),
|
||||
size=alignments.shape[1],
|
||||
|
@ -141,6 +160,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
|
||||
def compute_gst(self, inputs, mel_specs):
|
||||
""" Compute global style token """
|
||||
# pylint: disable=not-callable
|
||||
gst_outputs = self.gst_layer(mel_specs)
|
||||
inputs = self._add_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
|
Загрузка…
Ссылка в новой задаче