зеркало из https://github.com/mozilla/TTS.git
check gradients for big errorenous changes
This commit is contained in:
Родитель
b21fa4dd44
Коммит
d6b2af7ca9
|
@ -13,6 +13,7 @@
|
|||
|
||||
"epochs": 2000,
|
||||
"lr": 0.005,
|
||||
"warmup_steps": 4000,
|
||||
"batch_size": 180,
|
||||
"r": 5,
|
||||
|
||||
|
|
16
train.py
16
train.py
|
@ -21,7 +21,7 @@ from tensorboardX import SummaryWriter
|
|||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||
create_experiment_folder, save_checkpoint,
|
||||
save_best_model, load_config, lr_decay,
|
||||
count_parameters)
|
||||
count_parameters, check_update)
|
||||
from utils.model import get_param_size
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
|
@ -150,9 +150,9 @@ def main(args):
|
|||
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
|
||||
|
||||
# setup lr
|
||||
# current_lr = lr_decay(c.lr, current_step)
|
||||
# for params_group in optimizer.param_groups:
|
||||
# params_group['lr'] = current_lr
|
||||
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||
for params_group in optimizer.param_groups:
|
||||
params_group['lr'] = current_lr
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
@ -197,10 +197,13 @@ def main(args):
|
|||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_spec_var[: ,: ,:n_priority_freq])
|
||||
loss = mel_loss + linear_loss
|
||||
# loss = loss.cuda()
|
||||
|
||||
loss.backward()
|
||||
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 0.5) ## TODO: maybe no need
|
||||
grad_norm, skip_flag = check_update(model, 0.5, 100)
|
||||
if skip_flag:
|
||||
optimizer.zero_grad()
|
||||
print(" | > Iteration skipped!!")
|
||||
continue
|
||||
optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
|
@ -265,7 +268,6 @@ def main(args):
|
|||
best_loss, OUT_PATH,
|
||||
current_step, epoch)
|
||||
|
||||
#lr_scheduler.step(loss.data[0])
|
||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||
epoch_time = 0
|
||||
|
||||
|
|
|
@ -95,8 +95,21 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
|||
return best_loss
|
||||
|
||||
|
||||
def lr_decay(init_lr, global_step):
|
||||
warmup_steps = 4000.0
|
||||
def check_update(model, grad_clip, grad_top):
|
||||
r'''Check model gradient against unexpected jumps and failures'''
|
||||
skip_flag = False
|
||||
grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), grad_clip)
|
||||
if np.isinf(grad_norm):
|
||||
print(" | > Gradient is INF !!")
|
||||
skip_flag = True
|
||||
elif grad_norm > grad_top:
|
||||
print(" | > Gradient is above the top limit !!")
|
||||
skip_flag = True
|
||||
return grad_norm, skip_flag
|
||||
|
||||
|
||||
def lr_decay(init_lr, global_step, warmup_steps):
|
||||
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
|
||||
step = global_step + 1.
|
||||
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
|
||||
step**-0.5)
|
||||
|
|
Загрузка…
Ссылка в новой задаче