зеркало из https://github.com/mozilla/TTS.git
train.py - add with torch.no_grad():
This commit is contained in:
Родитель
7c40455edd
Коммит
b4bd713581
65
train.py
65
train.py
|
@ -191,45 +191,46 @@ def evaluate(model, criterion, data_loader, current_step):
|
|||
print(" | > Validation")
|
||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda()
|
||||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
linear_input = linear_input.cuda()
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda()
|
||||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
linear_input = linear_input.cuda()
|
||||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments =\
|
||||
model.forward(text_input, mel_input)
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments =\
|
||||
model.forward(text_input, mel_input)
|
||||
|
||||
# loss computation
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_input[:, :, :n_priority_freq],
|
||||
mel_lengths)
|
||||
loss = mel_loss + linear_loss
|
||||
# loss computation
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_input[:, :, :n_priority_freq],
|
||||
mel_lengths)
|
||||
loss = mel_loss + linear_loss
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update
|
||||
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
||||
('linear_loss', linear_loss.item()),
|
||||
('mel_loss', mel_loss.item())])
|
||||
# update
|
||||
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
||||
('linear_loss', linear_loss.item()),
|
||||
('mel_loss', mel_loss.item())])
|
||||
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_input.shape[0])
|
||||
|
|
Загрузка…
Ссылка в новой задаче