TTS/train.py

725 строки
30 KiB
Python
Исходник Обычный вид История

import argparse
2018-01-22 12:48:59 +03:00
import os
import sys
import glob
2018-01-22 12:48:59 +03:00
import time
import traceback
2018-01-22 12:48:59 +03:00
import numpy as np
import torch
2018-01-22 12:48:59 +03:00
import torch.nn as nn
from torch.utils.data import DataLoader
from TTS.datasets.TTSDataset import MyDataset
from distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor)
from TTS.layers.losses import L1LossMasked, MSELossMasked
from TTS.utils.audio import AudioProcessor
2019-10-04 19:36:32 +03:00
from TTS.utils.generic_utils import (
NoamLR, check_update, count_parameters, create_experiment_folder,
get_git_branch, load_config, remove_experiment_folder, save_best_model,
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
setup_model, gradual_training_scheduler, KeepAverage,
set_weight_decay)
from TTS.utils.logger import Logger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
2019-07-10 19:38:55 +03:00
get_speakers
from TTS.utils.synthesis import synthesis
from TTS.utils.text.symbols import phonemes, symbols
from TTS.utils.visual import plot_alignment, plot_spectrogram
2019-09-30 16:03:18 +03:00
from TTS.datasets.preprocess import load_meta_data
2019-08-30 11:15:54 +03:00
from TTS.utils.radam import RAdam
from TTS.utils.measures import alignment_diagonal_score
2019-08-30 11:15:54 +03:00
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(54321)
2018-01-22 12:48:59 +03:00
use_cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
2018-01-22 12:48:59 +03:00
2018-03-02 18:54:35 +03:00
def setup_loader(ap, r, is_val=False, verbose=False):
if is_val and not c.run_eval:
loader = None
else:
dataset = MyDataset(
r,
c.text_cleaner,
2019-09-17 19:44:53 +03:00
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
2019-10-04 19:36:32 +03:00
batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
2019-07-17 13:11:16 +03:00
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
2019-01-16 15:09:47 +03:00
phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
2019-04-12 17:12:15 +03:00
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
return loader
def format_data(data):
2019-07-10 19:38:55 +03:00
if c.use_speaker_embedding:
speaker_mapping = load_speaker_mapping(OUT_PATH)
# setup input data
text_input = data[0]
text_lengths = data[1]
speaker_names = data[2]
2019-11-19 15:07:06 +03:00
linear_input = data[3] if c.model in ["Tacotron"] else None
mel_input = data[4]
mel_lengths = data[5]
stop_targets = data[6]
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
if c.use_speaker_embedding:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = torch.LongTensor(speaker_ids)
else:
speaker_ids = None
2019-11-19 15:07:06 +03:00
# set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze(2)
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda(non_blocking=True)
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
2019-11-19 15:07:06 +03:00
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron"] else None
stop_targets = stop_targets.cuda(non_blocking=True)
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
ap, global_step, epoch):
data_loader = setup_loader(ap, model.decoder.r, is_val=False,
verbose=(epoch == 0))
model.train()
2018-03-02 18:54:35 +03:00
epoch_time = 0
train_values = {
'avg_postnet_loss': 0,
'avg_decoder_loss': 0,
'avg_stop_loss': 0,
'avg_align_score': 0,
'avg_step_time': 0,
'avg_loader_time': 0,
2019-10-04 19:36:32 +03:00
'avg_alignment_score': 0
}
if c.bidirectional_decoder:
train_values['avg_decoder_b_loss'] = 0 # decoder backward loss
train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss
keep_avg = KeepAverage()
keep_avg.add_values(train_values)
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
2019-08-19 17:27:53 +03:00
if use_cuda:
2019-10-04 19:36:32 +03:00
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
2019-08-19 17:27:53 +03:00
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
2019-07-22 03:11:20 +03:00
end_time = time.time()
2018-03-02 18:54:35 +03:00
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data)
2019-07-22 03:11:20 +03:00
loader_time = time.time() - end_time
global_step += 1
2018-03-02 18:54:35 +03:00
# setup lr
2020-01-15 14:07:19 +03:00
if c.noam_schedule:
2018-11-03 21:47:28 +03:00
scheduler.step()
2018-03-02 18:54:35 +03:00
optimizer.zero_grad()
2019-07-19 09:46:23 +03:00
if optimizer_st:
optimizer_st.zero_grad()
2018-03-02 18:54:35 +03:00
# forward pass model
if c.bidirectional_decoder:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
2018-04-03 13:24:57 +03:00
2018-03-02 18:54:35 +03:00
# loss computation
2019-10-04 19:36:32 +03:00
stop_loss = criterion_st(stop_tokens,
stop_targets) if c.stopnet else torch.zeros(1)
2019-04-10 17:41:08 +03:00
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
2019-11-19 18:48:04 +03:00
if c.model in ["Tacotron", "TacotronGST"]:
2019-10-04 19:36:32 +03:00
postnet_loss = criterion(postnet_output, linear_input,
mel_lengths)
2019-04-10 17:41:08 +03:00
else:
2019-10-04 19:36:32 +03:00
postnet_loss = criterion(postnet_output, mel_input,
mel_lengths)
else:
2019-04-10 17:41:08 +03:00
decoder_loss = criterion(decoder_output, mel_input)
2019-11-19 18:48:04 +03:00
if c.model in ["Tacotron", "TacotronGST"]:
2019-04-10 17:41:08 +03:00
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
loss = decoder_loss + postnet_loss
if not c.separate_stopnet and c.stopnet:
loss += stop_loss
2018-03-02 18:54:35 +03:00
# backward decoder
if c.bidirectional_decoder:
if c.loss_masking:
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
else:
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
loss += decoder_backward_loss + decoder_c_loss
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
loss.backward()
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, grad_flag = check_update(model, c.grad_clip, ignore_stopnet=True)
2018-03-02 18:54:35 +03:00
optimizer.step()
2019-09-11 11:39:59 +03:00
# compute alignment score
align_score = alignment_diagonal_score(alignments)
keep_avg.update_value('avg_align_score', align_score)
# backpass and check the grad norm for stop loss
if c.separate_stopnet:
stop_loss.backward()
optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step()
else:
grad_norm_st = 0
2018-03-02 18:54:35 +03:00
step_time = time.time() - start_time
epoch_time += step_time
if global_step % c.print_step == 0:
print(
" | > Step:{}/{} GlobalStep:{} PostnetLoss:{:.5f} "
"DecoderLoss:{:.5f} StopLoss:{:.5f} AlignScore:{:.4f} GradNorm:{:.5f} "
2019-07-22 03:11:20 +03:00
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
"LoaderTime:{:.2f} LR:{:.6f}".format(
2019-10-04 19:36:32 +03:00
num_iter, batch_n_iter, global_step, postnet_loss.item(),
decoder_loss.item(), stop_loss.item(), align_score,
grad_norm, grad_norm_st, avg_text_length, avg_spec_length,
step_time, loader_time, current_lr),
flush=True)
# aggregate losses from processes
if num_gpus > 1:
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
loss = reduce_tensor(loss.data, num_gpus)
2019-10-04 19:36:32 +03:00
stop_loss = reduce_tensor(stop_loss.data,
num_gpus) if c.stopnet else stop_loss
if args.rank == 0:
2019-10-04 19:36:32 +03:00
update_train_values = {
'avg_postnet_loss':
float(postnet_loss.item()),
'avg_decoder_loss':
float(decoder_loss.item()),
'avg_stop_loss':
stop_loss
if isinstance(stop_loss, float) else float(stop_loss.item()),
'avg_step_time':
step_time,
'avg_loader_time':
loader_time
}
keep_avg.update_values(update_train_values)
# Plot Training Iter Stats
2019-08-22 16:28:11 +03:00
# reduce TB load
if global_step % 10 == 0:
2019-10-04 19:36:32 +03:00
iter_stats = {
"loss_posnet": postnet_loss.item(),
"loss_decoder": decoder_loss.item(),
"lr": current_lr,
"grad_norm": grad_norm,
"grad_norm_st": grad_norm_st,
"step_time": step_time
}
2019-08-22 16:28:11 +03:00
tb_logger.tb_train_iter_stats(global_step, iter_stats)
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, optimizer_st,
postnet_loss.item(), OUT_PATH, global_step,
epoch)
# Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
2019-11-19 18:48:04 +03:00
"Tacotron", "TacotronGST"
2019-10-04 19:36:32 +03:00
] else mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(const_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img),
}
2019-10-29 21:07:08 +03:00
if c.bidirectional_decoder:
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy())
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
2019-11-19 18:48:04 +03:00
if c.model in ["Tacotron", "TacotronGST"]:
train_audio = ap.inv_spectrogram(const_spec.T)
else:
train_audio = ap.inv_mel_spectrogram(const_spec.T)
tb_logger.tb_train_audios(global_step,
2019-07-19 09:46:23 +03:00
{'TrainAudio': train_audio},
c.audio["sample_rate"])
2019-07-22 16:10:19 +03:00
end_time = time.time()
2018-07-11 13:42:59 +03:00
# print epoch stats
print(" | > EPOCH END -- GlobalStep:{} "
2019-10-04 19:36:32 +03:00
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
"AvgStopLoss:{:.5f} AvgAlignScore:{:3f} EpochTime:{:.2f} "
2019-10-04 19:36:32 +03:00
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(
global_step, keep_avg['avg_postnet_loss'],
keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'],
keep_avg['avg_align_score'], epoch_time,
keep_avg['avg_step_time'], keep_avg['avg_loader_time']),
flush=True)
# Plot Epoch Stats
if args.rank == 0:
# Plot Training Epoch Stats
2019-10-04 19:36:32 +03:00
epoch_stats = {
"loss_postnet": keep_avg['avg_postnet_loss'],
"loss_decoder": keep_avg['avg_decoder_loss'],
"stop_loss": keep_avg['avg_stop_loss'],
"alignment_score": keep_avg['avg_align_score'],
"epoch_time": epoch_time
}
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, global_step)
return keep_avg['avg_postnet_loss'], global_step
2018-03-02 18:54:35 +03:00
2018-04-03 13:24:57 +03:00
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
2019-07-10 19:38:55 +03:00
if c.use_speaker_embedding:
speaker_mapping = load_speaker_mapping(OUT_PATH)
model.eval()
2018-03-02 18:54:35 +03:00
epoch_time = 0
2019-10-04 19:36:32 +03:00
eval_values_dict = {
'avg_postnet_loss': 0,
'avg_decoder_loss': 0,
'avg_stop_loss': 0,
'avg_align_score': 0
}
if c.bidirectional_decoder:
eval_values_dict['avg_decoder_b_loss'] = 0 # decoder backward loss
eval_values_dict['avg_decoder_c_loss'] = 0 # decoder consistency loss
2019-09-11 11:39:59 +03:00
keep_avg = KeepAverage()
keep_avg.add_values(eval_values_dict)
print("\n > Validation")
2018-05-11 02:44:37 +03:00
with torch.no_grad():
if data_loader is not None:
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
assert mel_input.shape[1] % model.decoder.r == 0
# forward pass model
if c.bidirectional_decoder:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
# loss computation
stop_loss = criterion_st(
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
2019-04-10 17:41:08 +03:00
if c.loss_masking:
2019-10-04 19:36:32 +03:00
decoder_loss = criterion(decoder_output, mel_input,
mel_lengths)
2019-11-19 18:48:04 +03:00
if c.model in ["Tacotron", "TacotronGST"]:
2019-10-04 19:36:32 +03:00
postnet_loss = criterion(postnet_output, linear_input,
mel_lengths)
2019-04-10 17:41:08 +03:00
else:
2019-10-04 19:36:32 +03:00
postnet_loss = criterion(postnet_output, mel_input,
mel_lengths)
else:
2019-04-10 17:41:08 +03:00
decoder_loss = criterion(decoder_output, mel_input)
2019-11-19 18:48:04 +03:00
if c.model in ["Tacotron", "TacotronGST"]:
2019-04-10 17:41:08 +03:00
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
loss = decoder_loss + postnet_loss + stop_loss
# backward decoder loss
if c.bidirectional_decoder:
if c.loss_masking:
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
else:
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
loss += decoder_backward_loss + decoder_c_loss
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
step_time = time.time() - start_time
epoch_time += step_time
# compute alignment score
align_score = alignment_diagonal_score(alignments)
keep_avg.update_value('avg_align_score', align_score)
# aggregate losses from processes
if num_gpus > 1:
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
2019-05-14 19:10:35 +03:00
if c.stopnet:
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
2019-10-04 19:36:32 +03:00
keep_avg.update_values({
'avg_postnet_loss':
float(postnet_loss.item()),
'avg_decoder_loss':
float(decoder_loss.item()),
'avg_stop_loss':
float(stop_loss.item()),
2019-10-04 19:36:32 +03:00
})
if num_iter % c.print_step == 0:
print(
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
2019-10-04 19:36:32 +03:00
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}"
.format(loss.item(), postnet_loss.item(),
keep_avg['avg_postnet_loss'],
decoder_loss.item(),
keep_avg['avg_decoder_loss'], stop_loss.item(),
keep_avg['avg_stop_loss'], align_score,
keep_avg['avg_align_score']),
flush=True)
2018-04-03 13:24:57 +03:00
if args.rank == 0:
# Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0])
const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
2019-11-19 18:48:04 +03:00
"Tacotron", "TacotronGST"
2019-10-04 19:36:32 +03:00
] else mel_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
eval_figures = {
"prediction": plot_spectrogram(const_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img)
}
# Sample audio
2019-11-19 18:48:04 +03:00
if c.model in ["Tacotron", "TacotronGST"]:
eval_audio = ap.inv_spectrogram(const_spec.T)
else:
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
2019-10-04 19:36:32 +03:00
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
c.audio["sample_rate"])
# Plot Validation Stats
2019-10-04 19:36:32 +03:00
epoch_stats = {
"loss_postnet": keep_avg['avg_postnet_loss'],
"loss_decoder": keep_avg['avg_decoder_loss'],
2020-01-03 02:30:01 +03:00
"stop_loss": keep_avg['avg_stop_loss'],
"alignment_score": keep_avg['avg_align_score']
2019-10-04 19:36:32 +03:00
}
if c.bidirectional_decoder:
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss']
align_b_img = alignments_backward[idx].data.cpu().numpy()
eval_figures['alignment_backward'] = plot_alignment(align_b_img)
tb_logger.tb_eval_stats(global_step, epoch_stats)
tb_logger.tb_eval_figures(global_step, eval_figures)
if args.rank == 0 and epoch > c.test_delay_epochs:
if c.test_sentences_file is None:
test_sentences = [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist."
]
else:
with open(c.test_sentences_file, "r") as f:
test_sentences = [s.strip() for s in f.readlines()]
# test sentences
test_audios = {}
test_figures = {}
print(" | > Synthesizing test sentences")
2019-07-12 11:50:20 +03:00
speaker_id = 0 if c.use_speaker_embedding else None
2019-07-24 13:17:08 +03:00
style_wav = c.get("style_wav_for_test")
for idx, test_sentence in enumerate(test_sentences):
try:
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
2019-10-04 19:36:32 +03:00
model,
test_sentence,
c,
use_cuda,
ap,
2019-07-24 13:17:08 +03:00
speaker_id=speaker_id,
style_wav=style_wav)
2019-08-19 17:27:53 +03:00
file_path = os.path.join(AUDIO_PATH, str(global_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
2019-07-19 09:46:23 +03:00
"TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav
2019-10-04 19:36:32 +03:00
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
postnet_output, ap)
test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment)
except:
print(" !! Error creating Test Sentence -", idx)
traceback.print_exc()
2019-10-04 19:36:32 +03:00
tb_logger.tb_test_audios(global_step, test_audios,
c.audio['sample_rate'])
tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg['avg_postnet_loss']
2018-04-03 13:24:57 +03:00
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name
global meta_data_train, meta_data_eval
2019-07-19 12:35:06 +03:00
# Audio processor
ap = AudioProcessor(**c.audio)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
2019-01-21 16:52:40 +03:00
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
2019-07-10 19:38:55 +03:00
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
# parse speakers
2019-07-10 19:38:55 +03:00
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
2019-07-10 19:38:55 +03:00
if args.restore_path:
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
else:
2019-10-04 19:36:32 +03:00
speaker_mapping = {name: i for i, name in enumerate(speakers)}
2019-07-10 19:38:55 +03:00
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
", ".join(speakers)))
else:
num_speakers = 0
model = setup_model(num_chars, num_speakers, c)
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
2018-03-22 22:34:16 +03:00
params = set_weight_decay(model, c.wd)
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet:
2019-10-04 19:36:32 +03:00
optimizer_st = RAdam(model.decoder.stopnet.parameters(),
lr=c.lr,
weight_decay=0)
else:
optimizer_st = None
2018-04-03 13:24:57 +03:00
2019-04-10 17:41:08 +03:00
if c.loss_masking:
2019-11-19 18:48:04 +03:00
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
2019-10-04 19:36:32 +03:00
] else MSELossMasked()
2019-04-10 17:41:08 +03:00
else:
2019-11-19 18:48:04 +03:00
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
2019-10-04 19:36:32 +03:00
] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss(
pos_weight=torch.tensor(10)) if c.stopnet else None
2018-01-22 12:48:59 +03:00
2018-03-06 16:39:54 +03:00
if args.restore_path:
2019-12-09 15:34:17 +03:00
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
# optimizer restore
# optimizer.load_state_dict(checkpoint['optimizer'])
2019-07-19 09:46:23 +03:00
if c.reinit_layers:
raise RuntimeError
model.load_state_dict(checkpoint['model'])
except:
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint, c)
model.load_state_dict(model_dict)
del model_dict
2019-02-12 12:04:39 +03:00
for group in optimizer.param_groups:
group['lr'] = c.lr
2019-10-04 19:36:32 +03:00
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
2018-03-02 16:42:23 +03:00
args.restore_step = checkpoint['step']
2018-02-26 16:33:54 +03:00
else:
args.restore_step = 0
if use_cuda:
2019-12-09 15:27:46 +03:00
model.cuda()
criterion.cuda()
2019-07-19 09:46:23 +03:00
if criterion_st:
criterion_st.cuda()
2018-02-26 16:33:54 +03:00
# DISTRUBUTED
if num_gpus > 1:
model = apply_gradient_allreduce(model)
if c.noam_schedule:
2019-10-04 19:36:32 +03:00
scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)
else:
scheduler = None
2018-02-23 17:20:22 +03:00
num_params = count_parameters(model)
print("\n > Model has {} parameters".format(num_params), flush=True)
2018-04-03 13:24:57 +03:00
2018-02-27 17:25:28 +03:00
if 'best_loss' not in locals():
best_loss = float('inf')
2018-04-03 13:24:57 +03:00
global_step = args.restore_step
for epoch in range(0, c.epochs):
2019-07-20 13:33:21 +03:00
# set gradual training
2019-07-22 03:11:20 +03:00
if c.gradual_training is not None:
r, c.batch_size = gradual_training_scheduler(global_step, c)
2019-07-22 03:11:20 +03:00
c.r = r
2019-08-16 15:22:35 +03:00
model.decoder.set_r(r)
if c.bidirectional_decoder:
model.decoder_backward.set_r(r)
2019-07-20 13:33:21 +03:00
print(" > Number of outputs per iteration:", model.decoder.r)
train_loss, global_step = train(model, criterion, criterion_st,
2019-10-04 19:36:32 +03:00
optimizer, optimizer_st, scheduler, ap,
global_step, epoch)
val_loss = evaluate(model, criterion, criterion_st, ap, global_step,
epoch)
print(" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
train_loss, val_loss),
flush=True)
target_loss = train_loss
if c.run_eval:
target_loss = val_loss
best_loss = save_best_model(model, optimizer, target_loss, best_loss,
OUT_PATH, global_step, epoch)
2018-02-13 12:45:52 +03:00
2018-04-03 13:24:57 +03:00
2018-01-22 12:48:59 +03:00
if __name__ == '__main__':
2018-07-17 16:59:31 +03:00
parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
2019-10-31 17:13:39 +03:00
required='--config_path' not in sys.argv)
2018-08-02 17:34:17 +03:00
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
2018-08-02 17:34:17 +03:00
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv
2018-08-02 17:34:17 +03:00
)
2019-10-04 19:36:32 +03:00
parser.add_argument('--debug',
type=bool,
default=True,
help='Do not verify commit integrity to run training.')
# DISTRUBUTED
parser.add_argument(
'--rank',
type=int,
default=0,
help='DISTRIBUTED: process rank for distributed training.')
2019-10-04 19:36:32 +03:00
parser.add_argument('--group_id',
type=str,
default="",
help='DISTRIBUTED: process group id.')
2018-07-17 16:59:31 +03:00
args = parser.parse_args()
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
2019-11-19 18:48:04 +03:00
# setup output paths and read configs
2018-07-17 16:59:31 +03:00
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = args.continue_path
if args.continue_path == '':
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
2019-10-04 19:36:32 +03:00
copy_config_file(args.config_path,
os.path.join(OUT_PATH, 'config.json'), new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
2019-07-19 09:46:23 +03:00
if args.rank == 0:
LOG_DIR = OUT_PATH
tb_logger = Logger(LOG_DIR)
2018-07-17 16:59:31 +03:00
try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)