train.py update imports for utils refactoring

This commit is contained in:
erogol 2020-05-12 13:46:58 +02:00
Родитель 2d9dcd60ba
Коммит c0c3c6e331
1 изменённых файлов: 12 добавлений и 11 удалений

Просмотреть файл

@ -14,12 +14,13 @@ from distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor)
from TTS.layers.losses import TacotronLoss
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (
NoamLR, check_update, count_parameters, create_experiment_folder,
get_git_branch, load_config, remove_experiment_folder, save_best_model,
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
setup_model, gradual_training_scheduler, KeepAverage,
set_weight_decay, check_config)
from TTS.utils.generic_utils import (count_parameters, create_experiment_folder, remove_experiment_folder,
get_git_branch, set_init_dict,
setup_model, KeepAverage, check_config)
from TTS.utils.io import (save_best_model, save_checkpoint,
load_config, copy_config_file)
from TTS.utils.training import (NoamLR, check_update, adam_weight_decay,
gradual_training_scheduler, set_weight_decay)
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
@ -251,9 +252,9 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, optimizer_st,
loss_dict['postnet_loss'].item(), OUT_PATH, global_step,
epoch)
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
optimizer_st=optimizer_st,
model_loss=loss_dict['postnet_loss'].item())
# Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy()
@ -596,8 +597,8 @@ def main(args): # pylint: disable=redefined-outer-name
target_loss = train_avg_loss_dict['avg_postnet_loss']
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_postnet_loss']
best_loss = save_best_model(model, optimizer, target_loss, best_loss,
OUT_PATH, global_step, epoch)
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH)
if __name__ == '__main__':