make location attention optional and keep all attention weights in attention class

This commit is contained in:
Eren Golge 2019-04-29 11:37:01 +02:00
Родитель 3ea34c6488
Коммит e2439fde9a
3 изменённых файлов: 62 добавлений и 40 удалений

Просмотреть файл

@ -120,7 +120,7 @@ class LocationLayer(nn.Module):
class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, location_attention,
attention_location_n_filters, attention_location_kernel_size,
windowing, norm, forward_attn, trans_agent):
super(Attention, self).__init__()
@ -131,37 +131,64 @@ class Attention(nn.Module):
self.v = Linear(attention_dim, 1, bias=True)
if trans_agent:
self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True)
self.location_layer = LocationLayer(attention_location_n_filters,
attention_location_kernel_size,
attention_dim)
if location_attention:
self.location_layer = LocationLayer(attention_location_n_filters,
attention_location_kernel_size,
attention_dim)
self._mask_value = -float("inf")
self.windowing = windowing
self.win_idx = None
self.norm = norm
self.forward_attn = forward_attn
self.trans_agent = trans_agent
self.location_attention = location_attention
def init_win_idx(self):
self.win_idx = -1
self.win_back = 2
self.win_front = 6
def init_forward_attn_state(self, inputs):
"""
Init forward attention states
"""
def init_forward_attn(self, inputs):
B = inputs.shape[0]
T = inputs.shape[1]
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device)
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
def get_attention(self, query, processed_inputs, attention_cat):
def init_location_attention(self, inputs):
B = inputs.shape[0]
T = inputs.shape[1]
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
def init_states(self, inputs):
B = inputs.shape[0]
T = inputs.shape[1]
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
if self.location_attention:
self.init_location_attention(inputs)
if self.forward_attn:
self.init_forward_attn(inputs)
if self.windowing:
self.init_win_idx()
def update_location_attention(self, alignments):
self.attention_weights_cum += alignments
def get_location_attention(self, query, processed_inputs):
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)),
dim=1)
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights +
processed_inputs))
processed_inputs))
energies = energies.squeeze(-1)
return energies, processed_query
def get_attention(self, query, processed_inputs):
processed_query = self.query_layer(query.unsqueeze(1))
energies = self.v(
torch.tanh(processed_query +processed_inputs))
energies = energies.squeeze(-1)
return energies, processed_query
@ -192,13 +219,16 @@ class Attention(nn.Module):
if self.trans_agent:
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
self.u = torch.sigmoid(self.ta(ta_input))
return context, self.alpha, alignment
return context, self.alpha
def forward(self, attention_hidden_state, inputs, processed_inputs,
attention_cat, mask):
attention, processed_query = self.get_attention(
attention_hidden_state, processed_inputs, attention_cat)
mask):
if self.location_attention:
attention, processed_query = self.get_location_attention(
attention_hidden_state, processed_inputs)
else:
attention, processed_query = self.get_attention(
attention_hidden_state, processed_inputs)
# apply masking
if mask is not None:
attention.data.masked_fill_(1 - mask, self._mask_value)
@ -213,13 +243,15 @@ class Attention(nn.Module):
attention).sum(dim=1).unsqueeze(1)
else:
raise RuntimeError("Unknown value for attention norm type")
if self.location_attention:
self.update_location_attention(alignment)
# apply forward attention if enabled
if self.forward_attn:
return self.apply_forward_attention(inputs, alignment, processed_query)
context, self.attention_weights = self.apply_forward_attention(inputs, alignment, processed_query)
else:
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)
return context, alignment, alignment
return context
class Postnet(nn.Module):
@ -289,7 +321,7 @@ class Encoder(nn.Module):
# adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent):
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent, location_attn):
super(Decoder, self).__init__()
self.mel_channels = inputs_dim
self.r = r
@ -308,8 +340,8 @@ class Decoder(nn.Module):
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.attention_rnn_dim)
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
128, 32, 31, attn_win, attn_norm, forward_attn, trans_agent)
self.attention_layer = Attention(self.attention_rnn_dim, in_features, 128, location_attn,
32, 31, attn_win, attn_norm, forward_attn, trans_agent)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
self.decoder_rnn_dim, 1)
@ -351,9 +383,6 @@ class Decoder(nn.Module):
self.context = Variable(
inputs.data.new(B, self.encoder_embedding_dim).zero_())
self.attention_weights = Variable(inputs.data.new(B, T).zero_())
self.attention_weights_cum = Variable(inputs.data.new(B, T).zero_())
self.inputs = inputs
self.processed_inputs = self.attention_layer.inputs_layer(inputs)
@ -384,14 +413,10 @@ class Decoder(nn.Module):
self.attention_cell = F.dropout(
self.attention_cell, self.p_attention_dropout, self.training)
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)),
dim=1)
self.context, self.attention_weights, alignments = self.attention_layer(
self.context = self.attention_layer(
self.attention_hidden, self.inputs, self.processed_inputs,
attention_cat, self.mask)
self.mask)
self.attention_weights_cum += alignments
memory = torch.cat(
(self.attention_hidden, self.context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
@ -410,7 +435,7 @@ class Decoder(nn.Module):
stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1)
gate_prediction = self.stopnet(stopnet_input)
return decoder_output, gate_prediction, self.attention_weights
return decoder_output, gate_prediction, self.attention_layer.attention_weights
def forward(self, inputs, memories, mask):
memory = self.get_go_frame(inputs).unsqueeze(0)
@ -419,8 +444,7 @@ class Decoder(nn.Module):
memories = self.prenet(memories)
self._init_states(inputs, mask=mask)
if self.attention_layer.forward_attn:
self.attention_layer.init_forward_attn_state(inputs)
self.attention_layer.init_states(inputs)
outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1:
@ -441,8 +465,7 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=None)
self.attention_layer.init_win_idx()
if self.attention_layer.forward_attn:
self.attention_layer.init_forward_attn_state(inputs)
self.attention_layer.init_states(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [False, False, False]
@ -484,9 +507,7 @@ class Decoder(nn.Module):
else:
self._init_states(inputs, mask=None, keep_states=True)
self.attention_layer.init_win_idx()
if self.attention_layer.forward_attn:
self.attention_layer.init_forward_attn_state(inputs)
self.attention_layer.init_states(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [False, False, False]
stop_count = 0

Просмотреть файл

@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
# TODO: match function arguments with tacotron
class Tacotron2(nn.Module):
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False):
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False, location_attn=True):
super(Tacotron2, self).__init__()
self.n_mel_channels = 80
self.n_frames_per_step = r
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent, location_attn)
self.postnet = Postnet(self.n_mel_channels)
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):

Просмотреть файл

@ -263,5 +263,6 @@ def setup_model(num_chars, c):
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent)
trans_agent=c.transition_agent,
location_attn=c.location_attn)
return model