зеркало из https://github.com/mozilla/TTS.git
687 строки
28 KiB
Python
687 строки
28 KiB
Python
import argparse
|
|
import os
|
|
import sys
|
|
import glob
|
|
import time
|
|
import traceback
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from TTS.datasets.TTSDataset import MyDataset
|
|
from distribute import (DistributedSampler, apply_gradient_allreduce,
|
|
init_distributed, reduce_tensor)
|
|
from TTS.layers.losses import TacotronLoss
|
|
from TTS.utils.audio import AudioProcessor
|
|
from TTS.utils.generic_utils import (count_parameters, create_experiment_folder, remove_experiment_folder,
|
|
get_git_branch, set_init_dict,
|
|
setup_model, KeepAverage, check_config)
|
|
from TTS.utils.io import (save_best_model, save_checkpoint,
|
|
load_config, copy_config_file)
|
|
from TTS.utils.training import (NoamLR, check_update, adam_weight_decay,
|
|
gradual_training_scheduler, set_weight_decay)
|
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
from TTS.utils.console_logger import ConsoleLogger
|
|
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
|
get_speakers
|
|
from TTS.utils.synthesis import synthesis
|
|
from TTS.utils.text.symbols import make_symbols, phonemes, symbols
|
|
from TTS.utils.visual import plot_alignment, plot_spectrogram
|
|
from TTS.datasets.preprocess import load_meta_data
|
|
from TTS.utils.radam import RAdam
|
|
from TTS.utils.measures import alignment_diagonal_score
|
|
|
|
torch.backends.cudnn.enabled = True
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.manual_seed(54321)
|
|
use_cuda = torch.cuda.is_available()
|
|
num_gpus = torch.cuda.device_count()
|
|
print(" > Using CUDA: ", use_cuda)
|
|
print(" > Number of GPUs: ", num_gpus)
|
|
|
|
|
|
def setup_loader(ap, r, is_val=False, verbose=False):
|
|
if is_val and not c.run_eval:
|
|
loader = None
|
|
else:
|
|
dataset = MyDataset(
|
|
r,
|
|
c.text_cleaner,
|
|
compute_linear_spec=True if c.model.lower() == 'tacotron' else False,
|
|
meta_data=meta_data_eval if is_val else meta_data_train,
|
|
ap=ap,
|
|
tp=c.characters if 'characters' in c.keys() else None,
|
|
batch_group_size=0 if is_val else c.batch_group_size *
|
|
c.batch_size,
|
|
min_seq_len=c.min_seq_len,
|
|
max_seq_len=c.max_seq_len,
|
|
phoneme_cache_path=c.phoneme_cache_path,
|
|
use_phonemes=c.use_phonemes,
|
|
phoneme_language=c.phoneme_language,
|
|
enable_eos_bos=c.enable_eos_bos_chars,
|
|
verbose=verbose)
|
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
|
shuffle=False,
|
|
collate_fn=dataset.collate_fn,
|
|
drop_last=False,
|
|
sampler=sampler,
|
|
num_workers=c.num_val_loader_workers
|
|
if is_val else c.num_loader_workers,
|
|
pin_memory=False)
|
|
return loader
|
|
|
|
|
|
def format_data(data):
|
|
if c.use_speaker_embedding:
|
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
|
|
|
# setup input data
|
|
text_input = data[0]
|
|
text_lengths = data[1]
|
|
speaker_names = data[2]
|
|
linear_input = data[3] if c.model in ["Tacotron"] else None
|
|
mel_input = data[4]
|
|
mel_lengths = data[5]
|
|
stop_targets = data[6]
|
|
avg_text_length = torch.mean(text_lengths.float())
|
|
avg_spec_length = torch.mean(mel_lengths.float())
|
|
|
|
if c.use_speaker_embedding:
|
|
speaker_ids = [
|
|
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
|
]
|
|
speaker_ids = torch.LongTensor(speaker_ids)
|
|
else:
|
|
speaker_ids = None
|
|
|
|
# set stop targets view, we predict a single stop token per iteration.
|
|
stop_targets = stop_targets.view(text_input.shape[0],
|
|
stop_targets.size(1) // c.r, -1)
|
|
stop_targets = (stop_targets.sum(2) >
|
|
0.0).unsqueeze(2).float().squeeze(2)
|
|
|
|
# dispatch data to GPU
|
|
if use_cuda:
|
|
text_input = text_input.cuda(non_blocking=True)
|
|
text_lengths = text_lengths.cuda(non_blocking=True)
|
|
mel_input = mel_input.cuda(non_blocking=True)
|
|
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
|
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron"] else None
|
|
stop_targets = stop_targets.cuda(non_blocking=True)
|
|
if speaker_ids is not None:
|
|
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
|
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length
|
|
|
|
|
|
def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|
ap, global_step, epoch):
|
|
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
|
|
verbose=(epoch == 0))
|
|
model.train()
|
|
epoch_time = 0
|
|
train_values = {
|
|
'avg_postnet_loss': 0,
|
|
'avg_decoder_loss': 0,
|
|
'avg_stopnet_loss': 0,
|
|
'avg_align_error': 0,
|
|
'avg_step_time': 0,
|
|
'avg_loader_time': 0
|
|
}
|
|
if c.bidirectional_decoder:
|
|
train_values['avg_decoder_b_loss'] = 0 # decoder backward loss
|
|
train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss
|
|
if c.ga_alpha > 0:
|
|
train_values['avg_ga_loss'] = 0 # guidede attention loss
|
|
keep_avg = KeepAverage()
|
|
keep_avg.add_values(train_values)
|
|
if use_cuda:
|
|
batch_n_iter = int(
|
|
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
|
else:
|
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
|
end_time = time.time()
|
|
c_logger.print_train_start()
|
|
for num_iter, data in enumerate(data_loader):
|
|
start_time = time.time()
|
|
|
|
# format data
|
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data)
|
|
loader_time = time.time() - end_time
|
|
|
|
global_step += 1
|
|
|
|
# setup lr
|
|
if c.noam_schedule:
|
|
scheduler.step()
|
|
optimizer.zero_grad()
|
|
if optimizer_st:
|
|
optimizer_st.zero_grad()
|
|
|
|
# forward pass model
|
|
if c.bidirectional_decoder:
|
|
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
|
else:
|
|
decoder_output, postnet_output, alignments, stop_tokens = model(
|
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
|
decoder_backward_output = None
|
|
|
|
# set the alignment lengths wrt reduction factor for guided attention
|
|
if mel_lengths.max() % model.decoder.r != 0:
|
|
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
|
|
else:
|
|
alignment_lengths = mel_lengths // model.decoder.r
|
|
|
|
# compute loss
|
|
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
|
linear_input, stop_tokens, stop_targets,
|
|
mel_lengths, decoder_backward_output,
|
|
alignments, alignment_lengths, text_lengths)
|
|
if c.bidirectional_decoder:
|
|
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(),
|
|
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})
|
|
if c.ga_alpha > 0:
|
|
keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()})
|
|
|
|
# backward pass
|
|
loss_dict['loss'].backward()
|
|
optimizer, current_lr = adam_weight_decay(optimizer)
|
|
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
|
optimizer.step()
|
|
|
|
# compute alignment error (the lower the better )
|
|
align_error = 1 - alignment_diagonal_score(alignments)
|
|
keep_avg.update_value('avg_align_error', align_error)
|
|
loss_dict['align_error'] = align_error
|
|
|
|
# backpass and check the grad norm for stop loss
|
|
if c.separate_stopnet:
|
|
loss_dict['stopnet_loss'].backward()
|
|
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
|
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
|
optimizer_st.step()
|
|
else:
|
|
grad_norm_st = 0
|
|
|
|
step_time = time.time() - start_time
|
|
epoch_time += step_time
|
|
|
|
# update avg stats
|
|
update_train_values = {
|
|
'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
|
|
'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
|
|
'avg_stopnet_loss': loss_dict['stopnet_loss'].item() \
|
|
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
|
|
'avg_step_time': step_time,
|
|
'avg_loader_time': loader_time
|
|
}
|
|
keep_avg.update_values(update_train_values)
|
|
|
|
if global_step % c.print_step == 0:
|
|
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
|
avg_spec_length, avg_text_length,
|
|
step_time, loader_time, current_lr,
|
|
loss_dict, keep_avg.avg_values)
|
|
|
|
# aggregate losses from processes
|
|
if num_gpus > 1:
|
|
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
|
|
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
|
|
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
|
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss']
|
|
|
|
if args.rank == 0:
|
|
# Plot Training Iter Stats
|
|
# reduce TB load
|
|
if global_step % 10 == 0:
|
|
iter_stats = {
|
|
"loss_posnet": loss_dict['postnet_loss'].item(),
|
|
"loss_decoder": loss_dict['decoder_loss'].item(),
|
|
"lr": current_lr,
|
|
"grad_norm": grad_norm,
|
|
"grad_norm_st": grad_norm_st,
|
|
"step_time": step_time
|
|
}
|
|
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
|
|
|
if global_step % c.save_step == 0:
|
|
if c.checkpoint:
|
|
# save model
|
|
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
|
optimizer_st=optimizer_st,
|
|
model_loss=loss_dict['postnet_loss'].item())
|
|
|
|
# Diagnostic visualizations
|
|
const_spec = postnet_output[0].data.cpu().numpy()
|
|
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
|
|
"Tacotron", "TacotronGST"
|
|
] else mel_input[0].data.cpu().numpy()
|
|
align_img = alignments[0].data.cpu().numpy()
|
|
|
|
figures = {
|
|
"prediction": plot_spectrogram(const_spec, ap),
|
|
"ground_truth": plot_spectrogram(gt_spec, ap),
|
|
"alignment": plot_alignment(align_img),
|
|
}
|
|
|
|
if c.bidirectional_decoder:
|
|
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy())
|
|
|
|
tb_logger.tb_train_figures(global_step, figures)
|
|
|
|
# Sample audio
|
|
if c.model in ["Tacotron", "TacotronGST"]:
|
|
train_audio = ap.inv_spectrogram(const_spec.T)
|
|
else:
|
|
train_audio = ap.inv_melspectrogram(const_spec.T)
|
|
tb_logger.tb_train_audios(global_step,
|
|
{'TrainAudio': train_audio},
|
|
c.audio["sample_rate"])
|
|
end_time = time.time()
|
|
|
|
# print epoch stats
|
|
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
|
|
|
# Plot Epoch Stats
|
|
if args.rank == 0:
|
|
# Plot Training Epoch Stats
|
|
epoch_stats = {
|
|
"loss_postnet": keep_avg['avg_postnet_loss'],
|
|
"loss_decoder": keep_avg['avg_decoder_loss'],
|
|
"stopnet_loss": keep_avg['avg_stopnet_loss'],
|
|
"alignment_score": keep_avg['avg_align_error'],
|
|
"epoch_time": epoch_time
|
|
}
|
|
if c.ga_alpha > 0:
|
|
epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss']
|
|
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
|
if c.tb_model_param_stats:
|
|
tb_logger.tb_model_weights(model, global_step)
|
|
return keep_avg.avg_values, global_step
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, criterion, ap, global_step, epoch):
|
|
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
|
model.eval()
|
|
epoch_time = 0
|
|
eval_values_dict = {
|
|
'avg_postnet_loss': 0,
|
|
'avg_decoder_loss': 0,
|
|
'avg_stopnet_loss': 0,
|
|
'avg_align_error': 0
|
|
}
|
|
if c.bidirectional_decoder:
|
|
eval_values_dict['avg_decoder_b_loss'] = 0 # decoder backward loss
|
|
eval_values_dict['avg_decoder_c_loss'] = 0 # decoder consistency loss
|
|
if c.ga_alpha > 0:
|
|
eval_values_dict['avg_ga_loss'] = 0 # guidede attention loss
|
|
keep_avg = KeepAverage()
|
|
keep_avg.add_values(eval_values_dict)
|
|
|
|
c_logger.print_eval_start()
|
|
if data_loader is not None:
|
|
for num_iter, data in enumerate(data_loader):
|
|
start_time = time.time()
|
|
|
|
# format data
|
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
|
|
assert mel_input.shape[1] % model.decoder.r == 0
|
|
|
|
# forward pass model
|
|
if c.bidirectional_decoder:
|
|
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
|
else:
|
|
decoder_output, postnet_output, alignments, stop_tokens = model(
|
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
|
decoder_backward_output = None
|
|
|
|
# set the alignment lengths wrt reduction factor for guided attention
|
|
if mel_lengths.max() % model.decoder.r != 0:
|
|
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
|
|
else:
|
|
alignment_lengths = mel_lengths // model.decoder.r
|
|
|
|
# compute loss
|
|
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
|
linear_input, stop_tokens, stop_targets,
|
|
mel_lengths, decoder_backward_output,
|
|
alignments, alignment_lengths, text_lengths)
|
|
if c.bidirectional_decoder:
|
|
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_b_loss'].item(),
|
|
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})
|
|
if c.ga_alpha > 0:
|
|
keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()})
|
|
|
|
# step time
|
|
step_time = time.time() - start_time
|
|
epoch_time += step_time
|
|
|
|
# compute alignment score
|
|
align_error = 1 - alignment_diagonal_score(alignments)
|
|
keep_avg.update_value('avg_align_error', align_error)
|
|
|
|
# aggregate losses from processes
|
|
if num_gpus > 1:
|
|
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
|
|
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
|
|
if c.stopnet:
|
|
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus)
|
|
|
|
keep_avg.update_values({
|
|
'avg_postnet_loss':
|
|
float(loss_dict['postnet_loss'].item()),
|
|
'avg_decoder_loss':
|
|
float(loss_dict['decoder_loss'].item()),
|
|
'avg_stopnet_loss':
|
|
float(loss_dict['stopnet_loss'].item()),
|
|
})
|
|
|
|
if c.print_eval:
|
|
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
|
|
|
if args.rank == 0:
|
|
# Diagnostic visualizations
|
|
idx = np.random.randint(mel_input.shape[0])
|
|
const_spec = postnet_output[idx].data.cpu().numpy()
|
|
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
|
"Tacotron", "TacotronGST"
|
|
] else mel_input[idx].data.cpu().numpy()
|
|
align_img = alignments[idx].data.cpu().numpy()
|
|
|
|
eval_figures = {
|
|
"prediction": plot_spectrogram(const_spec, ap),
|
|
"ground_truth": plot_spectrogram(gt_spec, ap),
|
|
"alignment": plot_alignment(align_img)
|
|
}
|
|
|
|
# Sample audio
|
|
if c.model in ["Tacotron", "TacotronGST"]:
|
|
eval_audio = ap.inv_spectrogram(const_spec.T)
|
|
else:
|
|
eval_audio = ap.inv_melspectrogram(const_spec.T)
|
|
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
|
c.audio["sample_rate"])
|
|
|
|
# Plot Validation Stats
|
|
epoch_stats = {
|
|
"loss_postnet": keep_avg['avg_postnet_loss'],
|
|
"loss_decoder": keep_avg['avg_decoder_loss'],
|
|
"stopnet_loss": keep_avg['avg_stopnet_loss'],
|
|
"alignment_score": keep_avg['avg_align_error'],
|
|
}
|
|
|
|
if c.bidirectional_decoder:
|
|
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss']
|
|
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
|
eval_figures['alignment_backward'] = plot_alignment(align_b_img)
|
|
if c.ga_alpha > 0:
|
|
epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss']
|
|
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
|
tb_logger.tb_eval_figures(global_step, eval_figures)
|
|
|
|
if args.rank == 0 and epoch > c.test_delay_epochs:
|
|
if c.test_sentences_file is None:
|
|
test_sentences = [
|
|
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
|
"Be a voice, not an echo.",
|
|
"I'm sorry Dave. I'm afraid I can't do that.",
|
|
"This cake is great. It's so delicious and moist."
|
|
]
|
|
else:
|
|
with open(c.test_sentences_file, "r") as f:
|
|
test_sentences = [s.strip() for s in f.readlines()]
|
|
|
|
# test sentences
|
|
test_audios = {}
|
|
test_figures = {}
|
|
print(" | > Synthesizing test sentences")
|
|
speaker_id = 0 if c.use_speaker_embedding else None
|
|
style_wav = c.get("style_wav_for_test")
|
|
for idx, test_sentence in enumerate(test_sentences):
|
|
try:
|
|
wav, alignment, decoder_output, postnet_output, stop_tokens, inputs = synthesis(
|
|
model,
|
|
test_sentence,
|
|
c,
|
|
use_cuda,
|
|
ap,
|
|
speaker_id=speaker_id,
|
|
style_wav=style_wav,
|
|
truncated=False,
|
|
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
|
use_griffin_lim=True,
|
|
do_trim_silence=False)
|
|
|
|
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
|
os.makedirs(file_path, exist_ok=True)
|
|
file_path = os.path.join(file_path,
|
|
"TestSentence_{}.wav".format(idx))
|
|
ap.save_wav(wav, file_path)
|
|
test_audios['{}-audio'.format(idx)] = wav
|
|
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
|
postnet_output, ap)
|
|
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
|
alignment)
|
|
except:
|
|
print(" !! Error creating Test Sentence -", idx)
|
|
traceback.print_exc()
|
|
tb_logger.tb_test_audios(global_step, test_audios,
|
|
c.audio['sample_rate'])
|
|
tb_logger.tb_test_figures(global_step, test_figures)
|
|
return keep_avg.avg_values
|
|
|
|
|
|
# FIXME: move args definition/parsing inside of main?
|
|
def main(args): # pylint: disable=redefined-outer-name
|
|
# pylint: disable=global-variable-undefined
|
|
global meta_data_train, meta_data_eval, symbols, phonemes
|
|
# Audio processor
|
|
ap = AudioProcessor(**c.audio)
|
|
if 'characters' in c.keys():
|
|
symbols, phonemes = make_symbols(**c.characters)
|
|
|
|
# DISTRUBUTED
|
|
if num_gpus > 1:
|
|
init_distributed(args.rank, num_gpus, args.group_id,
|
|
c.distributed["backend"], c.distributed["url"])
|
|
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
|
|
|
# load data instances
|
|
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
|
|
|
# parse speakers
|
|
if c.use_speaker_embedding:
|
|
speakers = get_speakers(meta_data_train)
|
|
if args.restore_path:
|
|
prev_out_path = os.path.dirname(args.restore_path)
|
|
speaker_mapping = load_speaker_mapping(prev_out_path)
|
|
assert all([speaker in speaker_mapping
|
|
for speaker in speakers]), "As of now you, you cannot " \
|
|
"introduce new speakers to " \
|
|
"a previously trained model."
|
|
else:
|
|
speaker_mapping = {name: i for i, name in enumerate(speakers)}
|
|
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
|
num_speakers = len(speaker_mapping)
|
|
print("Training with {} speakers: {}".format(num_speakers,
|
|
", ".join(speakers)))
|
|
else:
|
|
num_speakers = 0
|
|
|
|
model = setup_model(num_chars, num_speakers, c)
|
|
|
|
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
|
|
|
params = set_weight_decay(model, c.wd)
|
|
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
|
if c.stopnet and c.separate_stopnet:
|
|
optimizer_st = RAdam(model.decoder.stopnet.parameters(),
|
|
lr=c.lr,
|
|
weight_decay=0)
|
|
else:
|
|
optimizer_st = None
|
|
|
|
# setup criterion
|
|
criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
|
|
|
|
if args.restore_path:
|
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
|
try:
|
|
# TODO: fix optimizer init, model.cuda() needs to be called before
|
|
# optimizer restore
|
|
# optimizer.load_state_dict(checkpoint['optimizer'])
|
|
if c.reinit_layers:
|
|
raise RuntimeError
|
|
model.load_state_dict(checkpoint['model'])
|
|
except:
|
|
print(" > Partial model initialization.")
|
|
model_dict = model.state_dict()
|
|
model_dict = set_init_dict(model_dict, checkpoint, c)
|
|
model.load_state_dict(model_dict)
|
|
del model_dict
|
|
for group in optimizer.param_groups:
|
|
group['lr'] = c.lr
|
|
print(" > Model restored from step %d" % checkpoint['step'],
|
|
flush=True)
|
|
args.restore_step = checkpoint['step']
|
|
else:
|
|
args.restore_step = 0
|
|
|
|
if use_cuda:
|
|
model.cuda()
|
|
criterion.cuda()
|
|
|
|
# DISTRUBUTED
|
|
if num_gpus > 1:
|
|
model = apply_gradient_allreduce(model)
|
|
|
|
if c.noam_schedule:
|
|
scheduler = NoamLR(optimizer,
|
|
warmup_steps=c.warmup_steps,
|
|
last_epoch=args.restore_step - 1)
|
|
else:
|
|
scheduler = None
|
|
|
|
num_params = count_parameters(model)
|
|
print("\n > Model has {} parameters".format(num_params), flush=True)
|
|
|
|
if 'best_loss' not in locals():
|
|
best_loss = float('inf')
|
|
|
|
global_step = args.restore_step
|
|
for epoch in range(0, c.epochs):
|
|
c_logger.print_epoch_start(epoch, c.epochs)
|
|
# set gradual training
|
|
if c.gradual_training is not None:
|
|
r, c.batch_size = gradual_training_scheduler(global_step, c)
|
|
c.r = r
|
|
model.decoder.set_r(r)
|
|
if c.bidirectional_decoder:
|
|
model.decoder_backward.set_r(r)
|
|
print("\n > Number of output frames:", model.decoder.r)
|
|
|
|
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
|
optimizer_st, scheduler, ap,
|
|
global_step, epoch)
|
|
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
|
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
|
if c.run_eval:
|
|
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
|
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
|
OUT_PATH)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--continue_path',
|
|
type=str,
|
|
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
|
default='',
|
|
required='--config_path' not in sys.argv)
|
|
parser.add_argument(
|
|
'--restore_path',
|
|
type=str,
|
|
help='Model file to be restored. Use to finetune a model.',
|
|
default='')
|
|
parser.add_argument(
|
|
'--config_path',
|
|
type=str,
|
|
help='Path to config file for training.',
|
|
required='--continue_path' not in sys.argv
|
|
)
|
|
parser.add_argument('--debug',
|
|
type=bool,
|
|
default=False,
|
|
help='Do not verify commit integrity to run training.')
|
|
|
|
# DISTRUBUTED
|
|
parser.add_argument(
|
|
'--rank',
|
|
type=int,
|
|
default=0,
|
|
help='DISTRIBUTED: process rank for distributed training.')
|
|
parser.add_argument('--group_id',
|
|
type=str,
|
|
default="",
|
|
help='DISTRIBUTED: process group id.')
|
|
args = parser.parse_args()
|
|
|
|
if args.continue_path != '':
|
|
args.output_path = args.continue_path
|
|
args.config_path = os.path.join(args.continue_path, 'config.json')
|
|
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
|
latest_model_file = max(list_of_files, key=os.path.getctime)
|
|
args.restore_path = latest_model_file
|
|
print(f" > Training continues for {args.restore_path}")
|
|
|
|
# setup output paths and read configs
|
|
c = load_config(args.config_path)
|
|
check_config(c)
|
|
_ = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
OUT_PATH = args.continue_path
|
|
if args.continue_path == '':
|
|
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
|
|
|
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
|
|
|
c_logger = ConsoleLogger()
|
|
|
|
if args.rank == 0:
|
|
os.makedirs(AUDIO_PATH, exist_ok=True)
|
|
new_fields = {}
|
|
if args.restore_path:
|
|
new_fields["restore_path"] = args.restore_path
|
|
new_fields["github_branch"] = get_git_branch()
|
|
copy_config_file(args.config_path,
|
|
os.path.join(OUT_PATH, 'config.json'), new_fields)
|
|
os.chmod(AUDIO_PATH, 0o775)
|
|
os.chmod(OUT_PATH, 0o775)
|
|
|
|
LOG_DIR = OUT_PATH
|
|
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
|
|
|
# write model desc to tensorboard
|
|
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
|
|
|
try:
|
|
main(args)
|
|
except KeyboardInterrupt:
|
|
remove_experiment_folder(OUT_PATH)
|
|
try:
|
|
sys.exit(0)
|
|
except SystemExit:
|
|
os._exit(0) # pylint: disable=protected-access
|
|
except Exception: # pylint: disable=broad-except
|
|
remove_experiment_folder(OUT_PATH)
|
|
traceback.print_exc()
|
|
sys.exit(1)
|