This commit is contained in:
Eren Golge 2018-03-02 05:42:23 -08:00
Родитель 81669c1e58
Коммит 793563b586
4 изменённых файлов: 10 добавлений и 6 удалений

Просмотреть файл

@ -12,7 +12,7 @@
"text_cleaner": "english_cleaners",
"epochs": 2000,
"lr": 0.005,
"lr": 0.0006,
"warmup_steps": 4000,
"batch_size": 180,
"r": 5,

Просмотреть файл

@ -1,4 +1,3 @@
import pandas as pd
import os
import numpy as np
import collections
@ -16,7 +15,10 @@ class LJSpeechDataset(Dataset):
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
text_cleaner, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power):
self.frames = pd.read_csv(csv_file, sep='|', header=None)
f = open(csv_file, "r")
self.frames = [line.split('|') for line in f]
f.close()
self.root_dir = root_dir
self.outputs_per_step = outputs_per_step
self.sample_rate = sample_rate
@ -40,7 +42,7 @@ class LJSpeechDataset(Dataset):
def __getitem__(self, idx):
wav_name = os.path.join(self.root_dir,
self.frames.ix[idx, 0]) + '.wav'
text = self.frames.ix[idx, 1]
text = self.frames[idx][1]
text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames.ix[idx, 0]}

Просмотреть файл

@ -73,7 +73,8 @@ def main(args):
dataloader = DataLoader(dataset, batch_size=c.batch_size,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
drop_last=True, num_workers=c.num_loader_workers,
pin_memory=True)
# setup the model
model = Tacotron(c.embedding_size,
@ -108,6 +109,7 @@ def main(args):
start_epoch = checkpoint['step'] // len(dataloader)
best_loss = checkpoint['linear_loss']
start_epoch = 0
args.restore_step = checkpoint['step']
else:
print("\n > Starting a new training")

Просмотреть файл

@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
def plot_alignment(alignment, info=None):
fig, ax = plt.subplots()
fig, ax = plt.subplots(figsize=(16,10))
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)