зеркало из https://github.com/mozilla/TTS.git
train.py update imports for utils refactoring
This commit is contained in:
Родитель
2d9dcd60ba
Коммит
c0c3c6e331
23
train.py
23
train.py
|
@ -14,12 +14,13 @@ from distribute import (DistributedSampler, apply_gradient_allreduce,
|
|||
init_distributed, reduce_tensor)
|
||||
from TTS.layers.losses import TacotronLoss
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import (
|
||||
NoamLR, check_update, count_parameters, create_experiment_folder,
|
||||
get_git_branch, load_config, remove_experiment_folder, save_best_model,
|
||||
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
|
||||
setup_model, gradual_training_scheduler, KeepAverage,
|
||||
set_weight_decay, check_config)
|
||||
from TTS.utils.generic_utils import (count_parameters, create_experiment_folder, remove_experiment_folder,
|
||||
get_git_branch, set_init_dict,
|
||||
setup_model, KeepAverage, check_config)
|
||||
from TTS.utils.io import (save_best_model, save_checkpoint,
|
||||
load_config, copy_config_file)
|
||||
from TTS.utils.training import (NoamLR, check_update, adam_weight_decay,
|
||||
gradual_training_scheduler, set_weight_decay)
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||
|
@ -251,9 +252,9 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, optimizer_st,
|
||||
loss_dict['postnet_loss'].item(), OUT_PATH, global_step,
|
||||
epoch)
|
||||
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
||||
optimizer_st=optimizer_st,
|
||||
model_loss=loss_dict['postnet_loss'].item())
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = postnet_output[0].data.cpu().numpy()
|
||||
|
@ -596,8 +597,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
||||
best_loss = save_best_model(model, optimizer, target_loss, best_loss,
|
||||
OUT_PATH, global_step, epoch)
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
OUT_PATH)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Загрузка…
Ссылка в новой задаче