зеркало из https://github.com/mozilla/TTS.git
Merge branch 'loc-sens-attn' into loc-sens-attn-new and attention without attention-cum
This commit is contained in:
Коммит
9f52833151
|
@ -23,6 +23,7 @@ Checkout [here](https://mycroft.ai/blog/available-voices/#the-human-voice-is-the
|
|||
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.|
|
||||
| [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.|
|
||||
| Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.|
|
||||
| Best: [iter-K] | [bla]() | [link]() | Location Sensitive attention |
|
||||
|
||||
## Data
|
||||
Currently TTS provides data loaders for
|
||||
|
|
|
@ -13,8 +13,8 @@ class BahdanauAttention(nn.Module):
|
|||
def forward(self, annots, query):
|
||||
"""
|
||||
Shapes:
|
||||
- query: (batch, 1, dim) or (batch, dim)
|
||||
- annots: (batch, max_time, dim)
|
||||
- query: (batch, 1, dim) or (batch, dim)
|
||||
"""
|
||||
if query.dim() == 2:
|
||||
# insert time-axis for broadcasting
|
||||
|
@ -29,31 +29,70 @@ class BahdanauAttention(nn.Module):
|
|||
return alignment.squeeze(-1)
|
||||
|
||||
|
||||
def get_mask_from_lengths(inputs, inputs_lengths):
|
||||
"""Get mask tensor from list of length
|
||||
class LocationSensitiveAttention(nn.Module):
|
||||
"""Location sensitive attention following
|
||||
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||
def __init__(self, annot_dim, query_dim, hidden_dim,
|
||||
kernel_size=7, filters=20):
|
||||
super(LocationSensitiveAttention, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.filters = filters
|
||||
padding = int((kernel - 1) / 2)
|
||||
self.loc_conv = nn.Conv1d(2, filters,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=padding, bias=False)
|
||||
self.loc_linear = nn.Linear(loc_dim, hidden_dim)
|
||||
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
||||
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
||||
|
||||
Args:
|
||||
inputs: Tensor in size (batch, max_time, dim)
|
||||
inputs_lengths: array like
|
||||
"""
|
||||
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
|
||||
for idx, l in enumerate(inputs_lengths):
|
||||
mask[idx][:l] = 1
|
||||
return ~mask
|
||||
def forward(self, annot, query, loc):
|
||||
"""
|
||||
Shapes:
|
||||
- annot: (batch, max_time, dim)
|
||||
- query: (batch, 1, dim) or (batch, dim)
|
||||
- loc: (batch, 2, max_time)
|
||||
"""
|
||||
if query.dim() == 2:
|
||||
# insert time-axis for broadcasting
|
||||
query = query.unsqueeze(1)
|
||||
loc_conv = self.loc_conv(loc)
|
||||
loc_conv = loc_conv.transpose(1, 2)
|
||||
processed_loc = self.loc_linear(loc_conv)
|
||||
processed_query = self.query_layer(query)
|
||||
processed_annots = self.annot_layer(annot)
|
||||
alignment = self.v(nn.functional.tanh(
|
||||
processed_query + processed_annots + processed_loc))
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
||||
class AttentionRNN(nn.Module):
|
||||
def __init__(self, out_dim, annot_dim, memory_dim,
|
||||
score_mask_value=-float("inf")):
|
||||
class AttentionRNNCell(nn.Module):
|
||||
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
|
||||
r"""
|
||||
General Attention RNN wrapper
|
||||
|
||||
Args:
|
||||
out_dim (int): context vector feature dimension.
|
||||
annot_dim (int): annotation vector feature dimension.
|
||||
memory_dim (int): memory vector (decoder autogression) feature dimension.
|
||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||
"""
|
||||
super(AttentionRNN, self).__init__()
|
||||
self.align_model = align_model
|
||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||
self.score_mask_value = score_mask_value
|
||||
# pick bahdanau or location sensitive attention
|
||||
if align_model == 'b':
|
||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||
if align_model == 'ls':
|
||||
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
|
||||
else:
|
||||
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
||||
|
||||
|
||||
def forward(self, memory, context, rnn_state, annotations,
|
||||
mask=None, annotations_lengths=None):
|
||||
if annotations_lengths is not None and mask is None:
|
||||
mask = get_mask_from_lengths(annotations, annotations_lengths)
|
||||
attention_vec, mask=None, annotations_lengths=None):
|
||||
# Concat input query and previous context context
|
||||
rnn_input = torch.cat((memory, context), -1)
|
||||
# Feed it to RNN
|
||||
|
@ -62,7 +101,10 @@ class AttentionRNN(nn.Module):
|
|||
# Alignment
|
||||
# (batch, max_time)
|
||||
# e_{ij} = a(s_{i-1}, h_j)
|
||||
alignment = self.alignment_model(annotations, rnn_output)
|
||||
if self.align_model is 'b':
|
||||
alignment = self.alignment_model(annotations, rnn_output)
|
||||
else:
|
||||
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
||||
# TODO: needs recheck.
|
||||
if mask is not None:
|
||||
mask = mask.view(query.size(0), -1)
|
||||
|
|
|
@ -17,10 +17,10 @@ def _sequence_mask(sequence_length, max_len=None):
|
|||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
class L1LossMasked(nn.Module):
|
||||
class L2LossMasked(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(L1LossMasked, self).__init__()
|
||||
super(L2LossMasked, self).__init__()
|
||||
|
||||
def forward(self, input, target, length):
|
||||
"""
|
||||
|
@ -44,7 +44,7 @@ class L1LossMasked(nn.Module):
|
|||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
||||
losses_flat = functional.mse_loss(input, target_flat, size_average=False,
|
||||
reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
from torch import nn
|
||||
from .attention import AttentionRNN
|
||||
from .attention import get_mask_from_lengths
|
||||
|
||||
from .attention import AttentionRNNCell
|
||||
|
||||
class Prenet(nn.Module):
|
||||
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||
|
@ -12,7 +10,7 @@ class Prenet(nn.Module):
|
|||
Args:
|
||||
in_features (int): size of the input vector
|
||||
out_features (int or list): size of each output sample.
|
||||
If it is a list, for each value, there is created a new layer.
|
||||
If it is a list, for each value, there is created a new layer.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features=[256, 128]):
|
||||
|
@ -162,7 +160,7 @@ class CBHG(nn.Module):
|
|||
x = highway(x)
|
||||
# (B, T_in, in_features*2)
|
||||
# TODO: replace GRU with convolution as in Deep Voice 3
|
||||
self.gru.flatten_parameters()
|
||||
# self.gru.flatten_parameters()
|
||||
outputs, _ = self.gru(x)
|
||||
return outputs
|
||||
|
||||
|
@ -195,7 +193,6 @@ class Decoder(nn.Module):
|
|||
in_features (int): input vector (encoder output) sample size.
|
||||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||
r (int): number of outputs per time step.
|
||||
eps (float): threshold for detecting the end of a sentence.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, memory_dim, r):
|
||||
|
@ -205,8 +202,8 @@ class Decoder(nn.Module):
|
|||
self.memory_dim = memory_dim
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
||||
self.attention_rnn = AttentionRNN(256, in_features, 128)
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNNCell(256, in_features, 128, align_model='ls')
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
@ -234,6 +231,7 @@ class Decoder(nn.Module):
|
|||
- memory: batch x #mel_specs x mel_spec_dim
|
||||
"""
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
# Run greedy decoding if memory is None
|
||||
greedy = not self.training
|
||||
if memory is not None:
|
||||
|
@ -243,19 +241,22 @@ class Decoder(nn.Module):
|
|||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
||||
self.memory_dim, self.r)
|
||||
T_decoder = memory.size(1)
|
||||
# go frame - 0 frames tarting the sequence
|
||||
# go frame as zeros matrix
|
||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||
# Init decoder states
|
||||
# decoder states
|
||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))]
|
||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
# attention_cum = inputs.data.new(B, T).zero_()
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
if memory is not None:
|
||||
memory = memory.transpose(0, 1)
|
||||
outputs = []
|
||||
alignments = []
|
||||
attentions = []
|
||||
stop_tokens = []
|
||||
t = 0
|
||||
memory_input = initial_memory
|
||||
|
@ -268,8 +269,12 @@ class Decoder(nn.Module):
|
|||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
# Attention RNN
|
||||
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs)
|
||||
# attention_cat = torch.cat((attention.unsqueeze(1),
|
||||
# attention_cum.unsqueeze(1)),
|
||||
# dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention)
|
||||
# attention_cum += attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||
|
@ -286,14 +291,14 @@ class Decoder(nn.Module):
|
|||
# predict stop token
|
||||
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
||||
outputs += [output]
|
||||
alignments += [alignment]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
t += 1
|
||||
if (not greedy and self.training) or (greedy and memory is not None):
|
||||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > inputs.shape[1]/2 and stop_token > 0.8:
|
||||
if t > inputs.shape[1]/2 and stop_token > 0.6:
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||
|
@ -301,28 +306,35 @@ class Decoder(nn.Module):
|
|||
break
|
||||
assert greedy or len(outputs) == T_decoder
|
||||
# Back to batch first
|
||||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
attentions = torch.stack(attentions).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
return outputs, alignments, stop_tokens
|
||||
return outputs, attentions, stop_tokens
|
||||
|
||||
|
||||
|
||||
class StopNet(nn.Module):
|
||||
r"""
|
||||
Predicting stop-token in decoder.
|
||||
|
||||
|
||||
Args:
|
||||
r (int): number of output frames of the network.
|
||||
memory_dim (int): feature dimension for each output frame.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, r, memory_dim):
|
||||
r"""
|
||||
Predicts the stop token to stop the decoder at testing time
|
||||
|
||||
Args:
|
||||
r (int): number of network output frames.
|
||||
memory_dim (int): single feature dim of a single network output frame.
|
||||
"""
|
||||
super(StopNet, self).__init__()
|
||||
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear = nn.Linear(r * memory_dim, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
|
||||
def forward(self, inputs, rnn_hidden):
|
||||
"""
|
||||
Args:
|
||||
|
@ -333,4 +345,4 @@ class StopNet(nn.Module):
|
|||
outputs = self.relu(rnn_hidden)
|
||||
outputs = self.linear(outputs)
|
||||
outputs = self.sigmoid(outputs)
|
||||
return outputs, rnn_hidden
|
||||
return outputs, rnn_hidden
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
2
train.py
2
train.py
|
@ -26,7 +26,7 @@ from utils.model import get_param_size
|
|||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
from models.tacotron import Tacotron
|
||||
from layers.losses import L1LossMasked
|
||||
from layers.losses import L2LossMasked
|
||||
|
||||
torch.manual_seed(1)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
|
Загрузка…
Ссылка в новой задаче