This commit is contained in:
Eren Golge 2018-01-22 08:20:20 -08:00
Родитель fd18e1cf34
Коммит 72e1357c80
4 изменённых файлов: 42 добавлений и 59 удалений

Просмотреть файл

@ -1,9 +1,11 @@
import os
import sys
import time
import shutil
import torch
import signal
import argparse
import importlib
import numpy as np
import torch.nn as nn
@ -11,37 +13,43 @@ from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import train_config as c
from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint)
create_experiment_folder, save_checkpoint,
load_config)
from utils.model import get_param_size
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
use_cuda = torch.cuda.is_available()
def main(args):
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# Ctrl+C handler to remove empty experiment folder
def signal_handler(signal, frame):
print(" !! Pressed Ctrl+C !!")
remove_experiment_folder(OUT_PATH)
sys.exit(0)
def main(args):
signal.signal(signal.SIGINT, signal_handler)
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path, 'wavs'),
c.dec_out_per_step
c.r
)
model = Tacotron(c.embedding_size,
c.hidden_size,
c.num_mels,
c.num_freq,
c.dec_out_per_step)
c.r)
if use_cuda:
model = nn.DataParallel(model.cuda())
@ -49,7 +57,7 @@ def main(args):
try:
checkpoint = torch.load(os.path.join(
c.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
CHECKPOINT_PATH, 'checkpoint_%d.pth.tar' % args.restore_step))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % args.restore_step)
@ -59,8 +67,8 @@ def main(args):
model = model.train()
if not os.path.exists(c.checkpoint_path):
os.mkdir(c.checkpoint_path)
if not os.path.exists(CHECKPOINT_PATH):
os.mkdir(CHECKPOINT_PATH)
if use_cuda:
criterion = nn.L1Loss().cuda()
@ -71,10 +79,10 @@ def main(args):
for epoch in range(c.epochs):
dataloader = DataLoader(dataset, batch_size=args.batch_size,
dataloader = DataLoader(dataset, batch_size=c.batch_size,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=32)
progbar = Progbar(len(dataset) / args.batch_size)
progbar = Progbar(len(dataset) / c.batch_size)
for i, data in enumerate(dataloader):
text_input = data[0]
@ -87,7 +95,7 @@ def main(args):
try:
mel_input = np.concatenate((np.zeros(
[args.batch_size, 1, c.num_mels], dtype=np.float32),
[c.batch_size, 1, c.num_mels], dtype=np.float32),
mel_input[:, 1:, :]), axis=1)
except:
raise TypeError("not same dimension")
@ -175,12 +183,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int,
help='Global step to restore checkpoint', default=128)
parser.add_argument('--batch_size', type=int,
help='Batch size', default=128)
parser.add_argument('--config', type=str,
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
args = parser.parse_args()
signal.signal(signal.SIGINT, signal_handler)
main(args)

Просмотреть файл

@ -1,33 +0,0 @@
# Audio
num_mels = 80
num_freq = 1024
sample_rate = 20000
frame_length_ms = 50.
frame_shift_ms = 12.5
preemphasis = 0.97
min_level_db = -100
ref_level_db = 20
hidden_size = 128
embedding_size = 256
# training
epochs = 10000
lr = 0.001
decay_step = [500000, 1000000, 2000000]
batch_size = 128
max_iters = 200
griffin_lim_iters = 60
power = 1.5
dec_out_per_step = 5
#teacher_forcing_ratio = 1.0
# outputing
log_step = 100
save_step = 2000
# text processing
cleaners = 'english_cleaners'
# data settings
data_path = '/data/shared/KeithIto/LJSpeech-1.0/'
output_path = './result'

Двоичные данные
utils/.generic_utils.py.swp

Двоичный файл не отображается.

Просмотреть файл

@ -4,9 +4,22 @@ import glob
import time
import shutil
import datetime
import json
import numpy as np
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def load_config(config_path):
config = AttrDict()
config.update(json.load(open(config_path, "r")))
return config
def create_experiment_folder(root_path):
""" Create a folder with the current date and time """
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p")
@ -20,7 +33,7 @@ def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
if len(checkpoint_files) == 0:
if len(checkpoint_files) < 2:
shutil.rmtree(experiment_path)
print(" ! Run is removed from {}".format(experiment_path))
else: