зеркало из https://github.com/mozilla/TTS.git
Graves attention and setting attn type by config.json
This commit is contained in:
Родитель
84d81b6579
Коммит
adf9ebd629
|
@ -60,6 +60,8 @@
|
|||
"prenet_dropout": true, // enable/disable dropout at prenet.
|
||||
|
||||
// ATTENTION
|
||||
"attention_type": "original", // 'original' or 'graves'
|
||||
"attention_heads": 5, // number of attention heads (only for 'graves')
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"use_forward_attn": false, // if it uses forward attention. In general, it aligns faster.
|
||||
|
|
|
@ -106,25 +106,33 @@ class LocationLayer(nn.Module):
|
|||
|
||||
|
||||
class GravesAttention(nn.Module):
|
||||
""" Graves attention as described here:
|
||||
- https://arxiv.org/abs/1910.10288
|
||||
"""
|
||||
COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi))
|
||||
|
||||
def __init__(self, query_dim, K, attention_alignment=0.05):
|
||||
def __init__(self, query_dim, K):
|
||||
super(GravesAttention, self).__init__()
|
||||
self._mask_value = -float("inf")
|
||||
self._mask_value = 0.0
|
||||
self.K = K
|
||||
self.attention_alignment = attention_alignment
|
||||
# self.attention_alignment = 0.05
|
||||
self.epsilon = 1e-5
|
||||
self.J = None
|
||||
self.N_a = nn.Sequential(
|
||||
nn.Linear(query_dim, query_dim//2),
|
||||
nn.Tanh(),
|
||||
nn.Linear(query_dim//2, 3*K))
|
||||
self.mu_tm1 = None
|
||||
|
||||
self.attention_weights = None
|
||||
self.mu_prev = None
|
||||
|
||||
def init_states(self, inputs):
|
||||
if self.J is None or inputs.shape[1] > self.J.shape[-1]:
|
||||
self.J = torch.arange(0, inputs.shape[1]).expand_as(torch.Tensor(inputs.shape[0], self.K, inputs.shape[1])).to(inputs.device)
|
||||
self.mu_tm1 = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||
|
||||
def preprocess_inputs(self, inputs):
|
||||
return None
|
||||
|
||||
def forward(self, query, inputs, mask):
|
||||
"""
|
||||
|
@ -143,9 +151,12 @@ class GravesAttention(nn.Module):
|
|||
k_t = gbk_t[:, 2, :]
|
||||
|
||||
# attention GMM parameters
|
||||
g_t = torch.softmax(g_t, dim=-1) + self.epsilon # distribution weight
|
||||
sig_t = torch.exp(b_t) + self.epsilon # variance
|
||||
mu_t = self.mu_tm1 + self.attention_alignment * torch.exp(k_t) # mean
|
||||
# g_t = torch.softmax(g_t, dim=-1) + self.epsilon # distribution weight
|
||||
# sig_t = torch.exp(b_t) + self.epsilon # variance
|
||||
# mu_t = self.mu_prev + self.attention_alignment * torch.exp(k_t) # mean
|
||||
sig_t = torch.pow(torch.nn.functional.softplus(b_t), 2)
|
||||
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
||||
g_t = (torch.softmax(g_t, dim=-1) / sig_t) * self.COEF
|
||||
|
||||
g_t = g_t.unsqueeze(2).expand(g_t.size(0),
|
||||
g_t.size(1),
|
||||
|
@ -156,27 +167,33 @@ class GravesAttention(nn.Module):
|
|||
|
||||
# attention weights
|
||||
phi_t = g_t * torch.exp(-0.5 * sig_t * (mu_t_ - j)**2)
|
||||
alpha_t = self.COEF * torch.sum(phi_t, 1)
|
||||
alpha_t = torch.sum(phi_t, 1)
|
||||
|
||||
# apply masking
|
||||
# if mask is not None:
|
||||
# alpha_t.data.masked_fill_(~mask, self._mask_value)
|
||||
|
||||
if mask is not None:
|
||||
alpha_t.data.masked_fill_(~mask, self._mask_value)
|
||||
|
||||
context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
|
||||
self.attention_weights = alpha_t
|
||||
self.mu_prev = mu_t
|
||||
breakpoint()
|
||||
|
||||
c_t = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
|
||||
self.mu_tm1 = mu_t
|
||||
return c_t, mu_t, alpha_t
|
||||
return context
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
class OriginalAttention(nn.Module):
|
||||
"""Following the methods proposed here:
|
||||
- https://arxiv.org/abs/1712.05884
|
||||
- https://arxiv.org/abs/1807.06736 + state masking at inference
|
||||
- Using sigmoid instead of softmax normalization
|
||||
- Attention windowing at inference time
|
||||
"""
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
def __init__(self, query_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||
trans_agent, forward_attn_mask):
|
||||
super(Attention, self).__init__()
|
||||
super(OriginalAttention, self).__init__()
|
||||
self.query_layer = Linear(
|
||||
query_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.inputs_layer = Linear(
|
||||
|
@ -229,6 +246,9 @@ class Attention(nn.Module):
|
|||
if self.windowing:
|
||||
self.init_win_idx()
|
||||
|
||||
def preprocess_inputs(self, inputs):
|
||||
return self.inputs_layer(inputs)
|
||||
|
||||
def update_location_attention(self, alignments):
|
||||
self.attention_weights_cum += alignments
|
||||
|
||||
|
@ -337,3 +357,21 @@ class Attention(nn.Module):
|
|||
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
||||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context
|
||||
|
||||
|
||||
def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||
trans_agent, forward_attn_mask, attn_K):
|
||||
if attn_type == "original":
|
||||
return OriginalAttention(query_dim, embedding_dim, attention_dim,
|
||||
location_attention,
|
||||
attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing,
|
||||
norm, forward_attn, trans_agent,
|
||||
forward_attn_mask)
|
||||
elif attn_type == "graves":
|
||||
return GravesAttention(query_dim, attn_K)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
" [!] Given Attention Type '{attn_type}' is not exist.")
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
from torch import nn
|
||||
from .common_layers import Prenet, Attention, Linear, GravesAttention
|
||||
from .common_layers import Prenet, init_attn, Linear
|
||||
|
||||
|
||||
class BatchNormConv1d(nn.Module):
|
||||
|
@ -263,9 +263,9 @@ class Decoder(nn.Module):
|
|||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_type, attn_windowing,
|
||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||
trans_agent, forward_attn_mask, location_attn,
|
||||
trans_agent, forward_attn_mask, location_attn, attn_K,
|
||||
separate_stopnet, speaker_embedding_dim):
|
||||
super(Decoder, self).__init__()
|
||||
self.r_init = r
|
||||
|
@ -288,18 +288,19 @@ class Decoder(nn.Module):
|
|||
# attention_rnn generates queries for the attention mechanism
|
||||
self.attention_rnn = nn.GRUCell(in_features + 128, self.query_dim)
|
||||
|
||||
# self.attention = Attention(query_dim=self.query_dim,
|
||||
# embedding_dim=in_features,
|
||||
# attention_dim=128,
|
||||
# location_attention=location_attn,
|
||||
# attention_location_n_filters=32,
|
||||
# attention_location_kernel_size=31,
|
||||
# windowing=attn_windowing,
|
||||
# norm=attn_norm,
|
||||
# forward_attn=forward_attn,
|
||||
# trans_agent=trans_agent,
|
||||
# forward_attn_mask=forward_attn_mask)
|
||||
self.attention = GravesAttention(self.query_dim, 5)
|
||||
self.attention = init_attn(attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_windowing,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
@ -343,7 +344,7 @@ class Decoder(nn.Module):
|
|||
]
|
||||
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
# cache attention inputs
|
||||
# self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||
self.processed_inputs = self.attention.preprocess_inputs(inputs)
|
||||
|
||||
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
||||
# Back to batch first
|
||||
|
@ -363,7 +364,7 @@ class Decoder(nn.Module):
|
|||
torch.cat((processed_memory, self.context_vec), -1),
|
||||
self.attention_rnn_hidden)
|
||||
self.context_vec = self.attention(
|
||||
self.attention_rnn_hidden, inputs, mask)
|
||||
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from .common_layers import Attention, Prenet, Linear
|
||||
from .common_layers import init_attn, Prenet, Linear
|
||||
|
||||
|
||||
class ConvBNBlock(nn.Module):
|
||||
|
@ -98,9 +98,9 @@ class Encoder(nn.Module):
|
|||
class Decoder(nn.Module):
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
def __init__(self, in_features, memory_dim, r, attn_win, attn_norm,
|
||||
def __init__(self, in_features, memory_dim, r, attn_type, attn_win, attn_norm,
|
||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||
forward_attn_mask, location_attn, separate_stopnet,
|
||||
forward_attn_mask, location_attn, attn_K, separate_stopnet,
|
||||
speaker_embedding_dim):
|
||||
super(Decoder, self).__init__()
|
||||
self.memory_dim = memory_dim
|
||||
|
@ -128,7 +128,8 @@ class Decoder(nn.Module):
|
|||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.query_dim)
|
||||
|
||||
self.attention = Attention(query_dim=self.query_dim,
|
||||
self.attention = init_attn(attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
|
@ -138,7 +139,8 @@ class Decoder(nn.Module):
|
|||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask)
|
||||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features,
|
||||
self.decoder_rnn_dim, 1)
|
||||
|
|
|
@ -15,6 +15,7 @@ class Tacotron(nn.Module):
|
|||
postnet_output_dim=1025,
|
||||
decoder_output_dim=80,
|
||||
memory_size=5,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
gst=False,
|
||||
attn_norm="sigmoid",
|
||||
|
@ -24,6 +25,7 @@ class Tacotron(nn.Module):
|
|||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
super(Tacotron, self).__init__()
|
||||
|
@ -41,10 +43,10 @@ class Tacotron(nn.Module):
|
|||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
# boilerplate model
|
||||
self.encoder = Encoder(encoder_dim)
|
||||
self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_win,
|
||||
self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, separate_stopnet,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
proj_speaker_dim)
|
||||
if self.bidirectional_decoder:
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
|
|
|
@ -14,6 +14,7 @@ class Tacotron2(nn.Module):
|
|||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
|
@ -22,6 +23,7 @@ class Tacotron2(nn.Module):
|
|||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
|
@ -42,10 +44,10 @@ class Tacotron2(nn.Module):
|
|||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
self.encoder = Encoder(encoder_dim)
|
||||
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win,
|
||||
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, separate_stopnet, proj_speaker_dim)
|
||||
location_attn, attn_K, separate_stopnet, proj_speaker_dim)
|
||||
if self.bidirectional_decoder:
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
self.postnet = Postnet(self.decoder_output_dim)
|
||||
|
|
|
@ -287,6 +287,7 @@ def setup_model(num_chars, num_speakers, c):
|
|||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
|
@ -295,6 +296,7 @@ def setup_model(num_chars, num_speakers, c):
|
|||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
|
@ -303,6 +305,7 @@ def setup_model(num_chars, num_speakers, c):
|
|||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
|
@ -311,6 +314,7 @@ def setup_model(num_chars, num_speakers, c):
|
|||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
return model
|
||||
|
|
Загрузка…
Ссылка в новой задаче