зеркало из https://github.com/mozilla/TTS.git
Dont use teacher forcing at test time
This commit is contained in:
Родитель
9b4aa92667
Коммит
cb48406383
|
@ -48,6 +48,7 @@ class BatchNormConv1d(nn.Module):
|
||||||
- input: batch x dims
|
- input: batch x dims
|
||||||
- output: batch x dims
|
- output: batch x dims
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(BatchNormConv1d, self).__init__()
|
super(BatchNormConv1d, self).__init__()
|
||||||
|
@ -241,7 +242,8 @@ class Decoder(nn.Module):
|
||||||
Args:
|
Args:
|
||||||
inputs: Encoder outputs.
|
inputs: Encoder outputs.
|
||||||
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
||||||
decoder outputs are used as decoder inputs.
|
decoder outputs are used as decoder inputs. If None, it uses the last
|
||||||
|
output as the input.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- inputs: batch x time x encoder_out_dim
|
- inputs: batch x time x encoder_out_dim
|
||||||
|
@ -293,7 +295,7 @@ class Decoder(nn.Module):
|
||||||
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
memory_input = torch.div(outputs[-1] + memory[t-1], 2.0)
|
||||||
# add a random noise
|
# add a random noise
|
||||||
noise = torch.autograd.Variable(
|
noise = torch.autograd.Variable(
|
||||||
memory_input.data.new(memory_input.size()).normal_(0.0, 1.0))
|
memory_input.data.new(memory_input.size()).normal_(0.0, 0.5))
|
||||||
memory_input = memory_input + noise
|
memory_input = memory_input + noise
|
||||||
|
|
||||||
# Prenet
|
# Prenet
|
||||||
|
|
3
train.py
3
train.py
|
@ -228,8 +228,7 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments = model.forward(text_input_var)
|
||||||
model.forward(text_input_var, mel_spec_var)
|
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче