Dont use teacher forcing at test time

This commit is contained in:
Eren Golge 2018-03-19 10:38:47 -07:00
Родитель 9b4aa92667
Коммит cb48406383
2 изменённых файлов: 5 добавлений и 4 удалений

Просмотреть файл

@ -48,6 +48,7 @@ class BatchNormConv1d(nn.Module):
- input: batch x dims - input: batch x dims
- output: batch x dims - output: batch x dims
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
activation=None): activation=None):
super(BatchNormConv1d, self).__init__() super(BatchNormConv1d, self).__init__()
@ -241,7 +242,8 @@ class Decoder(nn.Module):
Args: Args:
inputs: Encoder outputs. inputs: Encoder outputs.
memory (None): Decoder memory (autoregression. If None (at eval-time), memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. decoder outputs are used as decoder inputs. If None, it uses the last
output as the input.
Shapes: Shapes:
- inputs: batch x time x encoder_out_dim - inputs: batch x time x encoder_out_dim
@ -293,7 +295,7 @@ class Decoder(nn.Module):
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0) memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
# add a random noise # add a random noise
noise = torch.autograd.Variable( noise = torch.autograd.Variable(
memory_input.data.new(memory_input.size()).normal_(0.0, 1.0)) memory_input.data.new(memory_input.size()).normal_(0.0, 0.5))
memory_input = memory_input + noise memory_input = memory_input + noise
# Prenet # Prenet

Просмотреть файл

@ -228,8 +228,7 @@ def evaluate(model, criterion, data_loader, current_step):
linear_spec_var = linear_spec_var.cuda() linear_spec_var = linear_spec_var.cuda()
# forward pass # forward pass
mel_output, linear_output, alignments =\ mel_output, linear_output, alignments = model.forward(text_input_var)
model.forward(text_input_var, mel_spec_var)
# loss computation # loss computation
mel_loss = criterion(mel_output, mel_spec_var) mel_loss = criterion(mel_output, mel_spec_var)