Merge branch 'loc-sens-attn' into loc-sens-attn-new and attention without attention-cum

This commit is contained in:
Eren G 2018-07-13 14:27:51 +02:00
Родитель 7f75ff39df f791f4e5e7
Коммит 9f52833151
6 изменённых файлов: 884 добавлений и 46 удалений

Просмотреть файл

@ -23,6 +23,7 @@ Checkout [here](https://mycroft.ai/blog/available-voices/#the-human-voice-is-the
| [iter-62410](https://drive.google.com/open?id=1pjJNzENL3ZNps9n7k_ktGbpEl6YPIkcZ)| [99d56f7](https://github.com/mozilla/TTS/tree/99d56f7e93ccd7567beb0af8fcbd4d24c48e59e9) | [link](https://soundcloud.com/user-565970875/99d56f7-iter62410 )|First model with plain Tacotron implementation.|
| [iter-170K](https://drive.google.com/open?id=16L6JbPXj6MSlNUxEStNn28GiSzi4fu1j) | [e00bc66](https://github.com/mozilla/TTS/tree/e00bc66) |[link](https://soundcloud.com/user-565970875/april-13-2018-07-06pm-e00bc66-iter170k)|More stable and longer trained model.|
| Best: [iter-270K](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing)|[256ed63](https://github.com/mozilla/TTS/tree/256ed63)|[link](https://soundcloud.com/user-565970875/sets/samples-1650226)|Stop-Token prediction is added, to detect end of speech.|
| Best: [iter-K] | [bla]() | [link]() | Location Sensitive attention |
## Data
Currently TTS provides data loaders for

Просмотреть файл

@ -13,8 +13,8 @@ class BahdanauAttention(nn.Module):
def forward(self, annots, query):
"""
Shapes:
- query: (batch, 1, dim) or (batch, dim)
- annots: (batch, max_time, dim)
- query: (batch, 1, dim) or (batch, dim)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
@ -29,31 +29,70 @@ class BahdanauAttention(nn.Module):
return alignment.squeeze(-1)
def get_mask_from_lengths(inputs, inputs_lengths):
"""Get mask tensor from list of length
class LocationSensitiveAttention(nn.Module):
"""Location sensitive attention following
https://arxiv.org/pdf/1506.07503.pdf"""
def __init__(self, annot_dim, query_dim, hidden_dim,
kernel_size=7, filters=20):
super(LocationSensitiveAttention, self).__init__()
self.kernel_size = kernel_size
self.filters = filters
padding = int((kernel - 1) / 2)
self.loc_conv = nn.Conv1d(2, filters,
kernel_size=kernel_size, stride=1,
padding=padding, bias=False)
self.loc_linear = nn.Linear(loc_dim, hidden_dim)
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
self.v = nn.Linear(hidden_dim, 1, bias=False)
Args:
inputs: Tensor in size (batch, max_time, dim)
inputs_lengths: array like
"""
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
for idx, l in enumerate(inputs_lengths):
mask[idx][:l] = 1
return ~mask
def forward(self, annot, query, loc):
"""
Shapes:
- annot: (batch, max_time, dim)
- query: (batch, 1, dim) or (batch, dim)
- loc: (batch, 2, max_time)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
query = query.unsqueeze(1)
loc_conv = self.loc_conv(loc)
loc_conv = loc_conv.transpose(1, 2)
processed_loc = self.loc_linear(loc_conv)
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annot)
alignment = self.v(nn.functional.tanh(
processed_query + processed_annots + processed_loc))
# (batch, max_time)
return alignment.squeeze(-1)
class AttentionRNN(nn.Module):
def __init__(self, out_dim, annot_dim, memory_dim,
score_mask_value=-float("inf")):
class AttentionRNNCell(nn.Module):
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
r"""
General Attention RNN wrapper
Args:
out_dim (int): context vector feature dimension.
annot_dim (int): annotation vector feature dimension.
memory_dim (int): memory vector (decoder autogression) feature dimension.
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
"""
super(AttentionRNN, self).__init__()
self.align_model = align_model
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
self.score_mask_value = score_mask_value
# pick bahdanau or location sensitive attention
if align_model == 'b':
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
if align_model == 'ls':
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
else:
raise RuntimeError(" Wrong alignment model name: {}. Use\
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
def forward(self, memory, context, rnn_state, annotations,
mask=None, annotations_lengths=None):
if annotations_lengths is not None and mask is None:
mask = get_mask_from_lengths(annotations, annotations_lengths)
attention_vec, mask=None, annotations_lengths=None):
# Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN
@ -62,7 +101,10 @@ class AttentionRNN(nn.Module):
# Alignment
# (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j)
alignment = self.alignment_model(annotations, rnn_output)
if self.align_model is 'b':
alignment = self.alignment_model(annotations, rnn_output)
else:
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
# TODO: needs recheck.
if mask is not None:
mask = mask.view(query.size(0), -1)

Просмотреть файл

@ -17,10 +17,10 @@ def _sequence_mask(sequence_length, max_len=None):
return seq_range_expand < seq_length_expand
class L1LossMasked(nn.Module):
class L2LossMasked(nn.Module):
def __init__(self):
super(L1LossMasked, self).__init__()
super(L2LossMasked, self).__init__()
def forward(self, input, target, length):
"""
@ -44,7 +44,7 @@ class L1LossMasked(nn.Module):
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
losses_flat = functional.mse_loss(input, target_flat, size_average=False,
reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())

Просмотреть файл

@ -1,9 +1,7 @@
# coding: utf-8
import torch
from torch import nn
from .attention import AttentionRNN
from .attention import get_mask_from_lengths
from .attention import AttentionRNNCell
class Prenet(nn.Module):
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
@ -12,7 +10,7 @@ class Prenet(nn.Module):
Args:
in_features (int): size of the input vector
out_features (int or list): size of each output sample.
If it is a list, for each value, there is created a new layer.
If it is a list, for each value, there is created a new layer.
"""
def __init__(self, in_features, out_features=[256, 128]):
@ -162,7 +160,7 @@ class CBHG(nn.Module):
x = highway(x)
# (B, T_in, in_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3
self.gru.flatten_parameters()
# self.gru.flatten_parameters()
outputs, _ = self.gru(x)
return outputs
@ -195,7 +193,6 @@ class Decoder(nn.Module):
in_features (int): input vector (encoder output) sample size.
memory_dim (int): memory vector (prev. time-step output) sample size.
r (int): number of outputs per time step.
eps (float): threshold for detecting the end of a sentence.
"""
def __init__(self, in_features, memory_dim, r):
@ -205,8 +202,8 @@ class Decoder(nn.Module):
self.memory_dim = memory_dim
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
self.attention_rnn = AttentionRNN(256, in_features, 128)
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNNCell(256, in_features, 128, align_model='ls')
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
@ -234,6 +231,7 @@ class Decoder(nn.Module):
- memory: batch x #mel_specs x mel_spec_dim
"""
B = inputs.size(0)
T = inputs.size(1)
# Run greedy decoding if memory is None
greedy = not self.training
if memory is not None:
@ -243,19 +241,22 @@ class Decoder(nn.Module):
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
self.memory_dim, self.r)
T_decoder = memory.size(1)
# go frame - 0 frames tarting the sequence
# go frame as zeros matrix
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
# Init decoder states
# decoder states
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))]
current_context_vec = inputs.data.new(B, 256).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()
# attention_cum = inputs.data.new(B, T).zero_()
# Time first (T_decoder, B, memory_dim)
if memory is not None:
memory = memory.transpose(0, 1)
outputs = []
alignments = []
attentions = []
stop_tokens = []
t = 0
memory_input = initial_memory
@ -268,8 +269,12 @@ class Decoder(nn.Module):
# Prenet
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs)
# attention_cat = torch.cat((attention.unsqueeze(1),
# attention_cum.unsqueeze(1)),
# dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden, inputs, attention)
# attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_context_vec), -1))
@ -286,14 +291,14 @@ class Decoder(nn.Module):
# predict stop token
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
outputs += [output]
alignments += [alignment]
attentions += [attention]
stop_tokens += [stop_token]
t += 1
if (not greedy and self.training) or (greedy and memory is not None):
if t >= T_decoder:
break
else:
if t > inputs.shape[1]/2 and stop_token > 0.8:
if t > inputs.shape[1]/2 and stop_token > 0.6:
break
elif t > self.max_decoder_steps:
print(" !! Decoder stopped with 'max_decoder_steps'. \
@ -301,28 +306,35 @@ class Decoder(nn.Module):
break
assert greedy or len(outputs) == T_decoder
# Back to batch first
alignments = torch.stack(alignments).transpose(0, 1)
attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, alignments, stop_tokens
return outputs, attentions, stop_tokens
class StopNet(nn.Module):
r"""
Predicting stop-token in decoder.
Args:
r (int): number of output frames of the network.
memory_dim (int): feature dimension for each output frame.
"""
def __init__(self, r, memory_dim):
r"""
Predicts the stop token to stop the decoder at testing time
Args:
r (int): number of network output frames.
memory_dim (int): single feature dim of a single network output frame.
"""
super(StopNet, self).__init__()
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
self.relu = nn.ReLU()
self.linear = nn.Linear(r * memory_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs, rnn_hidden):
"""
Args:
@ -333,4 +345,4 @@ class StopNet(nn.Module):
outputs = self.relu(rnn_hidden)
outputs = self.linear(outputs)
outputs = self.sigmoid(outputs)
return outputs, rnn_hidden
return outputs, rnn_hidden

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -26,7 +26,7 @@ from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
from layers.losses import L1LossMasked
from layers.losses import L2LossMasked
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()