This commit is contained in:
Eren G 2018-08-02 16:34:17 +02:00
Родитель 3238ffa3e6
Коммит f5537dc48f
32 изменённых файлов: 766 добавлений и 599 удалений

Просмотреть файл

@ -8,21 +8,28 @@ import torch
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step,
prepare_tensor, prepare_stop_target)
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class MyDataset(Dataset):
def __init__(self, root_dir, csv_file, outputs_per_step,
text_cleaner, ap, min_seq_len=0):
def __init__(self,
root_dir,
csv_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wav')
self.wav_files = glob.glob(os.path.join(self.wav_dir, '*.wav'))
self._create_file_dict()
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
self.frames = [line.split('\t') for line in f if line.split('\t')[0] in self.wav_files_dict.keys()]
self.frames = [
line.split('\t') for line in f
if line.split('\t')[0] in self.wav_files_dict.keys()
]
self.outputs_per_step = outputs_per_step
self.sample_rate = ap.sample_rate
self.cleaners = text_cleaner
@ -43,10 +50,8 @@ class MyDataset(Dataset):
print(" !! Cannot read file : {}".format(filename))
def _trim_silence(self, wav):
return librosa.effects.trim(
wav, top_db=40,
frame_length=1024,
hop_length=256)[0]
return librosa.effects.trim(
wav, top_db=40, frame_length=1024, hop_length=256)[0]
def _create_file_dict(self):
self.wav_files_dict = {}
@ -87,11 +92,10 @@ class MyDataset(Dataset):
sidx = self.frames[idx][0]
sidx_files = self.wav_files_dict[sidx]
file_name = random.choice(sidx_files)
wav_name = os.path.join(self.wav_dir,
file_name)
wav_name = os.path.join(self.wav_dir, file_name)
text = self.frames[idx][2]
text = np.asarray(text_to_sequence(
text, [self.cleaners]), dtype=np.int32)
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample
@ -121,12 +125,13 @@ class MyDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1))
for mel_len in mel_lengths]
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets
stop_targets = prepare_stop_target(
stop_targets, self.outputs_per_step)
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
@ -150,8 +155,8 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}"
.format(type(batch[0]))))
found {}".format(type(batch[0]))))

Просмотреть файл

@ -6,14 +6,18 @@ import torch
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step,
prepare_tensor, prepare_stop_target)
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class MyDataset(Dataset):
def __init__(self, root_dir, csv_file, outputs_per_step,
text_cleaner, ap, min_seq_len=0):
def __init__(self,
root_dir,
csv_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wavs')
self.csv_dir = os.path.join(root_dir, csv_file)
@ -60,11 +64,10 @@ class MyDataset(Dataset):
return len(self.frames)
def __getitem__(self, idx):
wav_name = os.path.join(self.wav_dir,
self.frames[idx][0]) + '.wav'
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
text = self.frames[idx][1]
text = np.asarray(text_to_sequence(
text, [self.cleaners]), dtype=np.int32)
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample
@ -94,12 +97,13 @@ class MyDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1))
for mel_len in mel_lengths]
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets
stop_targets = prepare_stop_target(
stop_targets, self.outputs_per_step)
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
@ -123,8 +127,8 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}"
.format(type(batch[0]))))
found {}".format(type(batch[0]))))

Просмотреть файл

@ -6,14 +6,18 @@ import torch
from torch.utils.data import Dataset
from utils.text import text_to_sequence
from utils.data import (prepare_data, pad_per_step,
prepare_tensor, prepare_stop_target)
from utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class MyDataset(Dataset):
def __init__(self, root_dir, csv_file, outputs_per_step,
text_cleaner, ap, min_seq_len=0):
def __init__(self,
root_dir,
csv_file,
outputs_per_step,
text_cleaner,
ap,
min_seq_len=0):
self.root_dir = root_dir
self.wav_dir = os.path.join(root_dir, 'wavs')
self.feat_dir = os.path.join(root_dir, 'loader_data')
@ -35,7 +39,7 @@ class MyDataset(Dataset):
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def load_np(self, filename):
data = np.load(filename).astype('float32')
return data
@ -66,20 +70,24 @@ class MyDataset(Dataset):
def __getitem__(self, idx):
if self.items[idx] is None:
wav_name = os.path.join(self.wav_dir,
self.frames[idx][0]) + '.wav'
wav_name = os.path.join(self.wav_dir, self.frames[idx][0]) + '.wav'
mel_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.mel.npy'
linear_name = os.path.join(self.feat_dir,
self.frames[idx][0]) + '.linear.npy'
text = self.frames[idx][1]
text = np.asarray(text_to_sequence(
text, [self.cleaners]), dtype=np.int32)
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
mel = self.load_np(mel_name)
linear = self.load_np(linear_name)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0],
'mel':mel, 'linear': linear}
sample = {
'text': text,
'wav': wav,
'item_idx': self.frames[idx][0],
'mel': mel,
'linear': linear
}
self.items[idx] = sample
else:
sample = self.items[idx]
@ -109,12 +117,13 @@ class MyDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1))
for mel_len in mel_lengths]
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets
stop_targets = prepare_stop_target(
stop_targets, self.outputs_per_step)
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
@ -138,8 +147,8 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}"
.format(type(batch[0]))))
found {}".format(type(batch[0]))))

Просмотреть файл

@ -7,15 +7,25 @@ from torch.utils.data import Dataset
from TTS.utils.text import text_to_sequence
from TTS.utils.audio import AudioProcessor
from TTS.utils.data import (prepare_data, pad_per_step,
prepare_tensor, prepare_stop_target)
from TTS.utils.data import (prepare_data, pad_per_step, prepare_tensor,
prepare_stop_target)
class TWEBDataset(Dataset):
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
text_cleaner, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
def __init__(self,
csv_file,
root_dir,
outputs_per_step,
sample_rate,
text_cleaner,
num_mels,
min_level_db,
frame_shift_ms,
frame_length_ms,
preemphasis,
ref_level_db,
num_freq,
power,
min_seq_len=0):
with open(csv_file, "r") as f:
@ -25,8 +35,9 @@ class TWEBDataset(Dataset):
self.sample_rate = sample_rate
self.cleaners = text_cleaner
self.min_seq_len = min_seq_len
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db,
frame_shift_ms, frame_length_ms, preemphasis,
ref_level_db, num_freq, power)
print(" > Reading TWEB from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames()
@ -63,11 +74,10 @@ class TWEBDataset(Dataset):
return len(self.frames)
def __getitem__(self, idx):
wav_name = os.path.join(self.root_dir,
self.frames[idx][0]) + '.wav'
wav_name = os.path.join(self.root_dir, self.frames[idx][0]) + '.wav'
text = self.frames[idx][1]
text = np.asarray(text_to_sequence(
text, [self.cleaners]), dtype=np.int32)
text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample
@ -97,12 +107,13 @@ class TWEBDataset(Dataset):
mel_lengths = [m.shape[1] + 1 for m in mel] # +1 for zero-frame
# compute 'stop token' targets
stop_targets = [np.array([0.]*(mel_len-1))
for mel_len in mel_lengths]
stop_targets = [
np.array([0.] * (mel_len - 1)) for mel_len in mel_lengths
]
# PAD stop targets
stop_targets = prepare_stop_target(
stop_targets, self.outputs_per_step)
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
@ -126,8 +137,8 @@ class TWEBDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[
0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}"
.format(type(batch[0]))))
found {}".format(type(batch[0]))))

Просмотреть файл

@ -10,7 +10,6 @@
"hidden_size": 128,
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"epochs": 200,
"lr": 0.01,
"lr_patience": 2,
@ -19,9 +18,7 @@
"griffinf_lim_iters": 60,
"power": 1.5,
"r": 5,
"num_loader_workers": 16,
"save_step": 1,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result",

Просмотреть файл

@ -13,19 +13,19 @@ from utils.generic_utils import load_config
from multiprocessing import Pool
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str,
help='Data folder.')
parser.add_argument('--out_path', type=str,
help='Output folder.')
parser.add_argument('--config', type=str,
help='conf.json file for run settings.')
parser.add_argument("--num_proc", type=int, default=8,
help="number of processes.")
parser.add_argument("--trim_silence", type=bool, default=False,
help="trim silence in the voice clip.")
parser.add_argument('--data_path', type=str, help='Data folder.')
parser.add_argument('--out_path', type=str, help='Output folder.')
parser.add_argument(
'--config', type=str, help='conf.json file for run settings.')
parser.add_argument(
"--num_proc", type=int, default=8, help="number of processes.")
parser.add_argument(
"--trim_silence",
type=bool,
default=False,
help="trim silence in the voice clip.")
args = parser.parse_args()
DATA_PATH = args.data_path
OUT_PATH = args.out_path
@ -34,27 +34,26 @@ if __name__ == "__main__":
print(" > Input path: ", DATA_PATH)
print(" > Output path: ", OUT_PATH)
audio = importlib.import_module('utils.'+c.audio_processor)
audio = importlib.import_module('utils.' + c.audio_processor)
AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor(sample_rate = CONFIG.sample_rate,
num_mels = CONFIG.num_mels,
min_level_db = CONFIG.min_level_db,
frame_shift_ms = CONFIG.frame_shift_ms,
frame_length_ms = CONFIG.frame_length_ms,
ref_level_db = CONFIG.ref_level_db,
num_freq = CONFIG.num_freq,
power = CONFIG.power,
preemphasis = CONFIG.preemphasis,
min_mel_freq = CONFIG.min_mel_freq,
max_mel_freq = CONFIG.max_mel_freq)
ap = AudioProcessor(
sample_rate=CONFIG.sample_rate,
num_mels=CONFIG.num_mels,
min_level_db=CONFIG.min_level_db,
frame_shift_ms=CONFIG.frame_shift_ms,
frame_length_ms=CONFIG.frame_length_ms,
ref_level_db=CONFIG.ref_level_db,
num_freq=CONFIG.num_freq,
power=CONFIG.power,
preemphasis=CONFIG.preemphasis,
min_mel_freq=CONFIG.min_mel_freq,
max_mel_freq=CONFIG.max_mel_freq)
def trim_silence(self, wav):
margin = int(CONFIG.sample_rate * 0.1)
wav = wav[margin:-margin]
return librosa.effects.trim(
wav, top_db=40,
frame_length=1024,
hop_length=256)[0]
wav, top_db=40, frame_length=1024, hop_length=256)[0]
def extract_mel(file_path):
# x, fs = sf.read(file_path)
@ -63,23 +62,25 @@ if __name__ == "__main__":
x = trim_silence(x)
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
linear = ap.spectrogram(x.astype('float32')).astype('float32')
file_name = os.path.basename(file_path).replace(".wav","")
file_name = os.path.basename(file_path).replace(".wav", "")
mel_file = file_name + ".mel"
linear_file = file_name + ".linear"
np.save(os.path.join(OUT_PATH, mel_file), mel, allow_pickle=False)
np.save(os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
np.save(
os.path.join(OUT_PATH, linear_file), linear, allow_pickle=False)
mel_len = mel.shape[1]
linear_len = linear.shape[1]
wav_len = x.shape[0]
print(" > " + file_path, flush=True)
return file_path, mel_file, linear_file, str(wav_len), str(mel_len), str(linear_len)
return file_path, mel_file, linear_file, str(wav_len), str(
mel_len), str(linear_len)
glob_path = os.path.join(DATA_PATH, "*.wav")
print(" > Reading wav: {}".format(glob_path))
file_names = glob.glob(glob_path, recursive=True)
if __name__ == "__main__":
print(" > Number of files: %i"%(len(file_names)))
print(" > Number of files: %i" % (len(file_names)))
if not os.path.exists(OUT_PATH):
os.makedirs(OUT_PATH)
print(" > A new folder created at {}".format(OUT_PATH))
@ -88,7 +89,10 @@ if __name__ == "__main__":
if args.num_proc > 1:
print(" > Using {} processes.".format(args.num_proc))
with Pool(args.num_proc) as p:
r = list(tqdm.tqdm(p.imap(extract_mel, file_names), total=len(file_names)))
r = list(
tqdm.tqdm(
p.imap(extract_mel, file_names),
total=len(file_names)))
# r = list(p.imap(extract_mel, file_names))
else:
print(" > Using single process run.")
@ -100,5 +104,5 @@ if __name__ == "__main__":
file = open(file_path, "w")
for line in r:
line = ", ".join(line)
file.write(line+'\n')
file.write(line + '\n')
file.close()

Просмотреть файл

@ -24,8 +24,8 @@ class BahdanauAttention(nn.Module):
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annots)
# (batch, max_time, 1)
alignment = self.v(nn.functional.tanh(
processed_query + processed_annots))
alignment = self.v(
nn.functional.tanh(processed_query + processed_annots))
# (batch, max_time)
return alignment.squeeze(-1)
@ -33,15 +33,24 @@ class BahdanauAttention(nn.Module):
class LocationSensitiveAttention(nn.Module):
"""Location sensitive attention following
https://arxiv.org/pdf/1506.07503.pdf"""
def __init__(self, annot_dim, query_dim, attn_dim,
kernel_size=7, filters=20):
def __init__(self,
annot_dim,
query_dim,
attn_dim,
kernel_size=7,
filters=20):
super(LocationSensitiveAttention, self).__init__()
self.kernel_size = kernel_size
self.filters = filters
padding = int((kernel_size - 1) / 2)
self.loc_conv = nn.Conv1d(2, filters,
kernel_size=kernel_size, stride=1,
padding=padding, bias=False)
self.loc_conv = nn.Conv1d(
2,
filters,
kernel_size=kernel_size,
stride=1,
padding=padding,
bias=False)
self.loc_linear = nn.Linear(filters, attn_dim)
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
@ -62,8 +71,9 @@ class LocationSensitiveAttention(nn.Module):
processed_loc = self.loc_linear(loc_conv)
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annot)
alignment = self.v(nn.functional.tanh(
processed_query + processed_annots + processed_loc))
alignment = self.v(
nn.functional.tanh(processed_query + processed_annots +
processed_loc))
# (batch, max_time)
return alignment.squeeze(-1)
@ -85,16 +95,23 @@ class AttentionRNNCell(nn.Module):
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
# pick bahdanau or location sensitive attention
if align_model == 'b':
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim)
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
out_dim)
if align_model == 'ls':
self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim)
self.alignment_model = LocationSensitiveAttention(
annot_dim, rnn_dim, out_dim)
else:
raise RuntimeError(" Wrong alignment model name: {}. Use\
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
'b' (Bahdanau) or 'ls' (Location Sensitive)."
.format(align_model))
def forward(self, memory, context, rnn_state, annots,
atten, annot_lens=None):
def forward(self,
memory,
context,
rnn_state,
annots,
atten,
annot_lens=None):
"""
Shapes:
- memory: (batch, 1, dim) or (batch, dim)

Просмотреть файл

@ -2,7 +2,6 @@
import torch
from torch import nn
# class StopProjection(nn.Module):
# r""" Simple projection layer to predict the "stop token"

Просмотреть файл

@ -5,7 +5,6 @@ from utils.generic_utils import sequence_mask
class L1LossMasked(nn.Module):
def __init__(self):
super(L1LossMasked, self).__init__()
@ -31,21 +30,20 @@ class L1LossMasked(nn.Module):
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
reduce=False)
losses_flat = functional.l1_loss(
input, target_flat, size_average=False, reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss
class MSELossMasked(nn.Module):
def __init__(self):
super(MSELossMasked, self).__init__()
@ -71,14 +69,14 @@ class MSELossMasked(nn.Module):
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.mse_loss(input, target_flat, size_average=False,
reduce=False)
losses_flat = functional.mse_loss(
input, target_flat, size_average=False, reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)
mask = sequence_mask(
sequence_length=length, max_len=target.size(1)).unsqueeze(2)
losses = losses * mask.float()
loss = losses.sum() / (length.float().sum() * float(target.shape[2]))
return loss

Просмотреть файл

@ -3,6 +3,7 @@ import torch
from torch import nn
from .attention import AttentionRNNCell
class Prenet(nn.Module):
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
It creates as many layers as given by 'out_features'
@ -16,9 +17,10 @@ class Prenet(nn.Module):
def __init__(self, in_features, out_features=[256, 128]):
super(Prenet, self).__init__()
in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList(
[nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_features, out_features)])
self.layers = nn.ModuleList([
nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_features, out_features)
])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
@ -46,12 +48,21 @@ class BatchNormConv1d(nn.Module):
- output: batch x dims
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
activation=None):
super(BatchNormConv1d, self).__init__()
self.conv1d = nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride, padding=padding, bias=False)
self.conv1d = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False)
# Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
self.activation = activation
@ -96,16 +107,25 @@ class CBHG(nn.Module):
- output: batch x time x dim*2
"""
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
def __init__(self,
in_features,
K=16,
projections=[128, 128],
num_highways=4):
super(CBHG, self).__init__()
self.in_features = in_features
self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu)
for k in range(1, K + 1)])
self.conv1d_banks = nn.ModuleList([
BatchNormConv1d(
in_features,
in_features,
kernel_size=k,
stride=1,
padding=k // 2,
activation=self.relu) for k in range(1, K + 1)
])
# max pooling of conv bank
# TODO: try average pooling OR larger kernel size
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
@ -114,9 +134,15 @@ class CBHG(nn.Module):
activations += [None]
# setup conv1d projection layers
layer_set = []
for (in_size, out_size, ac) in zip(out_features, projections, activations):
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
for (in_size, out_size, ac) in zip(out_features, projections,
activations):
layer = BatchNormConv1d(
in_size,
out_size,
kernel_size=3,
stride=1,
padding=1,
activation=ac)
layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers
@ -204,10 +230,14 @@ class Decoder(nn.Module):
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNNCell(out_dim=128, rnn_dim=256, annot_dim=in_features,
memory_dim=128, align_model='ls')
self.attention_rnn = AttentionRNNCell(
out_dim=128,
rnn_dim=256,
annot_dim=in_features,
memory_dim=128,
align_model='ls')
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
@ -241,17 +271,20 @@ class Decoder(nn.Module):
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
memory = memory.view(B, memory.size(1) // self.r, -1)
" !! Dimension mismatch {} vs {} * {}".format(memory.size(-1),
self.memory_dim, self.r)
" !! Dimension mismatch {} vs {} * {}".format(
memory.size(-1), self.memory_dim, self.r)
T_decoder = memory.size(1)
# go frame as zeros matrix
initial_memory = inputs.data.new(B, self.memory_dim * self.r).zero_()
# decoder states
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))]
decoder_rnn_hiddens = [
inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))
]
current_context_vec = inputs.data.new(B, self.in_features).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
stopnet_rnn_hidden = inputs.data.new(B,
self.r * self.memory_dim).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()
attention_cum = inputs.data.new(B, T).zero_()
@ -268,13 +301,12 @@ class Decoder(nn.Module):
if greedy:
memory_input = outputs[-1]
else:
memory_input = memory[t-1]
memory_input = memory[t - 1]
# Prenet
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_cat = torch.cat((attention.unsqueeze(1),
attention_cum.unsqueeze(1)),
dim=1)
attention_cat = torch.cat(
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, input_lens)
@ -293,16 +325,18 @@ class Decoder(nn.Module):
output = self.proj_to_mel(decoder_output)
stop_input = output
# predict stop token
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
stop_token, stopnet_rnn_hidden = self.stopnet(
stop_input, stopnet_rnn_hidden)
outputs += [output]
attentions += [attention]
stop_tokens += [stop_token]
t += 1
if (not greedy and self.training) or (greedy and memory is not None):
if (not greedy and self.training) or (greedy
and memory is not None):
if t >= T_decoder:
break
else:
if t > inputs.shape[1]/2 and stop_token > 0.6:
if t > inputs.shape[1] / 2 and stop_token > 0.6:
break
elif t > self.max_decoder_steps:
print(" | | > Decoder stopped with 'max_decoder_steps")

Просмотреть файл

@ -6,14 +6,18 @@ from layers.tacotron import Prenet, Encoder, Decoder, CBHG
class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
r=5, padding_idx=None):
def __init__(self,
embedding_dim=256,
linear_dim=1025,
mel_dim=80,
r=5,
padding_idx=None):
super(Tacotron, self).__init__()
self.r = r
self.mel_dim = mel_dim
self.linear_dim = linear_dim
self.embedding = nn.Embedding(len(symbols), embedding_dim,
padding_idx=padding_idx)
self.embedding = nn.Embedding(
len(symbols), embedding_dim, padding_idx=padding_idx)
print(" | > Number of characters : {}".format(len(symbols)))
self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim)

Просмотреть файл

@ -40,8 +40,12 @@ def visualize(alignment, spectrogram, stop_tokens, CONFIG):
plt.plot(range(len(stop_tokens)), list(stop_tokens))
plt.subplot(3, 1, 3)
librosa.display.specshow(spectrogram.T, sr=CONFIG.sample_rate,
hop_length=hop_length, x_axis="time", y_axis="linear")
librosa.display.specshow(
spectrogram.T,
sr=CONFIG.sample_rate,
hop_length=hop_length,
x_axis="time",
y_axis="linear")
plt.xlabel("Time", fontsize=label_fontsize)
plt.ylabel("Hz", fontsize=label_fontsize)
plt.tight_layout()

Просмотреть файл

@ -1,9 +1,9 @@
## TTS example web-server
Steps to run:
1. Download one of the models given on the main page.
2. Checkout the corresponding commit history.
2. Set paths and other options in the file ```server/conf.json```.
3. Run the server ```python server/server.py -c conf.json```. (Requires Flask)
1. Download one of the models given on the main page. Click [here](https://drive.google.com/drive/folders/1Q6BKeEkZyxSGsocK2p_mqgzLwlNvbHFJ?usp=sharing) for the lastest model.
2. Checkout the corresponding commit history or use ```server``` branch if you like to use the latest model.
2. Set the paths and the other options in the file ```server/conf.json```.
3. Run the server ```python server/server.py -c server/conf.json```. (Requires Flask)
4. Go to ```localhost:[given_port]``` and enjoy.
Note that the audio quality on browser is slightly worse due to the encoder quantization.
For high quality results, please use the library versions shown in the ```requirements.txt``` file.

Просмотреть файл

@ -1,5 +1,5 @@
{
"model_path":"/home/erogol/projects/models/LJSpeech/May-22-2018_03_24PM-e6112f7",
"model_path":"../models/May-22-2018_03_24PM-e6112f7",
"model_name":"checkpoint_272976.pth.tar",
"model_config":"config.json",
"port": 5002,

Просмотреть файл

@ -2,12 +2,11 @@
import argparse
from synthesizer import Synthesizer
from TTS.utils.generic_utils import load_config
from flask import (Flask, Response, request,
render_template, send_file)
from flask import Flask, Response, request, render_template, send_file
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config_path', type=str,
help='path to config file for training')
parser.add_argument(
'-c', '--config_path', type=str, help='path to config file for training')
args = parser.parse_args()
config = load_config(args.config_path)
@ -16,17 +15,19 @@ synthesizer = Synthesizer()
synthesizer.load_model(config.model_path, config.model_name,
config.model_config, config.use_cuda)
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/tts', methods=['GET'])
def tts():
text = request.args.get('text')
print(" > Model input: {}".format(text))
data = synthesizer.tts(text)
return send_file(data,
mimetype='audio/wav')
return send_file(data, mimetype='audio/wav')
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=config.port)
app.run(debug=True, host='0.0.0.0', port=config.port)

Просмотреть файл

@ -13,39 +13,44 @@ from matplotlib import pylab as plt
class Synthesizer(object):
def load_model(self, model_path, model_name, model_config, use_cuda):
model_config = os.path.join(model_path, model_config)
self.model_file = os.path.join(model_path, model_name)
self.model_file = os.path.join(model_path, model_name)
print(" > Loading model ...")
print(" | > model config: ", model_config)
print(" | > model file: ", self.model_file)
config = load_config(model_config)
self.config = config
self.use_cuda = use_cuda
self.model = Tacotron(config.embedding_size, config.num_freq, config.num_mels, config.r)
self.ap = AudioProcessor(config.sample_rate, config.num_mels, config.min_level_db,
config.frame_shift_ms, config.frame_length_ms, config.preemphasis,
config.ref_level_db, config.num_freq, config.power, griffin_lim_iters=60)
self.model = Tacotron(config.embedding_size, config.num_freq,
config.num_mels, config.r)
self.ap = AudioProcessor(
config.sample_rate,
config.num_mels,
config.min_level_db,
config.frame_shift_ms,
config.frame_length_ms,
config.preemphasis,
config.ref_level_db,
config.num_freq,
config.power,
griffin_lim_iters=60)
# load model state
if use_cuda:
cp = torch.load(self.model_file)
else:
cp = torch.load(self.model_file, map_location=lambda storage, loc: storage)
cp = torch.load(
self.model_file, map_location=lambda storage, loc: storage)
# load the model
self.model.load_state_dict(cp['model'])
if use_cuda:
self.model.cuda()
self.model.eval()
self.model.eval()
def save_wav(self, wav, path):
wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
# sf.write(path, wav.astype(np.int32), self.config.sample_rate, format='wav')
# wav = librosa.util.normalize(wav.astype(np.float), norm=np.inf, axis=None)
# wav = wav / wav.max()
# sf.write(path, wav.astype('float'), self.config.sample_rate, format='ogg')
scipy.io.wavfile.write(path, self.config.sample_rate, wav.astype(np.int16))
# librosa.output.write_wav(path, wav.astype(np.int16), self.config.sample_rate, norm=True)
librosa.output.write_wav(path, wav.astype(np.int16),
self.config.sample_rate)
def tts(self, text):
text_cleaner = [self.config.text_cleaner]
@ -54,14 +59,15 @@ class Synthesizer(object):
if len(sen) < 3:
continue
sen = sen.strip()
sen +='.'
sen += '.'
print(sen)
sen = sen.strip()
seq = np.array(text_to_sequence(text, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0)
chars_var = torch.from_numpy(seq).unsqueeze(0).long()
if self.use_cuda:
chars_var = chars_var.cuda()
mel_out, linear_out, alignments, stop_tokens = self.model.forward(chars_var)
mel_out, linear_out, alignments, stop_tokens = self.model.forward(
chars_var)
linear_out = linear_out[0].data.cpu().numpy()
wav = self.ap.inv_spectrogram(linear_out.T)
# wav = wav[:self.ap.find_endpoint(wav)]

Просмотреть файл

@ -25,7 +25,6 @@ else:
class build_py(setuptools.command.build_py.build_py):
def run(self):
self.create_version_file()
setuptools.command.build_py.build_py.run(self)
@ -40,7 +39,6 @@ class build_py(setuptools.command.build_py.build_py):
class develop(setuptools.command.develop.develop):
def run(self):
build_py.create_version_file()
setuptools.command.develop.develop.run(self)
@ -50,8 +48,11 @@ def create_readme_rst():
global cwd
try:
subprocess.check_call(
["pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
"README.md"], cwd=cwd)
[
"pandoc", "--from=markdown", "--to=rst", "--output=README.rst",
"README.md"
],
cwd=cwd)
print("Generated README.rst from README.md using pandoc.")
except subprocess.CalledProcessError:
pass
@ -59,33 +60,31 @@ def create_readme_rst():
pass
setup(name='TTS',
version=version,
url='https://github.com/mozilla/TTS',
description='Text to Speech with Deep Learning',
packages=find_packages(),
cmdclass={
'build_py': build_py,
'develop': develop,
},
setup_requires=[
"numpy"
],
install_requires=[
"scipy",
"torch == 0.4.0",
"librosa",
"unidecode",
"tensorboardX",
"matplotlib",
"Pillow",
"flask",
"lws",
],
extras_require={
"bin": [
"tqdm",
"requests",
],
})
setup(
name='TTS',
version=version,
url='https://github.com/mozilla/TTS',
description='Text to Speech with Deep Learning',
packages=find_packages(),
cmdclass={
'build_py': build_py,
'develop': develop,
},
setup_requires=["numpy"],
install_requires=[
"scipy",
"torch == 0.4.0",
"librosa",
"unidecode",
"tensorboardX",
"matplotlib",
"Pillow",
"flask",
"lws",
],
extras_require={
"bin": [
"tqdm",
"requests",
],
})

Просмотреть файл

@ -8,19 +8,17 @@ OUT_PATH = '/tmp/test.pth.tar'
class ModelSavingTests(unittest.TestCase):
def save_checkpoint_test(self):
# create a dummy model
model = Prenet(128, out_features=[256, 128])
model = T.nn.DataParallel(layer)
# save the model
save_checkpoint(model, None, 100,
OUTPATH, 1, 1)
save_checkpoint(model, None, 100, OUTPATH, 1, 1)
# load the model to CPU
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
loc: storage)
model_dict = torch.load(
MODEL_PATH, map_location=lambda storage, loc: storage)
model.load_state_dict(model_dict['model'])
def save_best_model_test(self):
@ -29,11 +27,9 @@ class ModelSavingTests(unittest.TestCase):
model = T.nn.DataParallel(layer)
# save the model
best_loss = save_best_model(model, None, 0,
100, OUT_PATH,
10, 1)
best_loss = save_best_model(model, None, 0, 100, OUT_PATH, 10, 1)
# load the model to CPU
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
loc: storage)
model_dict = torch.load(
MODEL_PATH, map_location=lambda storage, loc: storage)
model.load_state_dict(model_dict['model'])

Просмотреть файл

@ -7,7 +7,6 @@ from TTS.utils.generic_utils import sequence_mask
class PrenetTests(unittest.TestCase):
def test_in_out(self):
layer = Prenet(128, out_features=[256, 128])
dummy_input = T.rand(4, 128)
@ -19,7 +18,6 @@ class PrenetTests(unittest.TestCase):
class CBHGTests(unittest.TestCase):
def test_in_out(self):
layer = CBHG(128, K=6, projections=[128, 128], num_highways=2)
dummy_input = T.rand(4, 8, 128)
@ -32,7 +30,6 @@ class CBHGTests(unittest.TestCase):
class DecoderTests(unittest.TestCase):
def test_in_out(self):
layer = Decoder(in_features=256, memory_dim=80, r=2)
dummy_input = T.rand(4, 8, 256)
@ -49,7 +46,6 @@ class DecoderTests(unittest.TestCase):
class EncoderTests(unittest.TestCase):
def test_in_out(self):
layer = Encoder(128)
dummy_input = T.rand(4, 8, 128)
@ -63,7 +59,6 @@ class EncoderTests(unittest.TestCase):
class L1LossMaskedTests(unittest.TestCase):
def test_in_out(self):
layer = L1LossMasked()
dummy_input = T.ones(4, 8, 128).float()
@ -80,7 +75,7 @@ class L1LossMaskedTests(unittest.TestCase):
dummy_input = T.ones(4, 8, 128).float()
dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.arange(5, 9)).long()
mask = ((sequence_mask(dummy_length).float() - 1.0)
* 100.0).unsqueeze(2)
mask = (
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])

Просмотреть файл

@ -12,34 +12,38 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TestLJSpeechDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestLJSpeechDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.ap = AudioProcessor(sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
self.ap = AudioProcessor(
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
def test_loader(self):
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.r,
c.text_cleaner,
ap = self.ap,
min_seq_len=c.min_seq_len
)
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.r,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(dataset, batch_size=2,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
@ -62,18 +66,22 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_input.shape[2] == c.num_mels
def test_padding(self):
dataset = LJSpeech.MyDataset(os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
1,
c.text_cleaner,
ap = self.ap,
min_seq_len=c.min_seq_len
)
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
# Test for batch size 1
dataloader = DataLoader(dataset, batch_size=1,
shuffle=False, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
@ -98,9 +106,13 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader = DataLoader(dataset, batch_size=2,
shuffle=False, collate_fn=dataset.collate_fn,
drop_last=False, num_workers=c.num_loader_workers)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
@ -130,9 +142,9 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch
assert mel_input[1-idx, -1].sum() == 0
assert linear_input[1-idx, -1].sum() == 0
assert stop_target[1-idx, -1] == 1
assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
@ -141,34 +153,38 @@ class TestLJSpeechDataset(unittest.TestCase):
class TestKusalDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestKusalDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
self.ap = AudioProcessor(sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
self.ap = AudioProcessor(
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
def test_loader(self):
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.r,
c.text_cleaner,
ap = self.ap,
min_seq_len=c.min_seq_len
)
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
c.r,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(dataset, batch_size=2,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
@ -191,18 +207,22 @@ class TestKusalDataset(unittest.TestCase):
assert mel_input.shape[2] == c.num_mels
def test_padding(self):
dataset = Kusal.MyDataset(os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
1,
c.text_cleaner,
ap = self.ap,
min_seq_len=c.min_seq_len
)
dataset = Kusal.MyDataset(
os.path.join(c.data_path_Kusal),
os.path.join(c.data_path_Kusal, 'prompts.txt'),
1,
c.text_cleaner,
ap=self.ap,
min_seq_len=c.min_seq_len)
# Test for batch size 1
dataloader = DataLoader(dataset, batch_size=1,
shuffle=False, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
@ -227,9 +247,13 @@ class TestKusalDataset(unittest.TestCase):
assert mel_lengths[0] == mel_input[0].shape[0]
# Test for batch size 2
dataloader = DataLoader(dataset, batch_size=2,
shuffle=False, collate_fn=dataset.collate_fn,
drop_last=False, num_workers=c.num_loader_workers)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=False,
collate_fn=dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
@ -259,16 +283,16 @@ class TestKusalDataset(unittest.TestCase):
assert mel_lengths[idx] == mel_input[idx].shape[0]
# check the second itme in the batch
assert mel_input[1-idx, -1].sum() == 0
assert linear_input[1-idx, -1].sum() == 0
assert stop_target[1-idx, -1] == 1
assert mel_input[1 - idx, -1].sum() == 0
assert linear_input[1 - idx, -1].sum() == 0
assert stop_target[1 - idx, -1] == 1
assert len(mel_lengths.shape) == 1
# check batch conditions
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# class TestTWEBDataset(unittest.TestCase):
# def __init__(self, *args, **kwargs):
@ -339,7 +363,7 @@ class TestKusalDataset(unittest.TestCase):
# for i, data in enumerate(dataloader):
# if i == self.max_loader_iter:
# break
# text_input = data[0]
# text_lengths = data[1]
# linear_input = data[2]
@ -399,4 +423,4 @@ class TestKusalDataset(unittest.TestCase):
# # check batch conditions
# assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
# assert (linear_input * stop_target.unsqueeze(2)).sum() == 0

Просмотреть файл

@ -19,50 +19,51 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTrainTest(unittest.TestCase):
def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device)
mel_spec = torch.rand(8, 30, c.num_mels).to(device)
linear_spec = torch.rand(8, 30, c.num_freq).to(device)
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1)
stop_targets = stop_targets.view(input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
criterion = L1LossMasked().to(device)
criterion_st = nn.BCELoss().to(device)
model = Tacotron(c.embedding_size,
c.num_freq,
c.num_mels,
model = Tacotron(c.embedding_size, c.num_freq, c.num_mels,
c.r).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5):
mel_out, linear_out, align, stop_tokens = model.forward(input, mel_spec)
mel_out, linear_out, align, stop_tokens = model.forward(
input, mel_spec)
assert stop_tokens.data.max() <= 1.0
assert stop_tokens.data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths)
loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets)
loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss
loss = loss + criterion(linear_out, linear_spec,
mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
# ignore pre-higway layer since it works conditional
if count not in [148, 59]:
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(count, param.shape, param, param_ref)
assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1

226
train.py
Просмотреть файл

@ -1,39 +1,34 @@
import os
import sys
import time
import datetime
import shutil
import torch
import signal
import argparse
import importlib
import pickle
import traceback
import numpy as np
import torch.nn as nn
from torch import optim
from torch import onnx
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboardX import SummaryWriter
from utils.generic_utils import (synthesis, remove_experiment_folder,
create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay,
count_parameters, check_update, get_commit_hash)
from utils.generic_utils import (
synthesis, remove_experiment_folder, create_experiment_folder,
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
check_update, get_commit_hash)
from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron
from layers.losses import L1LossMasked
from utils.audio import AudioProcessor
torch.manual_seed(1)
torch.set_num_threads(4)
use_cuda = torch.cuda.is_available()
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, ap, epoch):
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
ap, epoch):
model = model.train()
epoch_time = 0
avg_linear_loss = 0
@ -54,7 +49,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
stop_targets = data[5]
# set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
current_step = num_iter + args.restore_step + \
@ -89,7 +85,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets)
mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths)\
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq],
mel_lengths)
@ -106,7 +102,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# backpass and check the grad norm for stop loss
stop_loss.backward()
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100)
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet,
0.5, 100)
if skip_flag:
optimizer_st.zero_grad()
print(" | | > Iteration skipped fro stopnet!!")
@ -117,9 +114,10 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
epoch_time += step_time
if current_step % c.print_step == 0:
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "\
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "\
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter, current_step,
print(" | | > Step:{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter,
current_step,
loss.item(),
linear_loss.item(),
mel_loss.item(),
@ -147,8 +145,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
if current_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, optimizer_st, linear_loss.item(),
OUT_PATH, current_step, epoch)
save_checkpoint(model, optimizer, optimizer_st,
linear_loss.item(), OUT_PATH, current_step,
epoch)
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy()
@ -168,8 +167,11 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
ap.griffin_lim_iters = 60
audio_signal = ap.inv_spectrogram(audio_signal.T)
try:
tb.add_audio('SampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate)
tb.add_audio(
'SampleAudio',
audio_signal,
current_step,
sample_rate=c.sample_rate)
except:
pass
@ -180,9 +182,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_step_time /= (num_iter + 1)
# print epoch stats
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "\
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "\
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "\
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
"AvgStepTime:{:.2f}".format(current_step,
avg_total_loss,
avg_linear_loss,
@ -209,10 +211,12 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
avg_mel_loss = 0
avg_stop_loss = 0
print(" | > Validation")
test_sentences = ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist."]
test_sentences = [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist."
]
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
with torch.no_grad():
if data_loader is not None:
@ -228,7 +232,9 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
stop_targets = data[5]
# set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r,
-1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
# dispatch data to GPU
@ -256,11 +262,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
epoch_time += step_time
if num_iter % c.print_step == 0:
print(" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "\
"StopLoss: {:.5f} ".format(loss.item(),
linear_loss.item(),
mel_loss.item(),
stop_loss.item()), flush=True)
print(
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
"StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(),
mel_loss.item(), stop_loss.item()),
flush=True)
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
@ -278,15 +284,19 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step)
tb.add_image('ValVisual/ValidationAlignment', align_img,
current_step)
# Sample audio
audio_signal = linear_output[idx].data.cpu().numpy()
ap.griffin_lim_iters = 60
audio_signal = ap.inv_spectrogram(audio_signal.T)
try:
tb.add_audio('ValSampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate)
tb.add_audio(
'ValSampleAudio',
audio_signal,
current_step,
sample_rate=c.sample_rate)
except:
# sometimes audio signal is out of boundaries
pass
@ -298,81 +308,88 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
# Plot Learning Stats
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss,
current_step)
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss,
current_step)
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss, current_step)
tb.add_scalar('ValEpochLoss/Stop_loss', avg_stop_loss,
current_step)
# test sentences
ap.griffin_lim_iters = 60
for idx, test_sentence in enumerate(test_sentences):
wav, linear_spec, alignments = synthesis(model, ap, test_sentence, use_cuda,
c.text_cleaner)
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
use_cuda, c.text_cleaner)
try:
wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio(wav_name, wav, current_step,
sample_rate=c.sample_rate)
tb.add_audio(
wav_name, wav, current_step, sample_rate=c.sample_rate)
except:
pass
align_img = alignments[0].data.cpu().numpy()
linear_spec = plot_spectrogram(linear_spec, ap)
align_img = plot_alignment(align_img)
tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec, current_step)
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img, current_step)
tb.add_image('TestSentences/{}_Spectrogram'.format(idx), linear_spec,
current_step)
tb.add_image('TestSentences/{}_Alignment'.format(idx), align_img,
current_step)
return avg_linear_loss
def main(args):
dataset = importlib.import_module('datasets.'+c.dataset)
dataset = importlib.import_module('datasets.' + c.dataset)
Dataset = getattr(dataset, 'MyDataset')
audio = importlib.import_module('utils.'+c.audio_processor)
audio = importlib.import_module('utils.' + c.audio_processor)
AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor(sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
ap = AudioProcessor(
sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
# Setup the dataset
train_dataset = Dataset(c.data_path,
c.meta_file_train,
c.r,
c.text_cleaner,
ap = ap,
min_seq_len=c.min_seq_len
)
train_dataset = Dataset(
c.data_path,
c.meta_file_train,
c.r,
c.text_cleaner,
ap=ap,
min_seq_len=c.min_seq_len)
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
shuffle=False, collate_fn=train_dataset.collate_fn,
drop_last=False, num_workers=c.num_loader_workers,
pin_memory=True)
train_loader = DataLoader(
train_dataset,
batch_size=c.batch_size,
shuffle=False,
collate_fn=train_dataset.collate_fn,
drop_last=False,
num_workers=c.num_loader_workers,
pin_memory=True)
if c.run_eval:
val_dataset = Dataset(c.data_path,
c.meta_file_val,
c.r,
c.text_cleaner,
ap = ap
)
val_dataset = Dataset(
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap)
val_loader = DataLoader(val_dataset, batch_size=c.eval_batch_size,
shuffle=False, collate_fn=val_dataset.collate_fn,
drop_last=False, num_workers=4,
pin_memory=True)
val_loader = DataLoader(
val_dataset,
batch_size=c.eval_batch_size,
shuffle=False,
collate_fn=val_dataset.collate_fn,
drop_last=False,
num_workers=4,
pin_memory=True)
else:
val_loader = None
model = Tacotron(c.embedding_size,
ap.num_freq,
c.num_mels,
c.r)
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr)
@ -394,7 +411,8 @@ def main(args):
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
print(" > Model restored from step %d" % checkpoint['step'], flush=True)
print(
" > Model restored from step %d" % checkpoint['step'], flush=True)
start_epoch = checkpoint['step'] // len(train_loader)
best_loss = checkpoint['linear_loss']
args.restore_step = checkpoint['step']
@ -416,22 +434,36 @@ def main(args):
best_loss = float('inf')
for epoch in range(0, c.epochs):
train_loss, current_step = train(model, criterion, criterion_st, train_loader, optimizer, optimizer_st, ap, epoch)
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap, current_step)
print(" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(train_loss, val_loss), flush=True)
best_loss = save_best_model(model, optimizer, train_loss,
best_loss, OUT_PATH,
current_step, epoch)
train_loss, current_step = train(model, criterion, criterion_st,
train_loader, optimizer, optimizer_st,
ap, epoch)
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
current_step)
print(
" | > Train Loss: {:.5f} Validation Loss: {:.5f}".format(
train_loss, val_loss),
flush=True)
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
OUT_PATH, current_step, epoch)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
parser.add_argument('--debug', type=bool, default=False,
help='do not ask for git has before run.')
parser.add_argument(
'--restore_path',
type=str,
help='Folder path to checkpoints',
default=0)
parser.add_argument(
'--config_path',
type=str,
help='path to config file for training',
)
parser.add_argument(
'--debug',
type=bool,
default=False,
help='do not ask for git has before run.')
args = parser.parse_args()
# setup output paths and read configs

Просмотреть файл

@ -9,10 +9,19 @@ _mel_basis = None
class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, ref_level_db, num_freq, power, preemphasis,
min_mel_freq, max_mel_freq, griffin_lim_iters=None):
def __init__(self,
sample_rate,
num_mels,
min_level_db,
frame_shift_ms,
frame_length_ms,
ref_level_db,
num_freq,
power,
preemphasis,
min_mel_freq,
max_mel_freq,
griffin_lim_iters=None):
self.sample_rate = sample_rate
self.num_mels = num_mels
@ -30,7 +39,8 @@ class AudioProcessor(object):
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
librosa.output.write_wav(
path, wav.astype(np.float), self.sample_rate, norm=True)
def _linear_to_mel(self, spectrogram):
global _mel_basis
@ -40,8 +50,9 @@ class AudioProcessor(object):
def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
return librosa.filters.mel(
self.sample_rate, n_fft, n_mels=self.num_mels)
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
@ -66,7 +77,7 @@ class AudioProcessor(object):
if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
return signal.lfilter([1, -self.preemphasis], [1], x)
def apply_inv_preemphasis(self, x):
if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
@ -86,9 +97,9 @@ class AudioProcessor(object):
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
else:
return self._griffin_lim(S ** self.power)
return self._griffin_lim(S**self.power)
def _griffin_lim(self, S):
'''Applies Griffin-Lim's raw.
@ -113,7 +124,8 @@ class AudioProcessor(object):
def _stft(self, y):
n_fft, hop_length, win_length = self._stft_parameters()
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
return librosa.stft(
y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _istft(self, y):
_, hop_length, win_length = self._stft_parameters()

Просмотреть файл

@ -8,11 +8,23 @@ import lws
_mel_basis = None
class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, ref_level_db, num_freq, power, preemphasis,
min_mel_freq, max_mel_freq, griffin_lim_iters=None, ):
class AudioProcessor(object):
def __init__(
self,
sample_rate,
num_mels,
min_level_db,
frame_shift_ms,
frame_length_ms,
ref_level_db,
num_freq,
power,
preemphasis,
min_mel_freq,
max_mel_freq,
griffin_lim_iters=None,
):
print(" > Setting up Audio Processor...")
self.sample_rate = sample_rate
self.num_mels = num_mels
@ -25,18 +37,19 @@ class AudioProcessor(object):
self.min_mel_freq = min_mel_freq
self.max_mel_freq = max_mel_freq
self.griffin_lim_iters = griffin_lim_iters
self.preemphasis =preemphasis
self.preemphasis = preemphasis
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
if preemphasis == 0:
print(" | > Preemphasis is deactive.")
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.float), self.sample_rate, norm=True)
librosa.output.write_wav(
path, wav.astype(np.float), self.sample_rate, norm=True)
def _stft_parameters(self, ):
n_fft = int((self.num_freq - 1) * 2)
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000.0 * self.sample_rate)
if n_fft % hop_length != 0:
hop_length = n_fft / 8
@ -44,14 +57,21 @@ class AudioProcessor(object):
if n_fft % win_length != 0:
win_length = n_fft / 2
print(" | > win_length is set to default ({}).".format(win_length))
print(" | > fft size: {}, hop length: {}, win length: {}".format(n_fft, hop_length, win_length))
print(" | > fft size: {}, hop length: {}, win length: {}".format(
n_fft, hop_length, win_length))
return int(n_fft), int(hop_length), int(win_length)
def _lws_processor(self):
try:
return lws.lws(self.win_length, self.hop_length, fftsize=self.n_fft, mode="speech")
return lws.lws(
self.win_length,
self.hop_length,
fftsize=self.n_fft,
mode="speech")
except:
raise RuntimeError(" !! WindowLength({}) is not multiple of HopLength({}).".format(self.win_length, self.hop_length))
raise RuntimeError(
" !! WindowLength({}) is not multiple of HopLength({}).".
format(self.win_length, self.hop_length))
def _amp_to_db(self, x):
min_level = np.exp(self.min_level_db / 20 * np.log(10))
@ -70,7 +90,7 @@ class AudioProcessor(object):
if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
return signal.lfilter([1, -self.preemphasis], [1], x)
def apply_inv_preemphasis(self, x):
if self.preemphasis == 0:
raise RuntimeError(" !! Preemphasis is applied with factor 0.0. ")
@ -96,14 +116,14 @@ class AudioProcessor(object):
S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
processor = self._lws_processor()
D = processor.run_lws(S.astype(np.float64).T ** self.power)
D = processor.run_lws(S.astype(np.float64).T**self.power)
y = processor.istft(D).astype(np.float32)
# Reconstruct phase
if self.preemphasis:
return self.apply_inv_preemphasis(y)
sys.stdout = old_out
return y
def _linear_to_mel(self, spectrogram):
global _mel_basis
if _mel_basis is None:
@ -111,7 +131,10 @@ class AudioProcessor(object):
return np.dot(_mel_basis, spectrogram)
def _build_mel_basis(self, ):
return librosa.filters.mel(self.sample_rate, self.n_fft, n_mels=self.num_mels)
return librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.num_mels)
# fmin=self.min_mel_freq, fmax=self.max_mel_freq)
def melspectrogram(self, y):
@ -124,4 +147,4 @@ class AudioProcessor(object):
D = self._lws_processor().stft(y).T
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
sys.stdout = old_out
return self._normalize(S)
return self._normalize(S)

Просмотреть файл

@ -4,9 +4,8 @@ import numpy as np
def _pad_data(x, length):
_pad = 0
assert x.ndim == 1
return np.pad(x, (0, length - x.shape[0]),
mode='constant',
constant_values=_pad)
return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
def prepare_data(inputs):
@ -17,8 +16,10 @@ def prepare_data(inputs):
def _pad_tensor(x, length):
_pad = 0
assert x.ndim == 2
x = np.pad(x, [[0, 0], [0, length - x.shape[1]]],
mode='constant', constant_values=_pad)
x = np.pad(
x, [[0, 0], [0, length - x.shape[1]]],
mode='constant',
constant_values=_pad)
return x
@ -32,7 +33,8 @@ def prepare_tensor(inputs, out_steps):
def _pad_stop_target(x, length):
_pad = 1.
assert x.ndim == 1
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
def prepare_stop_target(inputs, out_steps):
@ -44,6 +46,7 @@ def prepare_stop_target(inputs, out_steps):
def pad_per_step(inputs, pad_len):
timesteps = inputs.shape[-1]
return np.pad(inputs, [[0, 0], [0, 0],
[0, pad_len]],
mode='constant', constant_values=0.0)
return np.pad(
inputs, [[0, 0], [0, 0], [0, pad_len]],
mode='constant',
constant_values=0.0)

Просмотреть файл

@ -28,10 +28,13 @@ def load_config(config_path):
def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
try:
subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean
subprocess.check_output(['git', 'diff-index', '--quiet',
'HEAD']) # Verify client is clean
except:
raise RuntimeError(" !! Commit before training to get the commit hash.")
commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
raise RuntimeError(
" !! Commit before training to get the commit hash.")
commit = subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode().strip()
print(' > Git Hash: {}'.format(commit))
return commit
@ -43,7 +46,8 @@ def create_experiment_folder(root_path, model_name, debug):
commit_hash = 'debug'
else:
commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, date_str + '-' + model_name + '-' + commit_hash)
output_folder = os.path.join(
root_path, date_str + '-' + model_name + '-' + commit_hash)
os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder))
return output_folder
@ -52,7 +56,7 @@ def create_experiment_folder(root_path, model_name, debug):
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
if len(checkpoint_files) < 1:
if os.path.exists(experiment_path):
shutil.rmtree(experiment_path)
@ -86,13 +90,15 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = _trim_model_state_dict(model.state_dict())
state = {'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'optimizer_st': optimizer_st.state_dict(),
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")}
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'optimizer_st': optimizer_st.state_dict(),
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")
}
torch.save(state, checkpoint_path)
@ -100,12 +106,14 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch):
if model_loss < best_loss:
new_state_dict = _trim_model_state_dict(model.state_dict())
state = {'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")}
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")
}
best_loss = model_loss
bestmodel_path = 'best_model.pth.tar'
bestmodel_path = os.path.join(out_path, bestmodel_path)
@ -161,12 +169,12 @@ def sequence_mask(sequence_length, max_len=None):
def synthesis(model, ap, text, use_cuda, text_cleaner):
text_cleaner = [text_cleaner]
seq = np.array(text_to_sequence(text, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda().long()
_, linear_out, alignments, _ = model.forward(chars_var)
linear_out = linear_out[0].data.cpu().numpy()
wav = ap.inv_spectrogram(linear_out.T)
return wav, linear_out, alignments
text_cleaner = [text_cleaner]
seq = np.array(text_to_sequence(text, text_cleaner))
chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda:
chars_var = chars_var.cuda().long()
_, linear_out, alignments, _ = model.forward(chars_var)
linear_out = linear_out[0].data.cpu().numpy()
wav = ap.inv_spectrogram(linear_out.T)
return wav, linear_out, alignments

Просмотреть файл

@ -4,7 +4,6 @@ import re
from utils.text import cleaners
from utils.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}

Просмотреть файл

@ -14,31 +14,31 @@ import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):

Просмотреть файл

@ -1,17 +1,16 @@
# -*- coding: utf-8 -*-
import re
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
'Y', 'Z', 'ZH'
]
_valid_symbol_set = set(valid_symbols)
@ -27,8 +26,10 @@ class CMUDict:
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {word: pron for word,
pron in entries.items() if len(pron) == 1}
entries = {
word: pron
for word, pron in entries.items() if len(pron) == 1
}
self._entries = entries
def __len__(self):

Просмотреть файл

@ -8,61 +8,45 @@ _ordinal_re = re.compile(r'([0-9]+)(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
_units = [
'',
'one',
'two',
'three',
'four',
'five',
'six',
'seven',
'eight',
'nine',
'ten',
'eleven',
'twelve',
'thirteen',
'fourteen',
'fifteen',
'sixteen',
'seventeen',
'eighteen',
'nineteen'
'', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine',
'ten', 'eleven', 'twelve', 'thirteen', 'fourteen', 'fifteen', 'sixteen',
'seventeen', 'eighteen', 'nineteen'
]
_tens = [
'',
'ten',
'twenty',
'thirty',
'forty',
'fifty',
'sixty',
'seventy',
'eighty',
'ninety',
'',
'ten',
'twenty',
'thirty',
'forty',
'fifty',
'sixty',
'seventy',
'eighty',
'ninety',
]
_digit_groups = [
'',
'thousand',
'million',
'billion',
'trillion',
'quadrillion',
'',
'thousand',
'million',
'billion',
'trillion',
'quadrillion',
]
_ordinal_suffixes = [
('one', 'first'),
('two', 'second'),
('three', 'third'),
('five', 'fifth'),
('eight', 'eighth'),
('nine', 'ninth'),
('twelve', 'twelfth'),
('ty', 'tieth'),
('one', 'first'),
('two', 'second'),
('three', 'third'),
('five', 'fifth'),
('eight', 'eighth'),
('nine', 'ninth'),
('twelve', 'twelfth'),
('ty', 'tieth'),
]
def _remove_commas(m):
return m.group(1).replace(',', '')
@ -114,7 +98,7 @@ def _standard_number_to_words(n, digit_group):
def _number_to_words(n):
# Handle special cases first, then go to the standard case:
if n >= 1000000000000000000:
return str(n) # Too large, just return the digits
return str(n) # Too large, just return the digits
elif n == 0:
return 'zero'
elif n % 100 == 0 and n % 1000 != 0 and n < 3000:

Просмотреть файл

@ -1,6 +1,4 @@
# -*- coding: utf-8 -*-
'''
Defines the set of symbols used in text input to the model.
@ -19,6 +17,5 @@ _arpabet = ['@' + s for s in cmudict.valid_symbols]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet
if __name__ == '__main__':
print(symbols)

Просмотреть файл

@ -6,8 +6,8 @@ import matplotlib.pyplot as plt
def plot_alignment(alignment, info=None):
fig, ax = plt.subplots(figsize=(16, 10))
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
interpolation='none')
im = ax.imshow(
alignment.T, aspect='auto', origin='lower', interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
@ -17,7 +17,7 @@ def plot_alignment(alignment, info=None):
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
plt.close()
return data
@ -30,6 +30,6 @@ def plot_spectrogram(linear_output, audio):
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
plt.close()
return data
return data