diff --git a/train.py b/train.py index 5ae4c03..7078acb 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ from tensorboardX import SummaryWriter from utils.generic_utils import ( synthesis, remove_experiment_folder, create_experiment_folder, save_checkpoint, save_best_model, load_config, lr_decay, count_parameters, - check_update, get_commit_hash, sequence_mask) + check_update, get_commit_hash, sequence_mask, AnnealLR) from utils.visual import plot_alignment, plot_spectrogram from models.tacotron import Tacotron from layers.losses import L1LossMasked @@ -312,22 +312,23 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step): # test sentences ap.griffin_lim_iters = 60 for idx, test_sentence in enumerate(test_sentences): - wav, linear_spec, alignments = synthesis(model, ap, test_sentence, - use_cuda, c.text_cleaner) try: + wav, linear_spec, alignments = synthesis(model, ap, test_sentence, + use_cuda, c.text_cleaner) wav_name = 'TestSentences/{}'.format(idx) tb.add_audio( wav_name, wav, current_step, sample_rate=c.sample_rate) + + align_img = alignments[0].data.cpu().numpy() + linear_spec = plot_spectrogram(linear_spec, ap) + align_img = plot_alignment(align_img) + tb.add_figure('TestSentences/{}_Spectrogram'.format(idx), linear_spec, + current_step) + tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img, + current_step) except: print(" !! Error as creating Test Sentence -", idx) pass - align_img = alignments[0].data.cpu().numpy() - linear_spec = plot_spectrogram(linear_spec, ap) - align_img = plot_alignment(align_img) - tb.add_figure('TestSentences/{}_Spectrogram'.format(idx), linear_spec, - current_step) - tb.add_figure('TestSentences/{}_Alignment'.format(idx), align_img, - current_step) return avg_linear_loss @@ -337,14 +338,6 @@ def main(args): audio = importlib.import_module('utils.' + c.audio_processor) AudioProcessor = getattr(audio, 'AudioProcessor') - print(" > LR scheduler: {} ", c.lr_scheduler) - try: - scheduler = importlib.import_module('torch.optim.lr_scheduler') - scheduler = getattr(scheduler, c.lr_scheduler) - except: - scheduler = importlib.import_module('utils.generic_utils') - scheduler = getattr(scheduler, c.lr_scheduler) - ap = AudioProcessor( sample_rate=c.sample_rate, num_mels=c.num_mels, @@ -426,7 +419,7 @@ def main(args): criterion.cuda() criterion_st.cuda() - scheduler = StepLR(optimizer, step_size=c.decay_step, gamma=c.lr_decay) + scheduler = AnnealLR(optimizer, warmup_steps=c.warmup_steps) num_params = count_parameters(model) print(" | > Model has {} parameters".format(num_params), flush=True) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index fea8c84..a1f72a3 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -143,16 +143,16 @@ def lr_decay(init_lr, global_step, warmup_steps): class AnnealLR(torch.optim.lr_scheduler._LRScheduler): - def __init__(self, optimizer, warmup_steps=0.1): + def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): self.warmup_steps = float(warmup_steps) super(AnnealLR, self).__init__(optimizer, last_epoch) def get_lr(self): + step = max(self.last_epoch, 1) return [ - base_lr * self.warmup_steps**0.5 * torch.min([ - self.last_epoch * self.warmup_steps**-1.5, self.last_epoch** - -0.5 - ]) for base_lr in self.base_lrs + base_lr * self.warmup_steps**0.5 * min( + step * self.warmup_steps**-1.5, step**-0.5) + for base_lr in self.base_lrs ]