From b4bd71358109764902fa2e6744800abb6e096112 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 10 May 2018 16:44:37 -0700 Subject: [PATCH] train.py - add with torch.no_grad(): --- train.py | 65 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/train.py b/train.py index 233c8f6..0337150 100644 --- a/train.py +++ b/train.py @@ -191,45 +191,46 @@ def evaluate(model, criterion, data_loader, current_step): print(" | > Validation") progbar = Progbar(len(data_loader.dataset) / c.batch_size) n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) - for num_iter, data in enumerate(data_loader): - start_time = time.time() + with torch.no_grad(): + for num_iter, data in enumerate(data_loader): + start_time = time.time() - # setup input data - text_input = data[0] - text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] - mel_lengths = data[4] + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[2] + mel_input = data[3] + mel_lengths = data[4] - # dispatch data to GPU - if use_cuda: - text_input = text_input.cuda() - mel_input = mel_input.cuda() - mel_lengths = mel_lengths.cuda() - linear_input = linear_input.cuda() + # dispatch data to GPU + if use_cuda: + text_input = text_input.cuda() + mel_input = mel_input.cuda() + mel_lengths = mel_lengths.cuda() + linear_input = linear_input.cuda() - # forward pass - mel_output, linear_output, alignments =\ - model.forward(text_input, mel_input) + # forward pass + mel_output, linear_output, alignments =\ + model.forward(text_input, mel_input) - # loss computation - mel_loss = criterion(mel_output, mel_input, mel_lengths) - linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ - + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_input[:, :, :n_priority_freq], - mel_lengths) - loss = mel_loss + linear_loss + # loss computation + mel_loss = criterion(mel_output, mel_input, mel_lengths) + linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ + + 0.5 * criterion(linear_output[:, :, :n_priority_freq], + linear_input[:, :, :n_priority_freq], + mel_lengths) + loss = mel_loss + linear_loss - step_time = time.time() - start_time - epoch_time += step_time + step_time = time.time() - start_time + epoch_time += step_time - # update - progbar.update(num_iter+1, values=[('total_loss', loss.item()), - ('linear_loss', linear_loss.item()), - ('mel_loss', mel_loss.item())]) + # update + progbar.update(num_iter+1, values=[('total_loss', loss.item()), + ('linear_loss', linear_loss.item()), + ('mel_loss', mel_loss.item())]) - avg_linear_loss += linear_loss.item() - avg_mel_loss += mel_loss.item() + avg_linear_loss += linear_loss.item() + avg_mel_loss += mel_loss.item() # Diagnostic visualizations idx = np.random.randint(mel_input.shape[0])