train.py - add with torch.no_grad():

This commit is contained in:
Eren Golge 2018-05-10 16:44:37 -07:00
Родитель 7c40455edd
Коммит b4bd713581
1 изменённых файлов: 33 добавлений и 32 удалений

Просмотреть файл

@ -191,45 +191,46 @@ def evaluate(model, criterion, data_loader, current_step):
print(" | > Validation")
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
with torch.no_grad():
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# setup input data
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
# setup input data
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda()
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda()
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda()
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda()
# forward pass
mel_output, linear_output, alignments =\
model.forward(text_input, mel_input)
# forward pass
mel_output, linear_output, alignments =\
model.forward(text_input, mel_input)
# loss computation
mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq],
mel_lengths)
loss = mel_loss + linear_loss
# loss computation
mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq],
mel_lengths)
loss = mel_loss + linear_loss
step_time = time.time() - start_time
epoch_time += step_time
step_time = time.time() - start_time
epoch_time += step_time
# update
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
('linear_loss', linear_loss.item()),
('mel_loss', mel_loss.item())])
# update
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
('linear_loss', linear_loss.item()),
('mel_loss', mel_loss.item())])
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
# Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0])