зеркало из https://github.com/mozilla/TTS.git
Change config to json
This commit is contained in:
Родитель
fd18e1cf34
Коммит
72e1357c80
41
train.py
41
train.py
|
@ -1,9 +1,11 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import shutil
|
||||
import torch
|
||||
import signal
|
||||
import argparse
|
||||
import importlib
|
||||
import numpy as np
|
||||
|
||||
import torch.nn as nn
|
||||
|
@ -11,37 +13,43 @@ from torch import optim
|
|||
from torch.autograd import Variable
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import train_config as c
|
||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||
create_experiment_folder, save_checkpoint)
|
||||
create_experiment_folder, save_checkpoint,
|
||||
load_config)
|
||||
from utils.model import get_param_size
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
from models.tacotron import Tacotron
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
OUT_PATH = os.path.join(_, c.output_path)
|
||||
OUT_PATH = create_experiment_folder(OUT_PATH)
|
||||
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
||||
|
||||
# Ctrl+C handler to remove empty experiment folder
|
||||
def signal_handler(signal, frame):
|
||||
print(" !! Pressed Ctrl+C !!")
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main(args):
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
c.dec_out_per_step
|
||||
c.r
|
||||
)
|
||||
|
||||
model = Tacotron(c.embedding_size,
|
||||
c.hidden_size,
|
||||
c.num_mels,
|
||||
c.num_freq,
|
||||
c.dec_out_per_step)
|
||||
c.r)
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
|
||||
|
@ -49,7 +57,7 @@ def main(args):
|
|||
|
||||
try:
|
||||
checkpoint = torch.load(os.path.join(
|
||||
c.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
|
||||
CHECKPOINT_PATH, 'checkpoint_%d.pth.tar' % args.restore_step))
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("\n > Model restored from step %d\n" % args.restore_step)
|
||||
|
@ -59,8 +67,8 @@ def main(args):
|
|||
|
||||
model = model.train()
|
||||
|
||||
if not os.path.exists(c.checkpoint_path):
|
||||
os.mkdir(c.checkpoint_path)
|
||||
if not os.path.exists(CHECKPOINT_PATH):
|
||||
os.mkdir(CHECKPOINT_PATH)
|
||||
|
||||
if use_cuda:
|
||||
criterion = nn.L1Loss().cuda()
|
||||
|
@ -71,10 +79,10 @@ def main(args):
|
|||
|
||||
for epoch in range(c.epochs):
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=args.batch_size,
|
||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||
shuffle=True, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=32)
|
||||
progbar = Progbar(len(dataset) / args.batch_size)
|
||||
progbar = Progbar(len(dataset) / c.batch_size)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
text_input = data[0]
|
||||
|
@ -87,7 +95,7 @@ def main(args):
|
|||
|
||||
try:
|
||||
mel_input = np.concatenate((np.zeros(
|
||||
[args.batch_size, 1, c.num_mels], dtype=np.float32),
|
||||
[c.batch_size, 1, c.num_mels], dtype=np.float32),
|
||||
mel_input[:, 1:, :]), axis=1)
|
||||
except:
|
||||
raise TypeError("not same dimension")
|
||||
|
@ -175,12 +183,7 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--restore_step', type=int,
|
||||
help='Global step to restore checkpoint', default=128)
|
||||
parser.add_argument('--batch_size', type=int,
|
||||
help='Batch size', default=128)
|
||||
parser.add_argument('--config', type=str,
|
||||
parser.add_argument('--config_path', type=str,
|
||||
help='path to config file for training',)
|
||||
args = parser.parse_args()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
main(args)
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
# Audio
|
||||
num_mels = 80
|
||||
num_freq = 1024
|
||||
sample_rate = 20000
|
||||
frame_length_ms = 50.
|
||||
frame_shift_ms = 12.5
|
||||
preemphasis = 0.97
|
||||
min_level_db = -100
|
||||
ref_level_db = 20
|
||||
hidden_size = 128
|
||||
embedding_size = 256
|
||||
|
||||
# training
|
||||
epochs = 10000
|
||||
lr = 0.001
|
||||
decay_step = [500000, 1000000, 2000000]
|
||||
batch_size = 128
|
||||
max_iters = 200
|
||||
griffin_lim_iters = 60
|
||||
power = 1.5
|
||||
dec_out_per_step = 5
|
||||
#teacher_forcing_ratio = 1.0
|
||||
|
||||
# outputing
|
||||
log_step = 100
|
||||
save_step = 2000
|
||||
|
||||
# text processing
|
||||
cleaners = 'english_cleaners'
|
||||
|
||||
# data settings
|
||||
data_path = '/data/shared/KeithIto/LJSpeech-1.0/'
|
||||
output_path = './result'
|
Двоичные данные
utils/.generic_utils.py.swp
Двоичные данные
utils/.generic_utils.py.swp
Двоичный файл не отображается.
|
@ -4,9 +4,22 @@ import glob
|
|||
import time
|
||||
import shutil
|
||||
import datetime
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
config = AttrDict()
|
||||
config.update(json.load(open(config_path, "r")))
|
||||
return config
|
||||
|
||||
|
||||
def create_experiment_folder(root_path):
|
||||
""" Create a folder with the current date and time """
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p")
|
||||
|
@ -20,7 +33,7 @@ def remove_experiment_folder(experiment_path):
|
|||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||
|
||||
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
|
||||
if len(checkpoint_files) == 0:
|
||||
if len(checkpoint_files) < 2:
|
||||
shutil.rmtree(experiment_path)
|
||||
print(" ! Run is removed from {}".format(experiment_path))
|
||||
else:
|
||||
|
|
Загрузка…
Ссылка в новой задаче