зеркало из https://github.com/mozilla/TTS.git
Checkpoint fix
This commit is contained in:
Родитель
1fa791f83e
Коммит
49f61d0b9e
15
config.json
15
config.json
|
@ -9,21 +9,18 @@
|
||||||
"ref_level_db": 20,
|
"ref_level_db": 20,
|
||||||
"hidden_size": 128,
|
"hidden_size": 128,
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 10000,
|
"epochs": 200,
|
||||||
"lr": 0.01,
|
"lr": 0.01,
|
||||||
"decay_step": [500000, 1000000, 2000000],
|
"lr_patience": 2,
|
||||||
"batch_size": 128,
|
"lr_decay": 0.5,
|
||||||
"max_iters": 200,
|
"batch_size": 256,
|
||||||
"griffinf_lim_iters": 60,
|
"griffinf_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.5,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"log_step": 100,
|
"save_step": 1,
|
||||||
"save_step": 2000,
|
|
||||||
|
|
||||||
"text_cleaner": "english_cleaners",
|
|
||||||
|
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
"output_path": "result",
|
"output_path": "result",
|
||||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
|
|
32
train.py
32
train.py
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import datetime
|
||||||
import shutil
|
import shutil
|
||||||
import torch
|
import torch
|
||||||
import signal
|
import signal
|
||||||
|
@ -13,6 +14,7 @@ import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
|
@ -97,12 +99,15 @@ def main(args):
|
||||||
|
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
|
|
||||||
|
lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
||||||
|
patience=c.lr_patience, verbose=True)
|
||||||
|
|
||||||
for epoch in range(c.epochs):
|
for epoch in range(c.epochs):
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
drop_last=True, num_workers=32)
|
drop_last=True, num_workers=32)
|
||||||
print("\n | > Epoch {}".format(epoch))
|
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
progbar = Progbar(len(dataset) / c.batch_size)
|
progbar = Progbar(len(dataset) / c.batch_size)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
|
@ -162,7 +167,7 @@ def main(args):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
time_per_step = time.time() - start_time
|
time_per_step = time.time() - start_time
|
||||||
progbar.update(i, values=[('total_loss', loss.data[0]),
|
progbar.update(i+1, values=[('total_loss', loss.data[0]),
|
||||||
('linear_loss', linear_loss.data[0]),
|
('linear_loss', linear_loss.data[0]),
|
||||||
('mel_loss', mel_loss.data[0])])
|
('mel_loss', mel_loss.data[0])])
|
||||||
|
|
||||||
|
@ -181,27 +186,8 @@ def main(args):
|
||||||
'mel_loss': mel_loss.data[0],
|
'mel_loss': mel_loss.data[0],
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y")},
|
'date': datetime.date.today().strftime("%B %d, %Y")},
|
||||||
checkpoint_path)
|
checkpoint_path)
|
||||||
print(" > Checkpoint is saved : {}".format(checkpoint_path))
|
print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
|
||||||
|
lr_scheduler.step(loss.data[0])
|
||||||
if current_step in c.decay_step:
|
|
||||||
optimizer = adjust_learning_rate(optimizer, current_step)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_learning_rate(optimizer, step):
|
|
||||||
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
|
||||||
if step == 500000:
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group['lr'] = 0.0005
|
|
||||||
|
|
||||||
elif step == 1000000:
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group['lr'] = 0.0003
|
|
||||||
|
|
||||||
elif step == 2000000:
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group['lr'] = 0.0001
|
|
||||||
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -5,6 +5,7 @@ import time
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче