This commit is contained in:
erogol 2020-06-04 09:51:20 +02:00
Родитель ed466c8581
Коммит 2b361f0b50
1 изменённых файлов: 4 добавлений и 2 удалений

Просмотреть файл

@ -120,11 +120,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
y_hat = model_G(c_G)
y_hat_sub = None
y_G_sub = None
y_hat_vis = y_hat # for visualization
# PQMF formatting
if y_hat.shape[1] > 1:
y_hat_sub = y_hat
y_hat = model_G.pqmf_synthesis(y_hat)
y_hat_vis = y_hat
y_G_sub = model_G.pqmf_analysis(y_G)
if global_step > c.steps_to_start_discriminator:
@ -265,12 +267,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
model_losses=loss_dict)
# compute spectrograms
figures = plot_results(y_hat, y_G, ap, global_step,
figures = plot_results(y_hat_vis, y_G, ap, global_step,
'train')
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice},
c.audio["sample_rate"])