diff --git a/train.py b/train.py index d0a0412..e238757 100644 --- a/train.py +++ b/train.py @@ -264,7 +264,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): # ('linear_loss', linear_loss.item()), # ('mel_loss', mel_loss.item()), # ('stop_loss', stop_loss.item())]) - if current_step % c.print_step == 0: + if num_iter % c.print_step == 0: print(" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "\ "StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(), @@ -319,7 +319,6 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step): def main(args): - # Setup the dataset # Setup the dataset train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'), os.path.join(c.data_path, 'wavs'),