This commit is contained in:
Eren Golge 2018-03-22 14:35:02 -07:00
Родитель e4a0eec77e
Коммит 32d9c734b2
2 изменённых файлов: 20 добавлений и 19 удалений

Просмотреть файл

@ -1,5 +1,6 @@
import torch
from torch import functional
from torch.nn import functional
from torch.autograd import Variable
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
@ -7,7 +8,7 @@ def _sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.range(0, max_len - 1).long()
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_range_expand = Variable(seq_range_expand)
if sequence_length.is_cuda:
@ -31,18 +32,20 @@ def L1LossMasked(input, target, length):
Returns:
loss: An average loss value masked by the length.
"""
input = input.contiguous()
target = target.contiguous()
# logits_flat: (batch * max_len, num_classes)
# logits_flat: (batch * max_len, dim)
input = input.view(-1, input.size(-1))
# target_flat: (batch * max_len, 1)
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, 1)
# losses_flat: (batch * max_len, 1)
losees_flat = functional.l1_loss(input, target, size_average=False,
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target, size_average=False,
reduce=False)
# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / length.float().sum()
return loss
return loss / input.shape[0]

Просмотреть файл

@ -26,7 +26,7 @@ from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
from losses import
from layers.losses import L1LossMasked
use_cuda = torch.cuda.is_available()
@ -95,7 +95,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
mel_length_var = Variable(mel_lengths)
mel_lengths_var = Variable(mel_lengths)
linear_spec_var = Variable(linear_input, volatile=True)
# sort sequence by length for curriculum learning
@ -105,6 +105,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
sorted_lengths = sorted_lengths.long().numpy()
text_input_var = text_input_var[indices]
mel_spec_var = mel_spec_var[indices]
mel_lengths_var = mel_lengths_var[indices]
linear_spec_var = linear_spec_var[indices]
# dispatch data to GPU
@ -119,11 +120,11 @@ def train(model, criterion, data_loader, optimizer, epoch):
model.forward(text_input_var, mel_spec_var)
# loss computation
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq],
mel_lengths)
mel_lengths_var)
loss = mel_loss + linear_loss
# backpass and check the grad norm
@ -240,10 +241,10 @@ def evaluate(model, criterion, data_loader, current_step):
# loss computation
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq],
mel_lengths)
mel_lengths_var)
loss = mel_loss + linear_loss
step_time = time.time() - start_time
@ -348,10 +349,7 @@ def main(args):
optimizer = optim.Adam(model.parameters(), lr=c.lr)
if use_cuda:
criterion = nn.L1Loss().cuda()
else:
criterion = nn.L1Loss()
criterion = L1LossMasked
if args.restore_path:
checkpoint = torch.load(args.restore_path)