train.py - replace data[0] with item()

This commit is contained in:
Eren Golge 2018-05-10 16:22:17 -07:00
Родитель 10fd4f62b3
Коммит c8bfe731d6
1 изменённых файлов: 15 добавлений и 17 удалений

Просмотреть файл

@ -118,19 +118,18 @@ def train(model, criterion, data_loader, optimizer, epoch):
epoch_time += step_time
# update
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
('linear_loss',
linear_loss.data[0]),
('mel_loss', mel_loss.data[0]),
('grad_norm', grad_norm)])
avg_linear_loss += linear_loss.data[0]
avg_mel_loss += mel_loss.data[0]
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
('linear_loss', linear_loss.item()),
('mel_loss', mel_loss.item()),
('grad_norm', grad_norm.item())])
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
# Plot Training Iter Stats
tb.add_scalar('TrainIterLoss/TotalLoss', loss.data[0], current_step)
tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0],
tb.add_scalar('TrainIterLoss/TotalLoss', loss.item(), current_step)
tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.item(),
current_step)
tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step)
tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.item(), current_step)
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
current_step)
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
@ -139,7 +138,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
if current_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, linear_loss.data[0],
save_checkpoint(model, optimizer, linear_loss.item(),
OUT_PATH, current_step, epoch)
# Diagnostic visualizations
@ -225,13 +224,12 @@ def evaluate(model, criterion, data_loader, current_step):
epoch_time += step_time
# update
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
('linear_loss',
linear_loss.data[0]),
('mel_loss', mel_loss.data[0])])
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
('linear_loss', linear_loss.item()),
('mel_loss', mel_loss.item())])
avg_linear_loss += linear_loss.data[0]
avg_mel_loss += mel_loss.data[0]
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
# Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0])