diff --git a/config.json b/config.json index 1a617ea..3505e11 100644 --- a/config.json +++ b/config.json @@ -13,6 +13,7 @@ "epochs": 2000, "lr": 0.005, + "warmup_steps": 4000, "batch_size": 180, "r": 5, diff --git a/train.py b/train.py index 230f0be..9b8e8c9 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ from tensorboardX import SummaryWriter from utils.generic_utils import (Progbar, remove_experiment_folder, create_experiment_folder, save_checkpoint, save_best_model, load_config, lr_decay, - count_parameters) + count_parameters, check_update) from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram from datasets.LJSpeech import LJSpeechDataset @@ -150,9 +150,9 @@ def main(args): current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1 # setup lr - # current_lr = lr_decay(c.lr, current_step) - # for params_group in optimizer.param_groups: - # params_group['lr'] = current_lr + current_lr = lr_decay(c.lr, current_step, c.warmup_steps) + for params_group in optimizer.param_groups: + params_group['lr'] = current_lr optimizer.zero_grad() @@ -197,10 +197,13 @@ def main(args): + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[: ,: ,:n_priority_freq]) loss = mel_loss + linear_loss - # loss = loss.cuda() loss.backward() - grad_norm = nn.utils.clip_grad_norm(model.parameters(), 0.5) ## TODO: maybe no need + grad_norm, skip_flag = check_update(model, 0.5, 100) + if skip_flag: + optimizer.zero_grad() + print(" | > Iteration skipped!!") + continue optimizer.step() step_time = time.time() - start_time @@ -265,7 +268,6 @@ def main(args): best_loss, OUT_PATH, current_step, epoch) - #lr_scheduler.step(loss.data[0]) tb.add_scalar('Time/EpochTime', epoch_time, epoch) epoch_time = 0 diff --git a/utils/generic_utils.py b/utils/generic_utils.py index ed9661f..4832ec4 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -95,8 +95,21 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path, return best_loss -def lr_decay(init_lr, global_step): - warmup_steps = 4000.0 +def check_update(model, grad_clip, grad_top): + r'''Check model gradient against unexpected jumps and failures''' + skip_flag = False + grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), grad_clip) + if np.isinf(grad_norm): + print(" | > Gradient is INF !!") + skip_flag = True + elif grad_norm > grad_top: + print(" | > Gradient is above the top limit !!") + skip_flag = True + return grad_norm, skip_flag + + +def lr_decay(init_lr, global_step, warmup_steps): + r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py''' step = global_step + 1. lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, step**-0.5)