check gradients for big errorenous changes

This commit is contained in:
Eren Golge 2018-02-27 07:31:07 -08:00
Родитель b21fa4dd44
Коммит d6b2af7ca9
3 изменённых файлов: 25 добавлений и 9 удалений

Просмотреть файл

@ -13,6 +13,7 @@
"epochs": 2000,
"lr": 0.005,
"warmup_steps": 4000,
"batch_size": 180,
"r": 5,

Просмотреть файл

@ -21,7 +21,7 @@ from tensorboardX import SummaryWriter
from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay,
count_parameters)
count_parameters, check_update)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
@ -150,9 +150,9 @@ def main(args):
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
# setup lr
# current_lr = lr_decay(c.lr, current_step)
# for params_group in optimizer.param_groups:
# params_group['lr'] = current_lr
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
for params_group in optimizer.param_groups:
params_group['lr'] = current_lr
optimizer.zero_grad()
@ -197,10 +197,13 @@ def main(args):
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq])
loss = mel_loss + linear_loss
# loss = loss.cuda()
loss.backward()
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 0.5) ## TODO: maybe no need
grad_norm, skip_flag = check_update(model, 0.5, 100)
if skip_flag:
optimizer.zero_grad()
print(" | > Iteration skipped!!")
continue
optimizer.step()
step_time = time.time() - start_time
@ -265,7 +268,6 @@ def main(args):
best_loss, OUT_PATH,
current_step, epoch)
#lr_scheduler.step(loss.data[0])
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
epoch_time = 0

Просмотреть файл

@ -95,8 +95,21 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
return best_loss
def lr_decay(init_lr, global_step):
warmup_steps = 4000.0
def check_update(model, grad_clip, grad_top):
r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False
grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), grad_clip)
if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
elif grad_norm > grad_top:
print(" | > Gradient is above the top limit !!")
skip_flag = True
return grad_norm, skip_flag
def lr_decay(init_lr, global_step, warmup_steps):
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
step = global_step + 1.
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
step**-0.5)