зеркало из https://github.com/mozilla/TTS.git
pep8 format all
This commit is contained in:
Родитель
3238ffa3e6
Коммит
f5537dc48f
|
@ -8,21 +8,28 @@ import torch
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
from utils.data import (prepare_data, pad_per_step,
|
||||
prepare_tensor, prepare_stop_target)
|
||||
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
|
||||
def __init__(self, root_dir, csv_file, outputs_per_step,
|
||||
text_cleaner, ap, min_seq_len=0):
|
||||
def __init__(self,
|
||||
root_dir,
|
||||
csv_file,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
min_seq_len=0):
|
||||
self.root_dir = root_dir
|
||||
self.wav_dir = os.path.join(root_dir, 'wav')
|
||||
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav'))
|
||||
self._create_file_dict()
|
||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||
self.frames = [line.split('\t') for line in f if line.split('\t')[0] in self.wav_files_dict.keys()]
|
||||
self.frames = [
|
||||
line.split('\t') for line in f
|
||||
if line.split('\t')[0] in self.wav_files_dict.keys()
|
||||
]
|
||||
self.outputs_per_step = outputs_per_step
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
|
@ -43,10 +50,8 @@ class MyDataset(Dataset):
|
|||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def _trim_silence(self, wav):
|
||||
return librosa.effects.trim(
|
||||
wav, top_db=40,
|
||||
frame_length=1024,
|
||||
hop_length=256)[0]
|
||||
return librosa.effects.trim(
|
||||
wav, top_db=40, frame_length=1024, hop_length=256)[0]
|
||||
|
||||
def _create_file_dict(self):
|
||||
self.wav_files_dict = {}
|
||||
|
@ -87,11 +92,10 @@ class MyDataset(Dataset):
|
|||
sidx = self.frames[idx][0]
|
||||
sidx_files = self.wav_files_dict[sidx]
|
||||
file_name = random.choice(sidx_files)
|
||||
wav_name = os.path.join(self.wav_dir,
|
||||
file_name)
|
||||
wav_name = os.path.join(self.wav_dir, file_name)
|
||||
text = self.frames[idx][2]
|
||||
text = np.asarray(text_to_sequence(
|
||||
text, [self.cleaners]), dtype=np.int32)
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||
return sample
|
||||
|
@ -121,12 +125,13 @@ class MyDataset(Dataset):
|
|||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [np.array([0.]*(mel_len-1))
|
||||
for mel_len in mel_lengths]
|
||||
stop_targets = [
|
||||
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||
]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(
|
||||
stop_targets, self.outputs_per_step)
|
||||
stop_targets = prepare_stop_target(stop_targets,
|
||||
self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
|
@ -150,8 +155,8 @@ class MyDataset(Dataset):
|
|||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||
0]
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}"
|
||||
.format(type(batch[0]))))
|
||||
found {}".format(type(batch[0]))))
|
||||
|
|
|
@ -6,14 +6,18 @@ import torch
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
from utils.data import (prepare_data, pad_per_step,
|
||||
prepare_tensor, prepare_stop_target)
|
||||
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
|
||||
def __init__(self, root_dir, csv_file, outputs_per_step,
|
||||
text_cleaner, ap, min_seq_len=0):
|
||||
def __init__(self,
|
||||
root_dir,
|
||||
csv_file,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
min_seq_len=0):
|
||||
self.root_dir = root_dir
|
||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||
|
@ -60,11 +64,10 @@ class MyDataset(Dataset):
|
|||
return len(self.frames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
wav_name = os.path.join(self.wav_dir,
|
||||
self.frames[idx][0]) + '.wav'
|
||||
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
|
||||
text = self.frames[idx][1]
|
||||
text = np.asarray(text_to_sequence(
|
||||
text, [self.cleaners]), dtype=np.int32)
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||
return sample
|
||||
|
@ -94,12 +97,13 @@ class MyDataset(Dataset):
|
|||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [np.array([0.]*(mel_len-1))
|
||||
for mel_len in mel_lengths]
|
||||
stop_targets = [
|
||||
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||
]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(
|
||||
stop_targets, self.outputs_per_step)
|
||||
stop_targets = prepare_stop_target(stop_targets,
|
||||
self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
|
@ -123,8 +127,8 @@ class MyDataset(Dataset):
|
|||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||
0]
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}"
|
||||
.format(type(batch[0]))))
|
||||
found {}".format(type(batch[0]))))
|
||||
|
|
|
@ -6,14 +6,18 @@ import torch
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
from utils.data import (prepare_data, pad_per_step,
|
||||
prepare_tensor, prepare_stop_target)
|
||||
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
|
||||
def __init__(self, root_dir, csv_file, outputs_per_step,
|
||||
text_cleaner, ap, min_seq_len=0):
|
||||
def __init__(self,
|
||||
root_dir,
|
||||
csv_file,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
min_seq_len=0):
|
||||
self.root_dir = root_dir
|
||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||
self.feat_dir = os.path.join(root_dir, 'loader_data')
|
||||
|
@ -35,7 +39,7 @@ class MyDataset(Dataset):
|
|||
return audio
|
||||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
|
||||
def load_np(self, filename):
|
||||
data = np.load(filename).astype('float32')
|
||||
return data
|
||||
|
@ -66,20 +70,24 @@ class MyDataset(Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
if self.items[idx] is None:
|
||||
wav_name = os.path.join(self.wav_dir,
|
||||
self.frames[idx][0]) + '.wav'
|
||||
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
|
||||
mel_name = os.path.join(self.feat_dir,
|
||||
self.frames[idx][0]) + '.mel.npy'
|
||||
linear_name = os.path.join(self.feat_dir,
|
||||
self.frames[idx][0]) + '.linear.npy'
|
||||
text = self.frames[idx][1]
|
||||
text = np.asarray(text_to_sequence(
|
||||
text, [self.cleaners]), dtype=np.int32)
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||
mel = self.load_np(mel_name)
|
||||
linear = self.load_np(linear_name)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0],
|
||||
'mel':mel, 'linear': linear}
|
||||
sample = {
|
||||
'text': text,
|
||||
'wav': wav,
|
||||
'item_idx': self.frames[idx][0],
|
||||
'mel': mel,
|
||||
'linear': linear
|
||||
}
|
||||
self.items[idx] = sample
|
||||
else:
|
||||
sample = self.items[idx]
|
||||
|
@ -109,12 +117,13 @@ class MyDataset(Dataset):
|
|||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [np.array([0.]*(mel_len-1))
|
||||
for mel_len in mel_lengths]
|
||||
stop_targets = [
|
||||
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||
]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(
|
||||
stop_targets, self.outputs_per_step)
|
||||
stop_targets = prepare_stop_target(stop_targets,
|
||||
self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
|
@ -138,8 +147,8 @@ class MyDataset(Dataset):
|
|||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||
0]
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}"
|
||||
.format(type(batch[0]))))
|
||||
found {}".format(type(batch[0]))))
|
||||
|
|
|
@ -7,15 +7,25 @@ from torch.utils.data import Dataset
|
|||
|
||||
from TTS.utils.text import text_to_sequence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.data import (prepare_data, pad_per_step,
|
||||
prepare_tensor, prepare_stop_target)
|
||||
from TTS.utils.data import (prepare_data, pad_per_step, prepare_tensor,
|
||||
prepare_stop_target)
|
||||
|
||||
|
||||
class TWEBDataset(Dataset):
|
||||
|
||||
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
||||
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||
def __init__(self,
|
||||
csv_file,
|
||||
root_dir,
|
||||
outputs_per_step,
|
||||
sample_rate,
|
||||
text_cleaner,
|
||||
num_mels,
|
||||
min_level_db,
|
||||
frame_shift_ms,
|
||||
frame_length_ms,
|
||||
preemphasis,
|
||||
ref_level_db,
|
||||
num_freq,
|
||||
power,
|
||||
min_seq_len=0):
|
||||
|
||||
with open(csv_file, "r") as f:
|
||||
|
@ -25,8 +35,9 @@ class TWEBDataset(Dataset):
|
|||
self.sample_rate = sample_rate
|
||||
self.cleaners = text_cleaner
|
||||
self.min_seq_len = min_seq_len
|
||||
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
||||
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db,
|
||||
frame_shift_ms, frame_length_ms, preemphasis,
|
||||
ref_level_db, num_freq, power)
|
||||
print(" > Reading TWEB from - {}".format(root_dir))
|
||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||
self._sort_frames()
|
||||
|
@ -63,11 +74,10 @@ class TWEBDataset(Dataset):
|
|||
return len(self.frames)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
wav_name = os.path.join(self.root_dir,
|
||||
self.frames[idx][0]) + '.wav'
|
||||
wav_name = os.path.join(self.root_dir, self.frames[idx][0]) + '.wav'
|
||||
text = self.frames[idx][1]
|
||||
text = np.asarray(text_to_sequence(
|
||||
text, [self.cleaners]), dtype=np.int32)
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||
return sample
|
||||
|
@ -97,12 +107,13 @@ class TWEBDataset(Dataset):
|
|||
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [np.array([0.]*(mel_len-1))
|
||||
for mel_len in mel_lengths]
|
||||
stop_targets = [
|
||||
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
|
||||
]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(
|
||||
stop_targets, self.outputs_per_step)
|
||||
stop_targets = prepare_stop_target(stop_targets,
|
||||
self.outputs_per_step)
|
||||
|
||||
# PAD sequences with largest length of the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
|
@ -126,8 +137,8 @@ class TWEBDataset(Dataset):
|
|||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
|
||||
0]
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}"
|
||||
.format(type(batch[0]))))
|
||||
found {}".format(type(batch[0]))))
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
"hidden_size": 128,
|
||||
"embedding_size": 256,
|
||||
"text_cleaner": "english_cleaners",
|
||||
|
||||
"epochs": 200,
|
||||
"lr": 0.01,
|
||||
"lr_patience": 2,
|
||||
|
@ -19,9 +18,7 @@
|
|||
"griffinf_lim_iters": 60,
|
||||
"power": 1.5,
|
||||
"r": 5,
|
||||
|
||||
"num_loader_workers": 16,
|
||||
|
||||
"save_step": 1,
|
||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||
"output_path": "result",
|
||||
|
|
|
@ -13,19 +13,19 @@ from utils.generic_utils import load_config
|
|||
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str,
|
||||
help='Data folder.')
|
||||
parser.add_argument('--out_path', type=str,
|
||||
help='Output folder.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='conf.json file for run settings.')
|
||||
parser.add_argument("--num_proc", type=int, default=8,
|
||||
help="number of processes.")
|
||||
parser.add_argument("--trim_silence", type=bool, default=False,
|
||||
help="trim silence in the voice clip.")
|
||||
parser.add_argument('--data_path', type=str, help='Data folder.')
|
||||
parser.add_argument('--out_path', type=str, help='Output folder.')
|
||||
parser.add_argument(
|
||||
'--config', type=str, help='conf.json file for run settings.')
|
||||
parser.add_argument(
|
||||
"--num_proc", type=int, default=8, help="number of processes.")
|
||||
parser.add_argument(
|
||||
"--trim_silence",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="trim silence in the voice clip.")
|
||||
args = parser.parse_args()
|
||||
DATA_PATH = args.data_path
|
||||
OUT_PATH = args.out_path
|
||||
|
@ -34,27 +34,26 @@ if __name__ == "__main__":
|
|||
print(" > Input path: ", DATA_PATH)
|
||||
print(" > Output path: ", OUT_PATH)
|
||||
|
||||
audio = importlib.import_module('utils.'+c.audio_processor)
|
||||
audio = importlib.import_module('utils.' + c.audio_processor)
|
||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||
ap = AudioProcessor(sample_rate = CONFIG.sample_rate,
|
||||
num_mels = CONFIG.num_mels,
|
||||
min_level_db = CONFIG.min_level_db,
|
||||
frame_shift_ms = CONFIG.frame_shift_ms,
|
||||
frame_length_ms = CONFIG.frame_length_ms,
|
||||
ref_level_db = CONFIG.ref_level_db,
|
||||
num_freq = CONFIG.num_freq,
|
||||
power = CONFIG.power,
|
||||
preemphasis = CONFIG.preemphasis,
|
||||
min_mel_freq = CONFIG.min_mel_freq,
|
||||
max_mel_freq = CONFIG.max_mel_freq)
|
||||
ap = AudioProcessor(
|
||||
sample_rate=CONFIG.sample_rate,
|
||||
num_mels=CONFIG.num_mels,
|
||||
min_level_db=CONFIG.min_level_db,
|
||||
frame_shift_ms=CONFIG.frame_shift_ms,
|
||||
frame_length_ms=CONFIG.frame_length_ms,
|
||||
ref_level_db=CONFIG.ref_level_db,
|
||||
num_freq=CONFIG.num_freq,
|
||||
power=CONFIG.power,
|
||||
preemphasis=CONFIG.preemphasis,
|
||||
min_mel_freq=CONFIG.min_mel_freq,
|
||||
max_mel_freq=CONFIG.max_mel_freq)
|
||||
|
||||
def trim_silence(self, wav):
|
||||
margin = int(CONFIG.sample_rate * 0.1)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(
|
||||
wav, top_db=40,
|
||||
frame_length=1024,
|
||||
hop_length=256)[0]
|
||||
wav, top_db=40, frame_length=1024, hop_length=256)[0]
|
||||
|
||||
def extract_mel(file_path):
|
||||
# x, fs = sf.read(file_path)
|
||||
|
@ -63,23 +62,25 @@ if __name__ == "__main__":
|
|||
x = trim_silence(x)
|
||||
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
|
||||
linear = ap.spectrogram(x.astype('float32')).astype('float32')
|
||||
file_name = os.path.basename(file_path).replace(".wav","")
|
||||
file_name = os.path.basename(file_path).replace(".wav", "")
|
||||
mel_file = file_name + ".mel"
|
||||
linear_file = file_name + ".linear"
|
||||
np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False)
|
||||
np.save(os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
|
||||
np.save(
|
||||
os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
|
||||
mel_len = mel.shape[1]
|
||||
linear_len = linear.shape[1]
|
||||
wav_len = x.shape[0]
|
||||
print(" > " + file_path, flush=True)
|
||||
return file_path, mel_file, linear_file, str(wav_len), str(mel_len), str(linear_len)
|
||||
return file_path, mel_file, linear_file, str(wav_len), str(
|
||||
mel_len), str(linear_len)
|
||||
|
||||
glob_path = os.path.join(DATA_PATH, "*.wav")
|
||||
print(" > Reading wav: {}".format(glob_path))
|
||||
file_names = glob.glob(glob_path, recursive=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(" > Number of files: %i"%(len(file_names)))
|
||||
print(" > Number of files: %i" % (len(file_names)))
|
||||
if not os.path.exists(OUT_PATH):
|
||||
os.makedirs(OUT_PATH)
|
||||
print(" > A new folder created at {}".format(OUT_PATH))
|
||||
|
@ -88,7 +89,10 @@ if __name__ == "__main__":
|
|||
if args.num_proc > 1:
|
||||
print(" > Using {} processes.".format(args.num_proc))
|
||||
with Pool(args.num_proc) as p:
|
||||
r = list(tqdm.tqdm(p.imap(extract_mel, file_names), total=len(file_names)))
|
||||
r = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(extract_mel, file_names),
|
||||
total=len(file_names)))
|
||||
# r = list(p.imap(extract_mel, file_names))
|
||||
else:
|
||||
print(" > Using single process run.")
|
||||
|
@ -100,5 +104,5 @@ if __name__ == "__main__":
|
|||
file = open(file_path, "w")
|
||||
for line in r:
|
||||
line = ", ".join(line)
|
||||
file.write(line+'\n')
|
||||
file.write(line + '\n')
|
||||
file.close()
|
||||
|
|
|
@ -24,8 +24,8 @@ class BahdanauAttention(nn.Module):
|
|||
processed_query = self.query_layer(query)
|
||||
processed_annots = self.annot_layer(annots)
|
||||
# (batch, max_time, 1)
|
||||
alignment = self.v(nn.functional.tanh(
|
||||
processed_query + processed_annots))
|
||||
alignment = self.v(
|
||||
nn.functional.tanh(processed_query + processed_annots))
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
@ -33,15 +33,24 @@ class BahdanauAttention(nn.Module):
|
|||
class LocationSensitiveAttention(nn.Module):
|
||||
"""Location sensitive attention following
|
||||
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||
def __init__(self, annot_dim, query_dim, attn_dim,
|
||||
kernel_size=7, filters=20):
|
||||
|
||||
def __init__(self,
|
||||
annot_dim,
|
||||
query_dim,
|
||||
attn_dim,
|
||||
kernel_size=7,
|
||||
filters=20):
|
||||
super(LocationSensitiveAttention, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.filters = filters
|
||||
padding = int((kernel_size - 1) / 2)
|
||||
self.loc_conv = nn.Conv1d(2, filters,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=padding, bias=False)
|
||||
self.loc_conv = nn.Conv1d(
|
||||
2,
|
||||
filters,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
bias=False)
|
||||
self.loc_linear = nn.Linear(filters, attn_dim)
|
||||
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||
|
@ -62,8 +71,9 @@ class LocationSensitiveAttention(nn.Module):
|
|||
processed_loc = self.loc_linear(loc_conv)
|
||||
processed_query = self.query_layer(query)
|
||||
processed_annots = self.annot_layer(annot)
|
||||
alignment = self.v(nn.functional.tanh(
|
||||
processed_query + processed_annots + processed_loc))
|
||||
alignment = self.v(
|
||||
nn.functional.tanh(processed_query + processed_annots +
|
||||
processed_loc))
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
@ -85,16 +95,23 @@ class AttentionRNNCell(nn.Module):
|
|||
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
||||
# pick bahdanau or location sensitive attention
|
||||
if align_model == 'b':
|
||||
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim)
|
||||
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
|
||||
out_dim)
|
||||
if align_model == 'ls':
|
||||
self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim)
|
||||
self.alignment_model = LocationSensitiveAttention(
|
||||
annot_dim, rnn_dim, out_dim)
|
||||
else:
|
||||
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
||||
'b' (Bahdanau) or 'ls' (Location Sensitive)."
|
||||
.format(align_model))
|
||||
|
||||
|
||||
def forward(self, memory, context, rnn_state, annots,
|
||||
atten, annot_lens=None):
|
||||
def forward(self,
|
||||
memory,
|
||||
context,
|
||||
rnn_state,
|
||||
annots,
|
||||
atten,
|
||||
annot_lens=None):
|
||||
"""
|
||||
Shapes:
|
||||
- memory: (batch, 1, dim) or (batch, dim)
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
# class StopProjection(nn.Module):
|
||||
# r""" Simple projection layer to predict the "stop token"
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ from utils.generic_utils import sequence_mask
|
|||
|
||||
|
||||
class L1LossMasked(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(L1LossMasked, self).__init__()
|
||||
|
||||
|
@ -31,21 +30,20 @@ class L1LossMasked(nn.Module):
|
|||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
||||
reduce=False)
|
||||
losses_flat = functional.l1_loss(
|
||||
input, target_flat, size_average=False, reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2)
|
||||
mask = sequence_mask(
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||
return loss
|
||||
|
||||
|
||||
class MSELossMasked(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MSELossMasked, self).__init__()
|
||||
|
||||
|
@ -71,14 +69,14 @@ class MSELossMasked(nn.Module):
|
|||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.mse_loss(input, target_flat, size_average=False,
|
||||
reduce=False)
|
||||
losses_flat = functional.mse_loss(
|
||||
input, target_flat, size_average=False, reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2)
|
||||
mask = sequence_mask(
|
||||
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
|
||||
losses = losses * mask.float()
|
||||
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
|
||||
return loss
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
|||
from torch import nn
|
||||
from .attention import AttentionRNNCell
|
||||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||
It creates as many layers as given by 'out_features'
|
||||
|
@ -16,9 +17,10 @@ class Prenet(nn.Module):
|
|||
def __init__(self, in_features, out_features=[256, 128]):
|
||||
super(Prenet, self).__init__()
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
self.layers = nn.ModuleList(
|
||||
[nn.Linear(in_size, out_size)
|
||||
for (in_size, out_size) in zip(in_features, out_features)])
|
||||
self.layers = nn.ModuleList([
|
||||
nn.Linear(in_size, out_size)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
|
@ -46,12 +48,21 @@ class BatchNormConv1d(nn.Module):
|
|||
- output: batch x dims
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
activation=None):
|
||||
super(BatchNormConv1d, self).__init__()
|
||||
self.conv1d = nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride, padding=padding, bias=False)
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False)
|
||||
# Following tensorflow's default parameters
|
||||
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
||||
self.activation = activation
|
||||
|
@ -96,16 +107,25 @@ class CBHG(nn.Module):
|
|||
- output: batch x time x dim*2
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
K=16,
|
||||
projections=[128, 128],
|
||||
num_highways=4):
|
||||
super(CBHG, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.relu = nn.ReLU()
|
||||
# list of conv1d bank with filter size k=1...K
|
||||
# TODO: try dilational layers instead
|
||||
self.conv1d_banks = nn.ModuleList(
|
||||
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
|
||||
padding=k // 2, activation=self.relu)
|
||||
for k in range(1, K + 1)])
|
||||
self.conv1d_banks = nn.ModuleList([
|
||||
BatchNormConv1d(
|
||||
in_features,
|
||||
in_features,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=k // 2,
|
||||
activation=self.relu) for k in range(1, K + 1)
|
||||
])
|
||||
# max pooling of conv bank
|
||||
# TODO: try average pooling OR larger kernel size
|
||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||
|
@ -114,9 +134,15 @@ class CBHG(nn.Module):
|
|||
activations += [None]
|
||||
# setup conv1d projection layers
|
||||
layer_set = []
|
||||
for (in_size, out_size, ac) in zip(out_features, projections, activations):
|
||||
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
|
||||
padding=1, activation=ac)
|
||||
for (in_size, out_size, ac) in zip(out_features, projections,
|
||||
activations):
|
||||
layer = BatchNormConv1d(
|
||||
in_size,
|
||||
out_size,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
activation=ac)
|
||||
layer_set.append(layer)
|
||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||
# setup Highway layers
|
||||
|
@ -204,10 +230,14 @@ class Decoder(nn.Module):
|
|||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNNCell(out_dim=128, rnn_dim=256, annot_dim=in_features,
|
||||
memory_dim=128, align_model='ls')
|
||||
self.attention_rnn = AttentionRNNCell(
|
||||
out_dim=128,
|
||||
rnn_dim=256,
|
||||
annot_dim=in_features,
|
||||
memory_dim=128,
|
||||
align_model='ls')
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
self.decoder_rnns = nn.ModuleList(
|
||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
|
@ -241,17 +271,20 @@ class Decoder(nn.Module):
|
|||
# Grouping multiple frames if necessary
|
||||
if memory.size(-1) == self.memory_dim:
|
||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
||||
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
|
||||
self.memory_dim, self.r)
|
||||
" !! Dimension mismatch {} vs {} * {}".format(
|
||||
memory.size(-1), self.memory_dim, self.r)
|
||||
T_decoder = memory.size(1)
|
||||
# go frame as zeros matrix
|
||||
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
|
||||
# decoder states
|
||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))]
|
||||
decoder_rnn_hiddens = [
|
||||
inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))
|
||||
]
|
||||
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||
stopnet_rnn_hidden = inputs.data.new(B,
|
||||
self.r * self.memory_dim).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
attention_cum = inputs.data.new(B, T).zero_()
|
||||
|
@ -268,13 +301,12 @@ class Decoder(nn.Module):
|
|||
if greedy:
|
||||
memory_input = outputs[-1]
|
||||
else:
|
||||
memory_input = memory[t-1]
|
||||
memory_input = memory[t - 1]
|
||||
# Prenet
|
||||
processed_memory = self.prenet(memory_input)
|
||||
# Attention RNN
|
||||
attention_cat = torch.cat((attention.unsqueeze(1),
|
||||
attention_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
attention_cat = torch.cat(
|
||||
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||
inputs, attention_cat, input_lens)
|
||||
|
@ -293,16 +325,18 @@ class Decoder(nn.Module):
|
|||
output = self.proj_to_mel(decoder_output)
|
||||
stop_input = output
|
||||
# predict stop token
|
||||
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
||||
stop_token, stopnet_rnn_hidden = self.stopnet(
|
||||
stop_input, stopnet_rnn_hidden)
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
t += 1
|
||||
if (not greedy and self.training) or (greedy and memory is not None):
|
||||
if (not greedy and self.training) or (greedy
|
||||
and memory is not None):
|
||||
if t >= T_decoder:
|
||||
break
|
||||
else:
|
||||
if t > inputs.shape[1]/2 and stop_token > 0.6:
|
||||
if t > inputs.shape[1] / 2 and stop_token > 0.6:
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print(" | | > Decoder stopped with 'max_decoder_steps")
|
||||
|
|
|
@ -6,14 +6,18 @@ from layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
|||
|
||||
|
||||
class Tacotron(nn.Module):
|
||||
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
||||
r=5, padding_idx=None):
|
||||
def __init__(self,
|
||||
embedding_dim=256,
|
||||
linear_dim=1025,
|
||||
mel_dim=80,
|
||||
r=5,
|
||||
padding_idx=None):
|
||||
super(Tacotron, self).__init__()
|
||||
self.r = r
|
||||
self.mel_dim = mel_dim
|
||||
self.linear_dim = linear_dim
|
||||
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
||||
padding_idx=padding_idx)
|
||||
self.embedding = nn.Embedding(
|
||||
len(symbols), embedding_dim, padding_idx=padding_idx)
|
||||
print(" | > Number of characters : {}".format(len(symbols)))
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(embedding_dim)
|
||||
|
|
|
@ -40,8 +40,12 @@ def visualize(alignment, spectrogram, stop_tokens, CONFIG):
|
|||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||
|
||||
plt.subplot(3, 1, 3)
|
||||
librosa.display.specshow(spectrogram.T, sr=CONFIG.sample_rate,
|
||||
hop_length=hop_length, x_axis="time", y_axis="linear")
|
||||
librosa.display.specshow(
|
||||
spectrogram.T,
|
||||
sr=CONFIG.sample_rate,
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear")
|
||||
plt.xlabel("Time", fontsize=label_fontsize)
|
||||
plt.ylabel("Hz", fontsize=label_fontsize)
|
||||
plt.tight_layout()
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
## TTS example web-server
|
||||
Steps to run:
|
||||
1. Download one of the models given on the main page.
|
||||
2. Checkout the corresponding commit history.
|
||||
2. Set paths and other options in the file ```server/conf.json```.
|
||||
3. Run the server ```python server/server.py -c conf.json```. (Requires Flask)
|
||||
1. Download one of the models given on the main page. Click [here](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing) for the lastest model.
|
||||
2. Checkout the corresponding commit history or use ```server``` branch if you like to use the latest model.
|
||||
2. Set the paths and the other options in the file ```server/conf.json```.
|
||||
3. Run the server ```python server/server.py -c server/conf.json```. (Requires Flask)
|
||||
4. Go to ```localhost:[given_port]``` and enjoy.
|
||||
|
||||
Note that the audio quality on browser is slightly worse due to the encoder quantization.
|
||||
For high quality results, please use the library versions shown in the ```requirements.txt``` file.
|
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"model_path":"/home/erogol/projects/models/LJSpeech/May-22-2018_03_24PM-e6112f7",
|
||||
"model_path":"../models/May-22-2018_03_24PM-e6112f7",
|
||||
"model_name":"checkpoint_272976.pth.tar",
|
||||
"model_config":"config.json",
|
||||
"port": 5002,
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
import argparse
|
||||
from synthesizer import Synthesizer
|
||||
from TTS.utils.generic_utils import load_config
|
||||
from flask import (Flask, Response, request,
|
||||
render_template, send_file)
|
||||
from flask import Flask, Response, request, render_template, send_file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config_path', type=str,
|
||||
help='path to config file for training')
|
||||
parser.add_argument(
|
||||
'-c', '--config_path', type=str, help='path to config file for training')
|
||||
args = parser.parse_args()
|
||||
|
||||
config = load_config(args.config_path)
|
||||
|
@ -16,17 +15,19 @@ synthesizer = Synthesizer()
|
|||
synthesizer.load_model(config.model_path, config.model_name,
|
||||
config.model_config, config.use_cuda)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return render_template('index.html')
|
||||
|
||||
|
||||
@app.route('/api/tts', methods=['GET'])
|
||||
def tts():
|
||||
text = request.args.get('text')
|
||||
print(" > Model input: {}".format(text))
|
||||
data = synthesizer.tts(text)
|
||||
return send_file(data,
|
||||
mimetype='audio/wav')
|
||||
return send_file(data, mimetype='audio/wav')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, host='0.0.0.0', port=config.port)
|
||||
app.run(debug=True, host='0.0.0.0', port=config.port)
|
||||
|
|
|
@ -13,39 +13,44 @@ from matplotlib import pylab as plt
|
|||
|
||||
|
||||
class Synthesizer(object):
|
||||
|
||||
def load_model(self, model_path, model_name, model_config, use_cuda):
|
||||
model_config = os.path.join(model_path, model_config)
|
||||
self.model_file = os.path.join(model_path, model_name)
|
||||
self.model_file = os.path.join(model_path, model_name)
|
||||
print(" > Loading model ...")
|
||||
print(" | > model config: ", model_config)
|
||||
print(" | > model file: ", self.model_file)
|
||||
config = load_config(model_config)
|
||||
self.config = config
|
||||
self.use_cuda = use_cuda
|
||||
self.model = Tacotron(config.embedding_size, config.num_freq, config.num_mels, config.r)
|
||||
self.ap = AudioProcessor(config.sample_rate, config.num_mels, config.min_level_db,
|
||||
config.frame_shift_ms, config.frame_length_ms, config.preemphasis,
|
||||
config.ref_level_db, config.num_freq, config.power, griffin_lim_iters=60)
|
||||
self.model = Tacotron(config.embedding_size, config.num_freq,
|
||||
config.num_mels, config.r)
|
||||
self.ap = AudioProcessor(
|
||||
config.sample_rate,
|
||||
config.num_mels,
|
||||
config.min_level_db,
|
||||
config.frame_shift_ms,
|
||||
config.frame_length_ms,
|
||||
config.preemphasis,
|
||||
config.ref_level_db,
|
||||
config.num_freq,
|
||||
config.power,
|
||||
griffin_lim_iters=60)
|
||||
# load model state
|
||||
if use_cuda:
|
||||
cp = torch.load(self.model_file)
|
||||
else:
|
||||
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage)
|
||||
cp = torch.load(
|
||||
self.model_file, map_location=lambda storage, loc: storage)
|
||||
# load the model
|
||||
self.model.load_state_dict(cp['model'])
|
||||
if use_cuda:
|
||||
self.model.cuda()
|
||||
self.model.eval()
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
||||
# sf.write(path, wav.astype(np.int32), self.config.sample_rate, format='wav')
|
||||
# wav = librosa.util.normalize(wav.astype(np.float), norm=np.inf, axis=None)
|
||||
# wav = wav / wav.max()
|
||||
# sf.write(path, wav.astype('float'), self.config.sample_rate, format='ogg')
|
||||
scipy.io.wavfile.write(path, self.config.sample_rate, wav.astype(np.int16))
|
||||
# librosa.output.write_wav(path, wav.astype(np.int16), self.config.sample_rate, norm=True)
|
||||
librosa.output.write_wav(path, wav.astype(np.int16),
|
||||
self.config.sample_rate)
|
||||
|
||||
def tts(self, text):
|
||||
text_cleaner = [self.config.text_cleaner]
|
||||
|
@ -54,14 +59,15 @@ class Synthesizer(object):
|
|||
if len(sen) < 3:
|
||||
continue
|
||||
sen = sen.strip()
|
||||
sen +='.'
|
||||
sen += '.'
|
||||
print(sen)
|
||||
sen = sen.strip()
|
||||
seq = np.array(text_to_sequence(text, text_cleaner))
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0).long()
|
||||
if self.use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
mel_out, linear_out, alignments, stop_tokens = self.model.forward(chars_var)
|
||||
mel_out, linear_out, alignments, stop_tokens = self.model.forward(
|
||||
chars_var)
|
||||
linear_out = linear_out[0].data.cpu().numpy()
|
||||
wav = self.ap.inv_spectrogram(linear_out.T)
|
||||
# wav = wav[:self.ap.find_endpoint(wav)]
|
||||
|
|
67
setup.py
67
setup.py
|
@ -25,7 +25,6 @@ else:
|
|||
|
||||
|
||||
class build_py(setuptools.command.build_py.build_py):
|
||||
|
||||
def run(self):
|
||||
self.create_version_file()
|
||||
setuptools.command.build_py.build_py.run(self)
|
||||
|
@ -40,7 +39,6 @@ class build_py(setuptools.command.build_py.build_py):
|
|||
|
||||
|
||||
class develop(setuptools.command.develop.develop):
|
||||
|
||||
def run(self):
|
||||
build_py.create_version_file()
|
||||
setuptools.command.develop.develop.run(self)
|
||||
|
@ -50,8 +48,11 @@ def create_readme_rst():
|
|||
global cwd
|
||||
try:
|
||||
subprocess.check_call(
|
||||
["pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
|
||||
"README.md"], cwd=cwd)
|
||||
[
|
||||
"pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
|
||||
"README.md"
|
||||
],
|
||||
cwd=cwd)
|
||||
print("Generated README.rst from README.md using pandoc.")
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
@ -59,33 +60,31 @@ def create_readme_rst():
|
|||
pass
|
||||
|
||||
|
||||
setup(name='TTS',
|
||||
version=version,
|
||||
url='https://github.com/mozilla/TTS',
|
||||
description='Text to Speech with Deep Learning',
|
||||
|
||||
packages=find_packages(),
|
||||
cmdclass={
|
||||
'build_py': build_py,
|
||||
'develop': develop,
|
||||
},
|
||||
setup_requires=[
|
||||
"numpy"
|
||||
],
|
||||
install_requires=[
|
||||
"scipy",
|
||||
"torch == 0.4.0",
|
||||
"librosa",
|
||||
"unidecode",
|
||||
"tensorboardX",
|
||||
"matplotlib",
|
||||
"Pillow",
|
||||
"flask",
|
||||
"lws",
|
||||
],
|
||||
extras_require={
|
||||
"bin": [
|
||||
"tqdm",
|
||||
"requests",
|
||||
],
|
||||
})
|
||||
setup(
|
||||
name='TTS',
|
||||
version=version,
|
||||
url='https://github.com/mozilla/TTS',
|
||||
description='Text to Speech with Deep Learning',
|
||||
packages=find_packages(),
|
||||
cmdclass={
|
||||
'build_py': build_py,
|
||||
'develop': develop,
|
||||
},
|
||||
setup_requires=["numpy"],
|
||||
install_requires=[
|
||||
"scipy",
|
||||
"torch == 0.4.0",
|
||||
"librosa",
|
||||
"unidecode",
|
||||
"tensorboardX",
|
||||
"matplotlib",
|
||||
"Pillow",
|
||||
"flask",
|
||||
"lws",
|
||||
],
|
||||
extras_require={
|
||||
"bin": [
|
||||
"tqdm",
|
||||
"requests",
|
||||
],
|
||||
})
|
||||
|
|
|
@ -8,19 +8,17 @@ OUT_PATH = '/tmp/test.pth.tar'
|
|||
|
||||
|
||||
class ModelSavingTests(unittest.TestCase):
|
||||
|
||||
def save_checkpoint_test(self):
|
||||
# create a dummy model
|
||||
model = Prenet(128, out_features=[256, 128])
|
||||
model = T.nn.DataParallel(layer)
|
||||
|
||||
# save the model
|
||||
save_checkpoint(model, None, 100,
|
||||
OUTPATH, 1, 1)
|
||||
save_checkpoint(model, None, 100, OUTPATH, 1, 1)
|
||||
|
||||
# load the model to CPU
|
||||
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
|
||||
loc: storage)
|
||||
model_dict = torch.load(
|
||||
MODEL_PATH, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(model_dict['model'])
|
||||
|
||||
def save_best_model_test(self):
|
||||
|
@ -29,11 +27,9 @@ class ModelSavingTests(unittest.TestCase):
|
|||
model = T.nn.DataParallel(layer)
|
||||
|
||||
# save the model
|
||||
best_loss = save_best_model(model, None, 0,
|
||||
100, OUT_PATH,
|
||||
10, 1)
|
||||
best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1)
|
||||
|
||||
# load the model to CPU
|
||||
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
|
||||
loc: storage)
|
||||
model_dict = torch.load(
|
||||
MODEL_PATH, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(model_dict['model'])
|
||||
|
|
|
@ -7,7 +7,6 @@ from TTS.utils.generic_utils import sequence_mask
|
|||
|
||||
|
||||
class PrenetTests(unittest.TestCase):
|
||||
|
||||
def test_in_out(self):
|
||||
layer = Prenet(128, out_features=[256, 128])
|
||||
dummy_input = T.rand(4, 128)
|
||||
|
@ -19,7 +18,6 @@ class PrenetTests(unittest.TestCase):
|
|||
|
||||
|
||||
class CBHGTests(unittest.TestCase):
|
||||
|
||||
def test_in_out(self):
|
||||
layer = CBHG(128, K=6, projections=[128, 128], num_highways=2)
|
||||
dummy_input = T.rand(4, 8, 128)
|
||||
|
@ -32,7 +30,6 @@ class CBHGTests(unittest.TestCase):
|
|||
|
||||
|
||||
class DecoderTests(unittest.TestCase):
|
||||
|
||||
def test_in_out(self):
|
||||
layer = Decoder(in_features=256, memory_dim=80, r=2)
|
||||
dummy_input = T.rand(4, 8, 256)
|
||||
|
@ -49,7 +46,6 @@ class DecoderTests(unittest.TestCase):
|
|||
|
||||
|
||||
class EncoderTests(unittest.TestCase):
|
||||
|
||||
def test_in_out(self):
|
||||
layer = Encoder(128)
|
||||
dummy_input = T.rand(4, 8, 128)
|
||||
|
@ -63,7 +59,6 @@ class EncoderTests(unittest.TestCase):
|
|||
|
||||
|
||||
class L1LossMaskedTests(unittest.TestCase):
|
||||
|
||||
def test_in_out(self):
|
||||
layer = L1LossMasked()
|
||||
dummy_input = T.ones(4, 8, 128).float()
|
||||
|
@ -80,7 +75,7 @@ class L1LossMaskedTests(unittest.TestCase):
|
|||
dummy_input = T.ones(4, 8, 128).float()
|
||||
dummy_target = T.zeros(4, 8, 128).float()
|
||||
dummy_length = (T.arange(5, 9)).long()
|
||||
mask = ((sequence_mask(dummy_length).float() - 1.0)
|
||||
* 100.0).unsqueeze(2)
|
||||
mask = (
|
||||
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
|
||||
output = layer(dummy_input + mask, dummy_target, dummy_length)
|
||||
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
|
||||
|
|
|
@ -12,34 +12,38 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
|
|||
|
||||
|
||||
class TestLJSpeechDataset(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
self.ap = AudioProcessor(sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis,
|
||||
min_mel_freq=c.min_mel_freq,
|
||||
max_mel_freq=c.max_mel_freq)
|
||||
self.ap = AudioProcessor(
|
||||
sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis,
|
||||
min_mel_freq=c.min_mel_freq,
|
||||
max_mel_freq=c.max_mel_freq)
|
||||
|
||||
def test_loader(self):
|
||||
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap = self.ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=2,
|
||||
shuffle=True, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
|
@ -62,18 +66,22 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_padding(self):
|
||||
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap = self.ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(dataset, batch_size=1,
|
||||
shuffle=False, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
|
@ -98,9 +106,13 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(dataset, batch_size=2,
|
||||
shuffle=False, collate_fn=dataset.collate_fn,
|
||||
drop_last=False, num_workers=c.num_loader_workers)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
|
@ -130,9 +142,9 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1-idx, -1].sum() == 0
|
||||
assert linear_input[1-idx, -1].sum() == 0
|
||||
assert stop_target[1-idx, -1] == 1
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1 - idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
|
||||
# check batch conditions
|
||||
|
@ -141,34 +153,38 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
|
||||
|
||||
class TestKusalDataset(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestKusalDataset, self).__init__(*args, **kwargs)
|
||||
self.max_loader_iter = 4
|
||||
self.ap = AudioProcessor(sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis,
|
||||
min_mel_freq=c.min_mel_freq,
|
||||
max_mel_freq=c.max_mel_freq)
|
||||
self.ap = AudioProcessor(
|
||||
sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis,
|
||||
min_mel_freq=c.min_mel_freq,
|
||||
max_mel_freq=c.max_mel_freq)
|
||||
|
||||
def test_loader(self):
|
||||
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap = self.ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
dataset = Kusal.MyDataset(
|
||||
os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=2,
|
||||
shuffle=True, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
|
@ -191,18 +207,22 @@ class TestKusalDataset(unittest.TestCase):
|
|||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_padding(self):
|
||||
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap = self.ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
dataset = Kusal.MyDataset(
|
||||
os.path.join(c.data_path_Kusal),
|
||||
os.path.join(c.data_path_Kusal, 'prompts.txt'),
|
||||
1,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
# Test for batch size 1
|
||||
dataloader = DataLoader(dataset, batch_size=1,
|
||||
shuffle=False, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
|
@ -227,9 +247,13 @@ class TestKusalDataset(unittest.TestCase):
|
|||
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||
|
||||
# Test for batch size 2
|
||||
dataloader = DataLoader(dataset, batch_size=2,
|
||||
shuffle=False, collate_fn=dataset.collate_fn,
|
||||
drop_last=False, num_workers=c.num_loader_workers)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
|
@ -259,16 +283,16 @@ class TestKusalDataset(unittest.TestCase):
|
|||
assert mel_lengths[idx] == mel_input[idx].shape[0]
|
||||
|
||||
# check the second itme in the batch
|
||||
assert mel_input[1-idx, -1].sum() == 0
|
||||
assert linear_input[1-idx, -1].sum() == 0
|
||||
assert stop_target[1-idx, -1] == 1
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1 - idx, -1] == 1
|
||||
assert len(mel_lengths.shape) == 1
|
||||
|
||||
# check batch conditions
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
||||
|
||||
|
||||
# class TestTWEBDataset(unittest.TestCase):
|
||||
|
||||
# def __init__(self, *args, **kwargs):
|
||||
|
@ -339,7 +363,7 @@ class TestKusalDataset(unittest.TestCase):
|
|||
# for i, data in enumerate(dataloader):
|
||||
# if i == self.max_loader_iter:
|
||||
# break
|
||||
|
||||
|
||||
# text_input = data[0]
|
||||
# text_lengths = data[1]
|
||||
# linear_input = data[2]
|
||||
|
@ -399,4 +423,4 @@ class TestKusalDataset(unittest.TestCase):
|
|||
|
||||
# # check batch conditions
|
||||
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
|
|
@ -19,50 +19,51 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
|
|||
|
||||
|
||||
class TacotronTrainTest(unittest.TestCase):
|
||||
|
||||
def test_train_step(self):
|
||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1)
|
||||
|
||||
stop_targets = stop_targets.view(input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||
|
||||
|
||||
criterion = L1LossMasked().to(device)
|
||||
criterion_st = nn.BCELoss().to(device)
|
||||
model = Tacotron(c.embedding_size,
|
||||
c.num_freq,
|
||||
c.num_mels,
|
||||
model = Tacotron(c.embedding_size, c.num_freq, c.num_mels,
|
||||
c.r).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||
input, mel_spec)
|
||||
assert stop_tokens.data.max() <= 1.0
|
||||
assert stop_tokens.data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
|
||||
loss = loss + criterion(linear_out, linear_spec,
|
||||
mel_lengths) + stop_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
if count not in [148, 59]:
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
|
||||
assert (param != param_ref).any(
|
||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref)
|
||||
count += 1
|
||||
|
||||
|
||||
|
226
train.py
226
train.py
|
@ -1,39 +1,34 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import shutil
|
||||
import torch
|
||||
import signal
|
||||
import argparse
|
||||
import importlib
|
||||
import pickle
|
||||
import traceback
|
||||
import numpy as np
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from torch import onnx
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from utils.generic_utils import (synthesis, remove_experiment_folder,
|
||||
create_experiment_folder, save_checkpoint,
|
||||
save_best_model, load_config, lr_decay,
|
||||
count_parameters, check_update, get_commit_hash)
|
||||
from utils.generic_utils import (
|
||||
synthesis, remove_experiment_folder, create_experiment_folder,
|
||||
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
|
||||
check_update, get_commit_hash)
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from models.tacotron import Tacotron
|
||||
from layers.losses import L1LossMasked
|
||||
from utils.audio import AudioProcessor
|
||||
|
||||
|
||||
torch.manual_seed(1)
|
||||
torch.set_num_threads(4)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
|
||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, ap, epoch):
|
||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||
ap, epoch):
|
||||
model = model.train()
|
||||
epoch_time = 0
|
||||
avg_linear_loss = 0
|
||||
|
@ -54,7 +49,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
stop_targets = data[5]
|
||||
|
||||
# set stop targets view, we predict a single stop token per r frames prediction
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = stop_targets.view(text_input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||
|
||||
current_step = num_iter + args.restore_step + \
|
||||
|
@ -89,7 +85,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_input[:, :, :n_priority_freq],
|
||||
mel_lengths)
|
||||
|
@ -106,7 +102,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
# backpass and check the grad norm for stop loss
|
||||
stop_loss.backward()
|
||||
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100)
|
||||
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet,
|
||||
0.5, 100)
|
||||
if skip_flag:
|
||||
optimizer_st.zero_grad()
|
||||
print(" | | > Iteration skipped fro stopnet!!")
|
||||
|
@ -117,9 +114,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
epoch_time += step_time
|
||||
|
||||
if current_step % c.print_step == 0:
|
||||
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "\
|
||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "\
|
||||
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, current_step,
|
||||
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter,
|
||||
current_step,
|
||||
loss.item(),
|
||||
linear_loss.item(),
|
||||
mel_loss.item(),
|
||||
|
@ -147,8 +145,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
if current_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, optimizer_st, linear_loss.item(),
|
||||
OUT_PATH, current_step, epoch)
|
||||
save_checkpoint(model, optimizer, optimizer_st,
|
||||
linear_loss.item(), OUT_PATH, current_step,
|
||||
epoch)
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = linear_output[0].data.cpu().numpy()
|
||||
|
@ -168,8 +167,11 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
ap.griffin_lim_iters = 60
|
||||
audio_signal = ap.inv_spectrogram(audio_signal.T)
|
||||
try:
|
||||
tb.add_audio('SampleAudio', audio_signal, current_step,
|
||||
sample_rate=c.sample_rate)
|
||||
tb.add_audio(
|
||||
'SampleAudio',
|
||||
audio_signal,
|
||||
current_step,
|
||||
sample_rate=c.sample_rate)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
@ -180,9 +182,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
avg_step_time /= (num_iter + 1)
|
||||
|
||||
# print epoch stats
|
||||
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "\
|
||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "\
|
||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "\
|
||||
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||
"AvgStepTime:{:.2f}".format(current_step,
|
||||
avg_total_loss,
|
||||
avg_linear_loss,
|
||||
|
@ -209,10 +211,12 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
avg_mel_loss = 0
|
||||
avg_stop_loss = 0
|
||||
print(" | > Validation")
|
||||
test_sentences = ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist."]
|
||||
test_sentences = [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist."
|
||||
]
|
||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||
with torch.no_grad():
|
||||
if data_loader is not None:
|
||||
|
@ -228,7 +232,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
stop_targets = data[5]
|
||||
|
||||
# set stop targets view, we predict a single stop token per r frames prediction
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = stop_targets.view(text_input.shape[0],
|
||||
stop_targets.size(1) // c.r,
|
||||
-1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||
|
||||
# dispatch data to GPU
|
||||
|
@ -256,11 +262,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
epoch_time += step_time
|
||||
|
||||
if num_iter % c.print_step == 0:
|
||||
print(" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "\
|
||||
"StopLoss: {:.5f} ".format(loss.item(),
|
||||
linear_loss.item(),
|
||||
mel_loss.item(),
|
||||
stop_loss.item()), flush=True)
|
||||
print(
|
||||
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
||||
"StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(),
|
||||
mel_loss.item(), stop_loss.item()),
|
||||
flush=True)
|
||||
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
|
@ -278,15 +284,19 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
|
||||
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
|
||||
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
|
||||
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step)
|
||||
tb.add_image('ValVisual/ValidationAlignment', align_img,
|
||||
current_step)
|
||||
|
||||
# Sample audio
|
||||
audio_signal = linear_output[idx].data.cpu().numpy()
|
||||
ap.griffin_lim_iters = 60
|
||||
audio_signal = ap.inv_spectrogram(audio_signal.T)
|
||||
try:
|
||||
tb.add_audio('ValSampleAudio', audio_signal, current_step,
|
||||
sample_rate=c.sample_rate)
|
||||
tb.add_audio(
|
||||
'ValSampleAudio',
|
||||
audio_signal,
|
||||
current_step,
|
||||
sample_rate=c.sample_rate)
|
||||
except:
|
||||
# sometimes audio signal is out of boundaries
|
||||
pass
|
||||
|
@ -298,81 +308,88 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
|
||||
|
||||
# Plot Learning Stats
|
||||
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss,
|
||||
current_step)
|
||||
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss,
|
||||
current_step)
|
||||
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
|
||||
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss,
|
||||
current_step)
|
||||
|
||||
# test sentences
|
||||
ap.griffin_lim_iters = 60
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
wav, linear_spec, alignments = synthesis(model, ap, test_sentence, use_cuda,
|
||||
c.text_cleaner)
|
||||
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
|
||||
use_cuda, c.text_cleaner)
|
||||
try:
|
||||
wav_name = 'TestSentences/{}'.format(idx)
|
||||
tb.add_audio(wav_name, wav, current_step,
|
||||
sample_rate=c.sample_rate)
|
||||
tb.add_audio(
|
||||
wav_name, wav, current_step, sample_rate=c.sample_rate)
|
||||
except:
|
||||
pass
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
linear_spec = plot_spectrogram(linear_spec, ap)
|
||||
align_img = plot_alignment(align_img)
|
||||
tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec, current_step)
|
||||
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img, current_step)
|
||||
tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
|
||||
current_step)
|
||||
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img,
|
||||
current_step)
|
||||
return avg_linear_loss
|
||||
|
||||
|
||||
def main(args):
|
||||
dataset = importlib.import_module('datasets.'+c.dataset)
|
||||
dataset = importlib.import_module('datasets.' + c.dataset)
|
||||
Dataset = getattr(dataset, 'MyDataset')
|
||||
audio = importlib.import_module('utils.'+c.audio_processor)
|
||||
audio = importlib.import_module('utils.' + c.audio_processor)
|
||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||
|
||||
ap = AudioProcessor(sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis,
|
||||
min_mel_freq=c.min_mel_freq,
|
||||
max_mel_freq=c.max_mel_freq)
|
||||
ap = AudioProcessor(
|
||||
sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis,
|
||||
min_mel_freq=c.min_mel_freq,
|
||||
max_mel_freq=c.max_mel_freq)
|
||||
|
||||
# Setup the dataset
|
||||
train_dataset = Dataset(c.data_path,
|
||||
c.meta_file_train,
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap = ap,
|
||||
min_seq_len=c.min_seq_len
|
||||
)
|
||||
train_dataset = Dataset(
|
||||
c.data_path,
|
||||
c.meta_file_train,
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=ap,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
||||
shuffle=False, collate_fn=train_dataset.collate_fn,
|
||||
drop_last=False, num_workers=c.num_loader_workers,
|
||||
pin_memory=True)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=c.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=train_dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=c.num_loader_workers,
|
||||
pin_memory=True)
|
||||
|
||||
if c.run_eval:
|
||||
val_dataset = Dataset(c.data_path,
|
||||
c.meta_file_val,
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap = ap
|
||||
)
|
||||
val_dataset = Dataset(
|
||||
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap)
|
||||
|
||||
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
|
||||
shuffle=False, collate_fn=val_dataset.collate_fn,
|
||||
drop_last=False, num_workers=4,
|
||||
pin_memory=True)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=c.eval_batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=val_dataset.collate_fn,
|
||||
drop_last=False,
|
||||
num_workers=4,
|
||||
pin_memory=True)
|
||||
else:
|
||||
val_loader = None
|
||||
|
||||
model = Tacotron(c.embedding_size,
|
||||
ap.num_freq,
|
||||
c.num_mels,
|
||||
c.r)
|
||||
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
|
||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
|
@ -394,7 +411,8 @@ def main(args):
|
|||
for k, v in state.items():
|
||||
if torch.is_tensor(v):
|
||||
state[k] = v.cuda()
|
||||
print(" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||
print(
|
||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||
start_epoch = checkpoint['step'] // len(train_loader)
|
||||
best_loss = checkpoint['linear_loss']
|
||||
args.restore_step = checkpoint['step']
|
||||
|
@ -416,22 +434,36 @@ def main(args):
|
|||
best_loss = float('inf')
|
||||
|
||||
for epoch in range(0, c.epochs):
|
||||
train_loss, current_step = train(model, criterion, criterion_st, train_loader, optimizer, optimizer_st, ap, epoch)
|
||||
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, current_step)
|
||||
print(" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(train_loss, val_loss), flush=True)
|
||||
best_loss = save_best_model(model, optimizer, train_loss,
|
||||
best_loss, OUT_PATH,
|
||||
current_step, epoch)
|
||||
train_loss, current_step = train(model, criterion, criterion_st,
|
||||
train_loader, optimizer, optimizer_st,
|
||||
ap, epoch)
|
||||
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
|
||||
current_step)
|
||||
print(
|
||||
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||
train_loss, val_loss),
|
||||
flush=True)
|
||||
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
||||
OUT_PATH, current_step, epoch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--restore_path', type=str,
|
||||
help='Folder path to checkpoints', default=0)
|
||||
parser.add_argument('--config_path', type=str,
|
||||
help='path to config file for training',)
|
||||
parser.add_argument('--debug', type=bool, default=False,
|
||||
help='do not ask for git has before run.')
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Folder path to checkpoints',
|
||||
default=0)
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='path to config file for training',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='do not ask for git has before run.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# setup output paths and read configs
|
||||
|
|
|
@ -9,10 +9,19 @@ _mel_basis = None
|
|||
|
||||
|
||||
class AudioProcessor(object):
|
||||
|
||||
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||
frame_length_ms, ref_level_db, num_freq, power, preemphasis,
|
||||
min_mel_freq, max_mel_freq, griffin_lim_iters=None):
|
||||
def __init__(self,
|
||||
sample_rate,
|
||||
num_mels,
|
||||
min_level_db,
|
||||
frame_shift_ms,
|
||||
frame_length_ms,
|
||||
ref_level_db,
|
||||
num_freq,
|
||||
power,
|
||||
preemphasis,
|
||||
min_mel_freq,
|
||||
max_mel_freq,
|
||||
griffin_lim_iters=None):
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.num_mels = num_mels
|
||||
|
@ -30,7 +39,8 @@ class AudioProcessor(object):
|
|||
|
||||
def save_wav(self, wav, path):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
|
||||
librosa.output.write_wav(
|
||||
path, wav.astype(np.float), self.sample_rate, norm=True)
|
||||
|
||||
def _linear_to_mel(self, spectrogram):
|
||||
global _mel_basis
|
||||
|
@ -40,8 +50,9 @@ class AudioProcessor(object):
|
|||
|
||||
def _build_mel_basis(self, ):
|
||||
n_fft = (self.num_freq - 1) * 2
|
||||
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||
return librosa.filters.mel(
|
||||
self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||
|
||||
def _normalize(self, S):
|
||||
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||
|
@ -66,7 +77,7 @@ class AudioProcessor(object):
|
|||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||
return signal.lfilter([1, -self.preemphasis], [1], x)
|
||||
|
||||
|
||||
def apply_inv_preemphasis(self, x):
|
||||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||
|
@ -86,9 +97,9 @@ class AudioProcessor(object):
|
|||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||
# Reconstruct phase
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
else:
|
||||
return self._griffin_lim(S ** self.power)
|
||||
return self._griffin_lim(S**self.power)
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
'''Applies Griffin-Lim's raw.
|
||||
|
@ -113,7 +124,8 @@ class AudioProcessor(object):
|
|||
|
||||
def _stft(self, y):
|
||||
n_fft, hop_length, win_length = self._stft_parameters()
|
||||
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||
return librosa.stft(
|
||||
y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
def _istft(self, y):
|
||||
_, hop_length, win_length = self._stft_parameters()
|
||||
|
|
|
@ -8,11 +8,23 @@ import lws
|
|||
|
||||
_mel_basis = None
|
||||
|
||||
class AudioProcessor(object):
|
||||
|
||||
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||
frame_length_ms, ref_level_db, num_freq, power, preemphasis,
|
||||
min_mel_freq, max_mel_freq, griffin_lim_iters=None, ):
|
||||
class AudioProcessor(object):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate,
|
||||
num_mels,
|
||||
min_level_db,
|
||||
frame_shift_ms,
|
||||
frame_length_ms,
|
||||
ref_level_db,
|
||||
num_freq,
|
||||
power,
|
||||
preemphasis,
|
||||
min_mel_freq,
|
||||
max_mel_freq,
|
||||
griffin_lim_iters=None,
|
||||
):
|
||||
print(" > Setting up Audio Processor...")
|
||||
self.sample_rate = sample_rate
|
||||
self.num_mels = num_mels
|
||||
|
@ -25,18 +37,19 @@ class AudioProcessor(object):
|
|||
self.min_mel_freq = min_mel_freq
|
||||
self.max_mel_freq = max_mel_freq
|
||||
self.griffin_lim_iters = griffin_lim_iters
|
||||
self.preemphasis =preemphasis
|
||||
self.preemphasis = preemphasis
|
||||
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
|
||||
if preemphasis == 0:
|
||||
print(" | > Preemphasis is deactive.")
|
||||
|
||||
def save_wav(self, wav, path):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
|
||||
|
||||
librosa.output.write_wav(
|
||||
path, wav.astype(np.float), self.sample_rate, norm=True)
|
||||
|
||||
def _stft_parameters(self, ):
|
||||
n_fft = int((self.num_freq - 1) * 2)
|
||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
|
||||
if n_fft % hop_length != 0:
|
||||
hop_length = n_fft / 8
|
||||
|
@ -44,14 +57,21 @@ class AudioProcessor(object):
|
|||
if n_fft % win_length != 0:
|
||||
win_length = n_fft / 2
|
||||
print(" | > win_length is set to default ({}).".format(win_length))
|
||||
print(" | > fft size: {}, hop length: {}, win length: {}".format(n_fft, hop_length, win_length))
|
||||
print(" | > fft size: {}, hop length: {}, win length: {}".format(
|
||||
n_fft, hop_length, win_length))
|
||||
return int(n_fft), int(hop_length), int(win_length)
|
||||
|
||||
|
||||
def _lws_processor(self):
|
||||
try:
|
||||
return lws.lws(self.win_length, self.hop_length, fftsize=self.n_fft, mode="speech")
|
||||
return lws.lws(
|
||||
self.win_length,
|
||||
self.hop_length,
|
||||
fftsize=self.n_fft,
|
||||
mode="speech")
|
||||
except:
|
||||
raise RuntimeError(" !! WindowLength({}) is not multiple of HopLength({}).".format(self.win_length, self.hop_length))
|
||||
raise RuntimeError(
|
||||
" !! WindowLength({}) is not multiple of HopLength({}).".
|
||||
format(self.win_length, self.hop_length))
|
||||
|
||||
def _amp_to_db(self, x):
|
||||
min_level = np.exp(self.min_level_db / 20 * np.log(10))
|
||||
|
@ -70,7 +90,7 @@ class AudioProcessor(object):
|
|||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||
return signal.lfilter([1, -self.preemphasis], [1], x)
|
||||
|
||||
|
||||
def apply_inv_preemphasis(self, x):
|
||||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
|
||||
|
@ -96,14 +116,14 @@ class AudioProcessor(object):
|
|||
S = self._denormalize(spectrogram)
|
||||
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||
processor = self._lws_processor()
|
||||
D = processor.run_lws(S.astype(np.float64).T ** self.power)
|
||||
D = processor.run_lws(S.astype(np.float64).T**self.power)
|
||||
y = processor.istft(D).astype(np.float32)
|
||||
# Reconstruct phase
|
||||
if self.preemphasis:
|
||||
return self.apply_inv_preemphasis(y)
|
||||
sys.stdout = old_out
|
||||
return y
|
||||
|
||||
|
||||
def _linear_to_mel(self, spectrogram):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
|
@ -111,7 +131,10 @@ class AudioProcessor(object):
|
|||
return np.dot(_mel_basis, spectrogram)
|
||||
|
||||
def _build_mel_basis(self, ):
|
||||
return librosa.filters.mel(self.sample_rate, self.n_fft, n_mels=self.num_mels)
|
||||
return librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.num_mels)
|
||||
|
||||
|
||||
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
|
||||
|
||||
def melspectrogram(self, y):
|
||||
|
@ -124,4 +147,4 @@ class AudioProcessor(object):
|
|||
D = self._lws_processor().stft(y).T
|
||||
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
||||
sys.stdout = old_out
|
||||
return self._normalize(S)
|
||||
return self._normalize(S)
|
||||
|
|
|
@ -4,9 +4,8 @@ import numpy as np
|
|||
def _pad_data(x, length):
|
||||
_pad = 0
|
||||
assert x.ndim == 1
|
||||
return np.pad(x, (0, length - x.shape[0]),
|
||||
mode='constant',
|
||||
constant_values=_pad)
|
||||
return np.pad(
|
||||
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||
|
||||
|
||||
def prepare_data(inputs):
|
||||
|
@ -17,8 +16,10 @@ def prepare_data(inputs):
|
|||
def _pad_tensor(x, length):
|
||||
_pad = 0
|
||||
assert x.ndim == 2
|
||||
x = np.pad(x, [[0, 0], [0, length - x.shape[1]]],
|
||||
mode='constant', constant_values=_pad)
|
||||
x = np.pad(
|
||||
x, [[0, 0], [0, length - x.shape[1]]],
|
||||
mode='constant',
|
||||
constant_values=_pad)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -32,7 +33,8 @@ def prepare_tensor(inputs, out_steps):
|
|||
def _pad_stop_target(x, length):
|
||||
_pad = 1.
|
||||
assert x.ndim == 1
|
||||
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||
return np.pad(
|
||||
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||
|
||||
|
||||
def prepare_stop_target(inputs, out_steps):
|
||||
|
@ -44,6 +46,7 @@ def prepare_stop_target(inputs, out_steps):
|
|||
|
||||
def pad_per_step(inputs, pad_len):
|
||||
timesteps = inputs.shape[-1]
|
||||
return np.pad(inputs, [[0, 0], [0, 0],
|
||||
[0, pad_len]],
|
||||
mode='constant', constant_values=0.0)
|
||||
return np.pad(
|
||||
inputs, [[0, 0], [0, 0], [0, pad_len]],
|
||||
mode='constant',
|
||||
constant_values=0.0)
|
||||
|
|
|
@ -28,10 +28,13 @@ def load_config(config_path):
|
|||
def get_commit_hash():
|
||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||
try:
|
||||
subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean
|
||||
subprocess.check_output(['git', 'diff-index', '--quiet',
|
||||
'HEAD']) # Verify client is clean
|
||||
except:
|
||||
raise RuntimeError(" !! Commit before training to get the commit hash.")
|
||||
commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
||||
raise RuntimeError(
|
||||
" !! Commit before training to get the commit hash.")
|
||||
commit = subprocess.check_output(['git', 'rev-parse', '--short',
|
||||
'HEAD']).decode().strip()
|
||||
print(' > Git Hash: {}'.format(commit))
|
||||
return commit
|
||||
|
||||
|
@ -43,7 +46,8 @@ def create_experiment_folder(root_path, model_name, debug):
|
|||
commit_hash = 'debug'
|
||||
else:
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(root_path, date_str + '-' + model_name + '-' + commit_hash)
|
||||
output_folder = os.path.join(
|
||||
root_path, date_str + '-' + model_name + '-' + commit_hash)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
print(" > Experiment folder: {}".format(output_folder))
|
||||
return output_folder
|
||||
|
@ -52,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug):
|
|||
def remove_experiment_folder(experiment_path):
|
||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||
|
||||
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
|
||||
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
||||
if len(checkpoint_files) < 1:
|
||||
if os.path.exists(experiment_path):
|
||||
shutil.rmtree(experiment_path)
|
||||
|
@ -86,13 +90,15 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
|
|||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
|
||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
||||
state = {'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'optimizer_st': optimizer_st.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'linear_loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y")}
|
||||
state = {
|
||||
'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'optimizer_st': optimizer_st.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'linear_loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y")
|
||||
}
|
||||
torch.save(state, checkpoint_path)
|
||||
|
||||
|
||||
|
@ -100,12 +106,14 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
|||
current_step, epoch):
|
||||
if model_loss < best_loss:
|
||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
||||
state = {'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'linear_loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y")}
|
||||
state = {
|
||||
'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'linear_loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y")
|
||||
}
|
||||
best_loss = model_loss
|
||||
bestmodel_path = 'best_model.pth.tar'
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
|
@ -161,12 +169,12 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
|
||||
|
||||
def synthesis(model, ap, text, use_cuda, text_cleaner):
|
||||
text_cleaner = [text_cleaner]
|
||||
seq = np.array(text_to_sequence(text, text_cleaner))
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda().long()
|
||||
_, linear_out, alignments, _ = model.forward(chars_var)
|
||||
linear_out = linear_out[0].data.cpu().numpy()
|
||||
wav = ap.inv_spectrogram(linear_out.T)
|
||||
return wav, linear_out, alignments
|
||||
text_cleaner = [text_cleaner]
|
||||
seq = np.array(text_to_sequence(text, text_cleaner))
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda().long()
|
||||
_, linear_out, alignments, _ = model.forward(chars_var)
|
||||
linear_out = linear_out[0].data.cpu().numpy()
|
||||
wav = ap.inv_spectrogram(linear_out.T)
|
||||
return wav, linear_out, alignments
|
||||
|
|
|
@ -4,7 +4,6 @@ import re
|
|||
from utils.text import cleaners
|
||||
from utils.text.symbols import symbols
|
||||
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
|
|
@ -14,31 +14,31 @@ import re
|
|||
from unidecode import unidecode
|
||||
from .numbers import normalize_numbers
|
||||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
import re
|
||||
|
||||
|
||||
valid_symbols = [
|
||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
||||
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
||||
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
||||
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
||||
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
||||
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
||||
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
|
||||
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
|
||||
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
|
||||
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
|
||||
'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
|
||||
'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
|
||||
'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
|
||||
'Y', 'Z', 'ZH'
|
||||
]
|
||||
|
||||
_valid_symbol_set = set(valid_symbols)
|
||||
|
@ -27,8 +26,10 @@ class CMUDict:
|
|||
else:
|
||||
entries = _parse_cmudict(file_or_path)
|
||||
if not keep_ambiguous:
|
||||
entries = {word: pron for word,
|
||||
pron in entries.items() if len(pron) == 1}
|
||||
entries = {
|
||||
word: pron
|
||||
for word, pron in entries.items() if len(pron) == 1
|
||||
}
|
||||
self._entries = entries
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -8,61 +8,45 @@ _ordinal_re = re.compile(r'([0-9]+)(st|nd|rd|th)')
|
|||
_number_re = re.compile(r'[0-9]+')
|
||||
|
||||
_units = [
|
||||
'',
|
||||
'one',
|
||||
'two',
|
||||
'three',
|
||||
'four',
|
||||
'five',
|
||||
'six',
|
||||
'seven',
|
||||
'eight',
|
||||
'nine',
|
||||
'ten',
|
||||
'eleven',
|
||||
'twelve',
|
||||
'thirteen',
|
||||
'fourteen',
|
||||
'fifteen',
|
||||
'sixteen',
|
||||
'seventeen',
|
||||
'eighteen',
|
||||
'nineteen'
|
||||
'', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine',
|
||||
'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen',
|
||||
'seventeen', 'eighteen', 'nineteen'
|
||||
]
|
||||
|
||||
_tens = [
|
||||
'',
|
||||
'ten',
|
||||
'twenty',
|
||||
'thirty',
|
||||
'forty',
|
||||
'fifty',
|
||||
'sixty',
|
||||
'seventy',
|
||||
'eighty',
|
||||
'ninety',
|
||||
'',
|
||||
'ten',
|
||||
'twenty',
|
||||
'thirty',
|
||||
'forty',
|
||||
'fifty',
|
||||
'sixty',
|
||||
'seventy',
|
||||
'eighty',
|
||||
'ninety',
|
||||
]
|
||||
|
||||
_digit_groups = [
|
||||
'',
|
||||
'thousand',
|
||||
'million',
|
||||
'billion',
|
||||
'trillion',
|
||||
'quadrillion',
|
||||
'',
|
||||
'thousand',
|
||||
'million',
|
||||
'billion',
|
||||
'trillion',
|
||||
'quadrillion',
|
||||
]
|
||||
|
||||
_ordinal_suffixes = [
|
||||
('one', 'first'),
|
||||
('two', 'second'),
|
||||
('three', 'third'),
|
||||
('five', 'fifth'),
|
||||
('eight', 'eighth'),
|
||||
('nine', 'ninth'),
|
||||
('twelve', 'twelfth'),
|
||||
('ty', 'tieth'),
|
||||
('one', 'first'),
|
||||
('two', 'second'),
|
||||
('three', 'third'),
|
||||
('five', 'fifth'),
|
||||
('eight', 'eighth'),
|
||||
('nine', 'ninth'),
|
||||
('twelve', 'twelfth'),
|
||||
('ty', 'tieth'),
|
||||
]
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
|
||||
|
@ -114,7 +98,7 @@ def _standard_number_to_words(n, digit_group):
|
|||
def _number_to_words(n):
|
||||
# Handle special cases first, then go to the standard case:
|
||||
if n >= 1000000000000000000:
|
||||
return str(n) # Too large, just return the digits
|
||||
return str(n) # Too large, just return the digits
|
||||
elif n == 0:
|
||||
return 'zero'
|
||||
elif n % 100 == 0 and n % 1000 != 0 and n < 3000:
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
'''
|
||||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
|
@ -19,6 +17,5 @@ _arpabet = ['@' + s for s in cmudict.valid_symbols]
|
|||
# Export all symbols:
|
||||
symbols = [_pad, _eos] + list(_characters) + _arpabet
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(symbols)
|
||||
|
|
|
@ -6,8 +6,8 @@ import matplotlib.pyplot as plt
|
|||
|
||||
def plot_alignment(alignment, info=None):
|
||||
fig, ax = plt.subplots(figsize=(16, 10))
|
||||
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
|
||||
interpolation='none')
|
||||
im = ax.imshow(
|
||||
alignment.T, aspect='auto', origin='lower', interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
if info is not None:
|
||||
|
@ -17,7 +17,7 @@ def plot_alignment(alignment, info=None):
|
|||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
@ -30,6 +30,6 @@ def plot_spectrogram(linear_output, audio):
|
|||
plt.tight_layout()
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||
plt.close()
|
||||
return data
|
||||
return data
|
||||
|
|
Загрузка…
Ссылка в новой задаче