зеркало из https://github.com/mozilla/TTS.git
masked loss
This commit is contained in:
Родитель
e4a0eec77e
Коммит
32d9c734b2
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from torch import functional
|
||||
from torch.nn import functional
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
|
@ -7,7 +8,7 @@ def _sequence_mask(sequence_length, max_len=None):
|
|||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = torch.range(0, max_len - 1).long()
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
seq_range_expand = Variable(seq_range_expand)
|
||||
if sequence_length.is_cuda:
|
||||
|
@ -31,18 +32,20 @@ def L1LossMasked(input, target, length):
|
|||
Returns:
|
||||
loss: An average loss value masked by the length.
|
||||
"""
|
||||
input = input.contiguous()
|
||||
target = target.contiguous()
|
||||
|
||||
# logits_flat: (batch * max_len, num_classes)
|
||||
# logits_flat: (batch * max_len, dim)
|
||||
input = input.view(-1, input.size(-1))
|
||||
# target_flat: (batch * max_len, 1)
|
||||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, 1)
|
||||
# losses_flat: (batch * max_len, 1)
|
||||
losees_flat = functional.l1_loss(input, target, size_average=False,
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.l1_loss(input, target, size_average=False,
|
||||
reduce=False)
|
||||
# losses: (batch, max_len)
|
||||
losses = losses_flat.view(*target.size())
|
||||
# mask: (batch, max_len)
|
||||
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
|
||||
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / length.float().sum()
|
||||
return loss
|
||||
return loss / input.shape[0]
|
20
train.py
20
train.py
|
@ -26,7 +26,7 @@ from utils.model import get_param_size
|
|||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
from models.tacotron import Tacotron
|
||||
from losses import
|
||||
from layers.losses import L1LossMasked
|
||||
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
@ -95,7 +95,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
# convert inputs to variables
|
||||
text_input_var = Variable(text_input)
|
||||
mel_spec_var = Variable(mel_input)
|
||||
mel_length_var = Variable(mel_lengths)
|
||||
mel_lengths_var = Variable(mel_lengths)
|
||||
linear_spec_var = Variable(linear_input, volatile=True)
|
||||
|
||||
# sort sequence by length for curriculum learning
|
||||
|
@ -105,6 +105,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
sorted_lengths = sorted_lengths.long().numpy()
|
||||
text_input_var = text_input_var[indices]
|
||||
mel_spec_var = mel_spec_var[indices]
|
||||
mel_lengths_var = mel_lengths_var[indices]
|
||||
linear_spec_var = linear_spec_var[indices]
|
||||
|
||||
# dispatch data to GPU
|
||||
|
@ -119,11 +120,11 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
model.forward(text_input_var, mel_spec_var)
|
||||
|
||||
# loss computation
|
||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \
|
||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_spec_var[: ,: ,:n_priority_freq],
|
||||
mel_lengths)
|
||||
mel_lengths_var)
|
||||
loss = mel_loss + linear_loss
|
||||
|
||||
# backpass and check the grad norm
|
||||
|
@ -240,10 +241,10 @@ def evaluate(model, criterion, data_loader, current_step):
|
|||
|
||||
# loss computation
|
||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_spec_var[: ,: ,:n_priority_freq],
|
||||
mel_lengths)
|
||||
mel_lengths_var)
|
||||
loss = mel_loss + linear_loss
|
||||
|
||||
step_time = time.time() - start_time
|
||||
|
@ -348,10 +349,7 @@ def main(args):
|
|||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
|
||||
if use_cuda:
|
||||
criterion = nn.L1Loss().cuda()
|
||||
else:
|
||||
criterion = nn.L1Loss()
|
||||
criterion = L1LossMasked
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
|
|
Загрузка…
Ссылка в новой задаче