зеркало из https://github.com/mozilla/TTS.git
87 строки
2.7 KiB
Python
87 строки
2.7 KiB
Python
import torch
|
|
from torch.autograd import Variable
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class BahdanauAttention(nn.Module):
|
|
def __init__(self, dim):
|
|
super(BahdanauAttention, self).__init__()
|
|
self.query_layer = nn.Linear(dim, dim, bias=False)
|
|
self.tanh = nn.Tanh()
|
|
self.v = nn.Linear(dim, 1, bias=False)
|
|
|
|
def forward(self, query, processed_memory):
|
|
"""
|
|
Args:
|
|
query: (batch, 1, dim) or (batch, dim)
|
|
processed_memory: (batch, max_time, dim)
|
|
"""
|
|
if query.dim() == 2:
|
|
# insert time-axis for broadcasting
|
|
query = query.unsqueeze(1)
|
|
# (batch, 1, dim)
|
|
processed_query = self.query_layer(query)
|
|
|
|
# (batch, max_time, 1)
|
|
alignment = self.v(self.tanh(processed_query + processed_memory))
|
|
|
|
# (batch, max_time)
|
|
return alignment.squeeze(-1)
|
|
|
|
|
|
def get_mask_from_lengths(memory, memory_lengths):
|
|
"""Get mask tensor from list of length
|
|
|
|
Args:
|
|
memory: (batch, max_time, dim)
|
|
memory_lengths: array like
|
|
"""
|
|
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
|
|
for idx, l in enumerate(memory_lengths):
|
|
mask[idx][:l] = 1
|
|
return ~mask
|
|
|
|
|
|
class AttentionWrapper(nn.Module):
|
|
def __init__(self, rnn_cell, attention_mechanism,
|
|
score_mask_value=-float("inf")):
|
|
super(AttentionWrapper, self).__init__()
|
|
self.rnn_cell = rnn_cell
|
|
self.attention_mechanism = attention_mechanism
|
|
self.score_mask_value = score_mask_value
|
|
|
|
def forward(self, query, attention, cell_state, memory,
|
|
processed_memory=None, mask=None, memory_lengths=None):
|
|
if processed_memory is None:
|
|
processed_memory = memory
|
|
if memory_lengths is not None and mask is None:
|
|
mask = get_mask_from_lengths(memory, memory_lengths)
|
|
|
|
# Concat input query and previous attention context
|
|
cell_input = torch.cat((query, attention), -1)
|
|
|
|
# Feed it to RNN
|
|
cell_output = self.rnn_cell(cell_input, cell_state)
|
|
|
|
# Alignment
|
|
# (batch, max_time)
|
|
alignment = self.attention_mechanism(cell_output, processed_memory)
|
|
|
|
if mask is not None:
|
|
mask = mask.view(query.size(0), -1)
|
|
alignment.data.masked_fill_(mask, self.score_mask_value)
|
|
|
|
# Normalize attention weight
|
|
alignment = F.softmax(alignment, dim=0)
|
|
|
|
# Attention context vector
|
|
# (batch, 1, dim)
|
|
attention = torch.bmm(alignment.unsqueeze(1), memory)
|
|
|
|
# (batch, dim)
|
|
attention = attention.squeeze(1)
|
|
|
|
return cell_output, attention, alignment
|
|
|