Merge pull request #3 from mozilla/testing_docing

Tests and some changes on the architecture.
This commit is contained in:
Eren Golge 2018-02-13 17:13:22 +01:00 коммит произвёл GitHub
Родитель b33d1ee043 d87eef9404
Коммит ab74c3e0cf
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
30 изменённых файлов: 1125 добавлений и 325 удалений

2
.gitignore поставляемый
Просмотреть файл

@ -116,3 +116,5 @@ venv.bak/
*.pth.tar
result/
# jupyter dummy files
core

Просмотреть файл

@ -15,7 +15,53 @@ Highly recommended to use [miniconda](https://conda.io/miniconda.html) for easie
* TODO
## Data
TODO
Currently TTS provides data loaders for
- [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
## Training the network
TODO
To run your own training, you need to define a ```config.json``` file (simple template below) and call with the command.
```train.py --config_path config.json```
If you like to use specific set of GPUs.
```CUDA_VISIBLE_DEVICES="0,1,4" train.py --config_path config.json```
Each run creates an experiment folder with the corresponfing date and time, under the folder you set in ```config.json```. And if there is no checkpoint yet under that folder, it is going to be removed when you press Ctrl+C.
Example ```config.json```:
```
{
// Data loading parameters
"num_mels": 80,
"num_freq": 1024,
"sample_rate": 20000,
"frame_length_ms": 50.0,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"hidden_size": 128,
"embedding_size": 256,
"text_cleaner": "english_cleaners",
// Training parameters
"epochs": 2000,
"lr": 0.001,
"lr_patience": 2, // lr_scheduler.ReduceLROnPlateau().patience
"lr_decay": 0.5, // lr_scheduler.ReduceLROnPlateau().factor
"batch_size": 256,
"griffinf_lim_iters": 60,
"power": 1.5,
"r": 5, // number of decoder outputs for Tacotron
// Number of data loader processes
"num_loader_workers": 8,
// Experiment logging parameters
"save_step": 200,
"data_path": "/path/to/KeithIto/LJSpeech-1.0",
"output_path": "/path/to/my_experiment",
"log_dir": "/path/to/my/tensorboard/logs/"
}
```

Просмотреть файл

@ -1,29 +1,31 @@
{
"num_mels": 80,
"num_freq": 1024,
"num_freq": 1025,
"sample_rate": 20000,
"frame_length_ms": 50.0,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"hidden_size": 128,
"embedding_size": 256,
"epochs": 10000,
"lr": 0.01,
"decay_step": [500000, 1000000, 2000000],
"batch_size": 128,
"max_iters": 200,
"griffinf_lim_iters": 60,
"power": 1.5,
"r": 5,
"log_step": 100,
"save_step": 2000,
"text_cleaner": "english_cleaners",
"epochs": 2000,
"lr": 0.003,
"lr_patience": 5,
"lr_decay": 0.5,
"batch_size": 180,
"r": 5,
"griffin_lim_iters": 60,
"power": 1.5,
"num_loader_workers": 32,
"checkpoint": false,
"save_step": 69,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result"
"output_path": "result",
"log_dir": "/home/erogol/projects/TTS/logs/"
}

Двоичные данные
datasets/.LJSpeech.py.swp

Двоичный файл не отображается.

Просмотреть файл

@ -2,28 +2,34 @@ import pandas as pd
import os
import numpy as np
import collections
import librosa
import torch
from torch.utils.data import Dataset
from Tacotron.utils.text import text_to_sequence
from Tacotron.utils.audio import *
from Tacotron.utils.data import prepare_data, pad_data, pad_per_step
from TTS.utils.text import text_to_sequence
from TTS.utils.audio import AudioProcessor
from TTS.utils.data import prepare_data, pad_data, pad_per_step
class LJSpeechDataset(Dataset):
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
cleaners):
text_cleaner, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power):
self.frames = pd.read_csv(csv_file, sep='|', header=None)
self.root_dir = root_dir
self.outputs_per_step = outputs_per_step
self.sample_rate = sample_rate
self.cleaners = cleaners
self.cleaners = text_cleaner
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power
)
print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
def load_wav(self, filename):
try:
audio = librosa.load(filename, sr=self.sample_rate)
audio = librosa.core.load(filename, sr=self.sample_rate)
return audio
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
@ -37,36 +43,54 @@ class LJSpeechDataset(Dataset):
text = self.frames.ix[idx, 1]
text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
sample = {'text': text, 'wav': wav}
sample = {'text': text, 'wav': wav, 'item_idx': self.frames.ix[idx, 0]}
return sample
def get_dummy_data(self):
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
def collate_fn(self, batch):
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):
keys = list()
text = [d['text'] for d in batch]
wav = [d['wav'] for d in batch]
item_idxs = [d['item_idx'] for d in batch]
text = [d['text'] for d in batch]
text_lenghts = np.array([len(x) for x in text])
max_text_len = np.max(text_lenghts)
# PAD sequences with largest length of the batch
text = prepare_data(text).astype(np.int32)
wav = prepare_data(wav)
magnitude = np.array([spectrogram(w) for w in wav])
mel = np.array([melspectrogram(w) for w in wav])
linear = np.array([self.ap.spectrogram(w).astype('float32') for w in wav])
mel = np.array([self.ap.melspectrogram(w).astype('float32') for w in wav])
assert mel.shape[2] == linear.shape[2]
timesteps = mel.shape[2]
# PAD with zeros that can be divided by outputs per step
if timesteps % self.outputs_per_step != 0:
magnitude = pad_per_step(magnitude, self.outputs_per_step)
mel = pad_per_step(mel, self.outputs_per_step)
if (timesteps + 1) % self.outputs_per_step != 0:
pad_len = self.outputs_per_step - \
((timesteps + 1) % self.outputs_per_step)
pad_len += 1
else:
pad_len = 1
linear = pad_per_step(linear, pad_len)
mel = pad_per_step(mel, pad_len)
# reshape jombo
magnitude = magnitude.transpose(0, 2, 1)
linear = linear.transpose(0, 2, 1)
mel = mel.transpose(0, 2, 1)
return text, magnitude, mel
# convert things to pytorch
text_lenghts = torch.LongTensor(text_lenghts)
text = torch.LongTensor(text)
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
return text, text_lenghts, linear, mel, item_idxs[0]
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}"

0
datasets/__init__.py Normal file
Просмотреть файл

29
debug_config.py Normal file
Просмотреть файл

@ -0,0 +1,29 @@
{
"num_mels": 80,
"num_freq": 1024,
"sample_rate": 20000,
"frame_length_ms": 50.0,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"hidden_size": 128,
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"epochs": 200,
"lr": 0.01,
"lr_patience": 2,
"lr_decay": 0.5,
"batch_size": 32,
"griffinf_lim_iters": 60,
"power": 1.5,
"r": 5,
"num_loader_workers": 16,
"save_step":1 ,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result",
"log_dir": "/home/erogol/projects/TTS/logs/"
}

Двоичные данные
layers/.attention.py.swp

Двоичный файл не отображается.

Двоичные данные
layers/.tacotron.py.swp

Двоичный файл не отображается.

Просмотреть файл

@ -11,11 +11,11 @@ class BahdanauAttention(nn.Module):
self.tanh = nn.Tanh()
self.v = nn.Linear(dim, 1, bias=False)
def forward(self, query, processed_memory):
def forward(self, query, processed_inputs):
"""
Args:
query: (batch, 1, dim) or (batch, dim)
processed_memory: (batch, max_time, dim)
processed_inputs: (batch, max_time, dim)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
@ -24,63 +24,71 @@ class BahdanauAttention(nn.Module):
processed_query = self.query_layer(query)
# (batch, max_time, 1)
alignment = self.v(self.tanh(processed_query + processed_memory))
alignment = self.v(self.tanh(processed_query + processed_inputs))
# (batch, max_time)
return alignment.squeeze(-1)
def get_mask_from_lengths(memory, memory_lengths):
def get_mask_from_lengths(inputs, inputs_lengths):
"""Get mask tensor from list of length
Args:
memory: (batch, max_time, dim)
memory_lengths: array like
inputs: (batch, max_time, dim)
inputs_lengths: array like
"""
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
for idx, l in enumerate(memory_lengths):
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
for idx, l in enumerate(inputs_lengths):
mask[idx][:l] = 1
return ~mask
class AttentionWrapper(nn.Module):
def __init__(self, rnn_cell, attention_mechanism,
def __init__(self, rnn_cell, alignment_model,
score_mask_value=-float("inf")):
super(AttentionWrapper, self).__init__()
self.rnn_cell = rnn_cell
self.attention_mechanism = attention_mechanism
self.alignment_model = alignment_model
self.score_mask_value = score_mask_value
def forward(self, query, attention, cell_state, memory,
processed_memory=None, mask=None, memory_lengths=None):
if processed_memory is None:
processed_memory = memory
if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths)
def forward(self, query, context_vec, cell_state, inputs,
processed_inputs=None, mask=None, inputs_lengths=None):
# Concat input query and previous attention context
cell_input = torch.cat((query, attention), -1)
if processed_inputs is None:
processed_inputs = inputs
# Feed it to RNN
cell_output = self.rnn_cell(cell_input, cell_state)
if inputs_lengths is not None and mask is None:
mask = get_mask_from_lengths(inputs, inputs_lengths)
# Alignment
# (batch, max_time)
alignment = self.attention_mechanism(cell_output, processed_memory)
# e_{ij} = a(s_{i-1}, h_j)
# import ipdb
# ipdb.set_trace()
alignment = self.alignment_model(cell_state, processed_inputs)
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize attention weight
alignment = F.softmax(alignment, dim=0)
# Normalize context_vec weight
alignment = F.softmax(alignment, dim=-1)
# Attention context vector
# (batch, 1, dim)
attention = torch.bmm(alignment.unsqueeze(1), memory)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context_vec = torch.bmm(alignment.unsqueeze(1), inputs)
context_vec = context_vec.squeeze(1)
# (batch, dim)
attention = attention.squeeze(1)
# Concat input query and previous context_vec context
cell_input = torch.cat((query, context_vec), -1)
#cell_input = cell_input.unsqueeze(1)
# Feed it to RNN
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
cell_output = self.rnn_cell(cell_input, cell_state)
context_vec = context_vec.squeeze(1)
return cell_output, context_vec, alignment
return cell_output, attention, alignment

Просмотреть файл

@ -7,35 +7,59 @@ from .attention import BahdanauAttention, AttentionWrapper
from .attention import get_mask_from_lengths
class Prenet(nn.Module):
def __init__(self, in_dim, sizes=[256, 128]):
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
It creates as many layers as given by 'out_features'
Args:
in_features (int): size of the input vector
out_features (int or list): size of each output sample.
If it is a list, for each value, there is created a new layer.
"""
def __init__(self, in_features, out_features=[256, 128]):
super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList(
[nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_sizes, sizes)])
for (in_size, out_size) in zip(in_features, out_features)])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, inputs):
for linear in self.layers:
inputs = self.dropout(self.relu(linear(inputs)))
return inputs
class BatchNormConv1d(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
function between Conv and BatchNorm layers. BatchNorm layer
is initialized with the TF default values for momentum and eps.
Args:
in_channels: size of each input sample
out_channels: size of each output samples
kernel_size: kernel size of conv filters
stride: stride of conv filters
padding: padding of conv filters
activation: activation function set b/w Conv1d and BatchNorm
Shapes:
- input: batch x dims
- output: batch x dims
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
activation=None):
super(BatchNormConv1d, self).__init__()
self.conv1d = nn.Conv1d(in_dim, out_dim,
self.conv1d = nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride, padding=padding, bias=False)
# Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3)
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
self.activation = activation
def forward(self, x):
x = self.conv1d(x)
x = self.conv1d(x)
if self.activation is not None:
x = self.activation(x)
return self.bn(x)
@ -62,135 +86,180 @@ class CBHG(nn.Module):
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
Args:
in_features (int): sample size
K (int): max filter size in conv bank
projections (list): conv channel sizes for conv projections
num_highways (int): number of highways layers
Shapes:
- input: batch x time x dim
- output: batch x time x dim*2
"""
def __init__(self, in_dim, K=16, projections=[128, 128]):
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
super(CBHG, self).__init__()
self.in_dim = in_dim
self.in_features = in_features
self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList(
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
padding=k // 2, activation=self.relu)
for k in range(1, K + 1)])
for k in range(1, K + 1)])
# max pooling of conv bank
# TODO: try average pooling OR larger kernel size
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
in_sizes = [K * in_dim] + projections[:-1]
activations = [self.relu] * (len(projections) - 1) + [None]
self.conv1d_projections = nn.ModuleList(
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
for (in_size, out_size, ac) in zip(
in_sizes, projections, activations)])
out_features = [K * in_features] + projections[:-1]
activations = [self.relu] * (len(projections) - 1)
activations += [None]
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
# setup conv1d projection layers
layer_set = []
for (in_size, out_size, ac) in zip(out_features, projections, activations):
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
padding=1, activation=ac)
layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
self.highways = nn.ModuleList(
[Highway(in_dim, in_dim) for _ in range(4)])
[Highway(in_features, in_features) for _ in range(num_highways)])
# bi-directional GPU layer
self.gru = nn.GRU(
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
in_features, in_features, 1, batch_first=True, bidirectional=True)
def forward(self, inputs, input_lengths=None):
# (B, T_in, in_dim)
def forward(self, inputs):
# (B, T_in, in_features)
x = inputs
# Needed to perform conv1d on time-axis
# (B, in_dim, T_in)
if x.size(-1) == self.in_dim:
# (B, in_features, T_in)
if x.size(-1) == self.in_features:
x = x.transpose(1, 2)
T = x.size(-1)
# (B, in_dim*K, T_in)
# (B, in_features*K, T_in)
# Concat conv1d bank outputs
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
outs = []
for conv1d in self.conv1d_banks:
out = conv1d(x)
out = out[:, :, :T]
outs.append(out)
x = torch.cat(outs, dim=1)
assert x.size(1) == self.in_features * len(self.conv1d_banks)
x = self.max_pool1d(x)[:, :, :T]
for conv1d in self.conv1d_projections:
x = conv1d(x)
# (B, T_in, in_dim)
# (B, T_in, in_features)
# Back to the original shape
x = x.transpose(1, 2)
if x.size(-1) != self.in_dim:
if x.size(-1) != self.in_features:
x = self.pre_highway(x)
# Residual connection
# TODO: try residual scaling as in Deep Voice 3
# TODO: try plain residual layers
x += inputs
for highway in self.highways:
x = highway(x)
if input_lengths is not None:
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True)
# (B, T_in, in_dim*2)
self.gru.flatten_parameters()
# (B, T_in, in_features*2)
# TODO: replace GRU with convolution as in Deep Voice 3
self.gru.flatten_parameters()
outputs, _ = self.gru(x)
if input_lengths is not None:
outputs, _ = nn.utils.rnn.pad_packed_sequence(
outputs, batch_first=True)
return outputs
class Encoder(nn.Module):
def __init__(self, in_dim):
r"""Encapsulate Prenet and CBHG modules for encoder"""
def __init__(self, in_features):
super(Encoder, self).__init__()
self.prenet = Prenet(in_dim, sizes=[256, 128])
self.prenet = Prenet(in_features, out_features=[256, 128])
self.cbhg = CBHG(128, K=16, projections=[128, 128])
def forward(self, inputs, input_lengths=None):
def forward(self, inputs):
r"""
Args:
inputs (FloatTensor): embedding features
Shapes:
- inputs: batch x time x in_features
- outputs: batch x time x 128*2
"""
inputs = self.prenet(inputs)
return self.cbhg(inputs, input_lengths)
return self.cbhg(inputs)
class Decoder(nn.Module):
def __init__(self, memory_dim, r):
r"""Decoder module.
Args:
in_features (int): input vector (encoder output) sample size.
memory_dim (int): memory vector (prev. time-step output) sample size.
r (int): number of outputs per time step.
eps (float): threshold for detecting the end of a sentence.
"""
def __init__(self, in_features, memory_dim, r, eps=0.2):
super(Decoder, self).__init__()
self.max_decoder_steps = 200
self.memory_dim = memory_dim
self.eps = eps
self.r = r
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
# attetion RNN
# input -> |Linear| -> processed_inputs
self.input_layer = nn.Linear(in_features, 256, bias=False)
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
self.attention_rnn = AttentionWrapper(
nn.GRUCell(256 + 128, 256),
nn.GRUCell(in_features + 128, 256),
BahdanauAttention(256)
)
self.memory_layer = nn.Linear(256, 256, bias=False)
# concat and project context and attention vectors
# (prenet_out + attention context) -> output
self.project_to_decoder_in = nn.Linear(512, 256)
# decoder RNNs
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
self.max_decoder_steps = 200
def forward(self, decoder_inputs, memory=None, memory_lengths=None):
"""
def forward(self, inputs, memory=None, memory_lengths=None):
r"""
Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in
Tacotron paper, greedy decoding is adapted.
Args:
decoder_inputs: Encoder outputs. (B, T_encoder, dim)
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
inputs: Encoder outputs.
memory: Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for
attention masking.
"""
B = decoder_inputs.size(0)
processed_memory = self.memory_layer(decoder_inputs)
Shapes:
- inputs: batch x time x encoder_out_dim
- memory: batch x #mels_pecs x mel_spec_dim
"""
B = inputs.size(0)
# TODO: take this segment into Attention module.
processed_inputs = self.input_layer(inputs)
if memory_lengths is not None:
mask = get_mask_from_lengths(processed_memory, memory_lengths)
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
else:
mask = None
@ -198,6 +267,7 @@ class Decoder(nn.Module):
greedy = memory is None
if memory is not None:
# Grouping multiple frames if necessary
if memory.size(-1) == self.memory_dim:
memory = memory.view(B, memory.size(1) // self.r, -1)
@ -206,18 +276,18 @@ class Decoder(nn.Module):
self.memory_dim, self.r)
T_decoder = memory.size(1)
# go frames - 0 frames tarting the sequence
initial_input = Variable(
decoder_inputs.data.new(B, self.memory_dim * self.r).zero_())
# go frame - 0 frames tarting the sequence
initial_memory = Variable(
inputs.data.new(B, self.memory_dim * self.r).zero_())
# Init decoder states
attention_rnn_hidden = Variable(
decoder_inputs.data.new(B, 256).zero_())
inputs.data.new(B, 256).zero_())
decoder_rnn_hiddens = [Variable(
decoder_inputs.data.new(B, 256).zero_())
inputs.data.new(B, 256).zero_())
for _ in range(len(self.decoder_rnns))]
current_attention = Variable(
decoder_inputs.data.new(B, 256).zero_())
current_context_vec = Variable(
inputs.data.new(B, 256).zero_())
# Time first (T_decoder, B, memory_dim)
if memory is not None:
@ -227,21 +297,21 @@ class Decoder(nn.Module):
alignments = []
t = 0
current_input = initial_input
memory_input = initial_memory
while True:
if t > 0:
current_input = outputs[-1] if greedy else memory[t - 1]
memory_input = outputs[-1] if greedy else memory[t - 1]
# Prenet
current_input = self.prenet(current_input)
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
current_input, current_attention, attention_rnn_hidden,
decoder_inputs, processed_memory=processed_memory, mask=mask)
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, processed_inputs=processed_inputs, mask=mask)
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((attention_rnn_hidden, current_attention), -1))
torch.cat((attention_rnn_hidden, current_context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
@ -261,10 +331,11 @@ class Decoder(nn.Module):
t += 1
if greedy:
if t > 1 and is_end_of_frames(output):
if t > 1 and is_end_of_frames(output, self.eps):
break
elif t > self.max_decoder_steps:
print("Warning! doesn't seems to be converged")
print(" !! Decoder stopped with 'max_decoder_steps'. \
Something is probably wrong.")
break
else:
if t >= T_decoder:
@ -279,5 +350,5 @@ class Decoder(nn.Module):
return outputs, alignments
def is_end_of_frames(output, eps=0.2):
def is_end_of_frames(output, eps=0.1): #0.2
return (output.data <= eps).all()

Двоичные данные
models/.tacotron.py.swp

Двоичный файл не отображается.

Просмотреть файл

@ -2,8 +2,9 @@
import torch
from torch.autograd import Variable
from torch import nn
from utils.text.symbols import symbols
from Tacotron.layers.tacotron import Prenet, Encoder, Decoder, CBHG
from TTS.utils.text.symbols import symbols
from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
@ -15,10 +16,12 @@ class Tacotron(nn.Module):
self.use_memory_mask = use_memory_mask
self.embedding = nn.Embedding(len(symbols), embedding_dim,
padding_idx=padding_idx)
print(" | > Embedding dim : {}".format(len(symbols)))
# Trying smaller std
self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(mel_dim, r)
self.decoder = Decoder(256, mel_dim, r)
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
@ -28,7 +31,7 @@ class Tacotron(nn.Module):
inputs = self.embedding(characters)
# (B, T', in_dim)
encoder_outputs = self.encoder(inputs, input_lengths)
encoder_outputs = self.encoder(inputs)
if self.use_memory_mask:
memory_lengths = input_lengths

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

51
notebooks/utils.py Normal file
Просмотреть файл

@ -0,0 +1,51 @@
import io
import librosa
import torch
import numpy as np
from TTS.utils.text import text_to_sequence
from matplotlib import pylab as plt
hop_length = 250
def create_speech(m, s, CONFIG, use_cuda, ap):
text_cleaner = [CONFIG.text_cleaner]
seq = np.array(text_to_sequence(s, text_cleaner))
# mel = np.zeros([seq.shape[0], CONFIG.num_mels, 1], dtype=np.float32)
if use_cuda:
chars_var = torch.autograd.Variable(torch.from_numpy(seq), volatile=True).unsqueeze(0).cuda()
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.cuda.FloatTensor), volatile=True).cuda()
else:
chars_var = torch.autograd.Variable(torch.from_numpy(seq), volatile=True).unsqueeze(0)
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.FloatTensor), volatile=True)
mel_out, linear_out, alignments =m.forward(chars_var)
linear_out = linear_out[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
spec = ap._denormalize(linear_out)
wav = ap.inv_spectrogram(linear_out.T)
wav = wav[:ap.find_endpoint(wav)]
out = io.BytesIO()
ap.save_wav(wav, out)
return wav, alignment, spec
def visualize(alignment, spectrogram, CONFIG):
label_fontsize = 16
plt.figure(figsize=(16,16))
plt.subplot(2,1,1)
plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
plt.colorbar()
plt.subplot(2,1,2)
librosa.display.specshow(spectrogram.T, sr=CONFIG.sample_rate,
hop_length=hop_length, x_axis="time", y_axis="linear")
plt.xlabel("Time", fontsize=label_fontsize)
plt.ylabel("Hz", fontsize=label_fontsize)
plt.tight_layout()
plt.colorbar()

Просмотреть файл

@ -0,0 +1,5 @@
librosa
inflect
unidecode
tensorboard
tensorboardX

Просмотреть файл

@ -38,17 +38,11 @@ def main(args):
# Sentences for generation
sentences = [
"And it is worth mention in passing that, as an example of fine typography,",
# From July 8, 2017 New York Times:
'Scientists at the CERN laboratory say they have discovered a new particle.',
'Theres a way to measure the acute emotional intelligence that has never gone out of style.',
'President Trump met with other leaders at the Group of 20 conference.',
'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.',
# From Google's Tacotron example page:
'Generative adversarial network or variational auto-encoder.',
'The buses aren\'t the problem, they actually provide a solution.',
'Does the quick brown fox jump over the lazy dog?',
'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.',
"I try my best to translate text to speech. But I know I need more work",
"The new Firefox, Fast for good.",
"Technology is continually providing us with new ways to create and publish stories.",
"For these stories to achieve their full impact, it requires tool.",
"I am allien and I am here to destron your world."
]
# Synthesis and save to wav files

60
tests/layers_tests.py Normal file
Просмотреть файл

@ -0,0 +1,60 @@
import unittest
import torch as T
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
class PrenetTests(unittest.TestCase):
def test_in_out(self):
layer = Prenet(128, out_features=[256, 128])
dummy_input = T.autograd.Variable(T.rand(4, 128))
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 128
class CBHGTests(unittest.TestCase):
def test_in_out(self):
layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 8
assert output.shape[2] == 256
class DecoderTests(unittest.TestCase):
def test_in_out(self):
layer = Decoder(in_features=128, memory_dim=32, r=5)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
print(layer)
output, alignment = layer(dummy_input, dummy_memory)
print(output.shape)
assert output.shape[0] == 4
assert output.shape[1] == 120 / 5
assert output.shape[2] == 32 * 5
class EncoderTests(unittest.TestCase):
def test_in_out(self):
layer = Encoder(128)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
print(layer)
output = layer(dummy_input)
print(output.shape)
assert output.shape[0] == 4
assert output.shape[1] == 8
assert output.shape[2] == 256 # 128 * 2 BiRNN

93
tests/loader_tests.py Normal file
Просмотреть файл

@ -0,0 +1,93 @@
import os
import unittest
import numpy as np
from torch.utils.data import DataLoader
from TTS.utils.generic_utils import load_config
from TTS.datasets.LJSpeech import LJSpeechDataset
file_path = os.path.dirname(os.path.realpath(__file__))
c = load_config(os.path.join(file_path, 'test_config.json'))
class TestDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestDataset, self).__init__(*args, **kwargs)
self.max_loader_iter = 4
def test_loader(self):
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path, 'wavs'),
c.r,
c.sample_rate,
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
dataloader = DataLoader(dataset, batch_size=c.batch_size,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
item_idx = data[4]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
def test_padding(self):
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path, 'wavs'),
1,
c.sample_rate,
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
dataloader = DataLoader(dataset, batch_size=1,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
item_idx = data[4]
# check the last time step to be zero padded
assert mel_input[0, -1].sum() == 0
assert mel_input[0, -2].sum() != 0
assert linear_input[0, -1].sum() == 0
assert linear_input[0, -2].sum() != 0

0
tests/tacotron_tests.py Normal file
Просмотреть файл

30
tests/test_config.json Normal file
Просмотреть файл

@ -0,0 +1,30 @@
{
"num_mels": 80,
"num_freq": 1025,
"sample_rate": 20000,
"frame_length_ms": 50,
"frame_shift_ms": 12.5,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"hidden_size": 128,
"embedding_size": 256,
"text_cleaner": "english_cleaners",
"epochs": 2000,
"lr": 0.003,
"lr_patience": 5,
"lr_decay": 0.5,
"batch_size": 2,
"r": 5,
"griffin_lim_iters": 60,
"power": 1.5,
"num_loader_workers": 4,
"save_step": 200,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result",
"log_dir": "/home/erogol/projects/TTS/logs/"
}

252
train.py
Просмотреть файл

@ -1,28 +1,33 @@
import os
import sys
import time
import datetime
import shutil
import torch
import signal
import argparse
import importlib
import pickle
import numpy as np
import torch.nn as nn
from torch import optim
from torch import onnx
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboardX import SummaryWriter
from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint,
load_config)
save_best_model, load_config, lr_decay)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
use_cuda = torch.cuda.is_available()
def main(args):
# setup output paths and read configs
@ -33,39 +38,73 @@ def main(args):
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# save config to tmp place to be loaded by subsequent modules.
file_name = str(os.getpid())
tmp_path = os.path.join("/tmp/", file_name+'_tts')
pickle.dump(c, open(tmp_path, "wb"))
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
# Ctrl+C handler to remove empty experiment folder
def signal_handler(signal, frame):
print(" !! Pressed Ctrl+C !!")
remove_experiment_folder(OUT_PATH)
sys.exit(0)
sys.exit(1)
signal.signal(signal.SIGINT, signal_handler)
# Setup the dataset
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path, 'wavs'),
c.r,
c.sample_rate,
c.text_cleaner
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
dataloader = DataLoader(dataset, batch_size=c.batch_size,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
# setup the model
model = Tacotron(c.embedding_size,
c.hidden_size,
c.num_mels,
c.num_freq,
c.r)
# plot model on tensorboard
dummy_input = dataset.get_dummy_data()
## TODO: onnx does not support RNN fully yet
# model_proto_path = os.path.join(OUT_PATH, "model.proto")
# onnx.export(model, dummy_input, model_proto_path, verbose=True)
# tb.add_graph_onnx(model_proto_path)
if use_cuda:
model = nn.DataParallel(model.cuda())
optimizer = optim.Adam(model.parameters(), lr=c.lr)
try:
if args.restore_step:
checkpoint = torch.load(os.path.join(
CHECKPOINT_PATH, 'checkpoint_%d.pth.tar' % args.restore_step))
args.restore_path, 'checkpoint_%d.pth.tar' % args.restore_step))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % args.restore_step)
except:
print("\n > Starting a new training\n")
start_epoch = checkpoint['step'] // len(dataloader)
best_loss = checkpoint['linear_loss']
else:
start_epoch = 0
print("\n > Starting a new training")
model = model.train()
@ -79,112 +118,153 @@ def main(args):
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
for epoch in range(c.epochs):
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
# patience=c.lr_patience, verbose=True)
epoch_time = 0
best_loss = float('inf')
for epoch in range(0, c.epochs):
dataloader = DataLoader(dataset, batch_size=c.batch_size,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=32)
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
progbar = Progbar(len(dataset) / c.batch_size)
for i, data in enumerate(dataloader):
text_input = data[0]
magnitude_input = data[1]
mel_input = data[2]
for num_iter, data in enumerate(dataloader):
start_time = time.time()
current_step = i + args.restore_step + epoch * len(dataloader) + 1
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
# setup lr
current_lr = lr_decay(c.lr, current_step)
for params_group in optimizer.param_groups:
params_group['lr'] = current_lr
optimizer.zero_grad()
try:
mel_input = np.concatenate((np.zeros(
[c.batch_size, 1, c.num_mels], dtype=np.float32),
mel_input[:, 1:, :]), axis=1)
except:
raise TypeError("not same dimension")
# Add a single frame of zeros to Mel Specs for better end detection
#try:
# mel_input = np.concatenate((np.zeros(
# [c.batch_size, 1, c.num_mels], dtype=np.float32),
# mel_input[:, 1:, :]), axis=1)
#except:
# raise TypeError("not same dimension")
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
linear_spec_var = Variable(linear_input, volatile=True)
# sort sequence by length.
# TODO: might be unnecessary
sorted_lengths, indices = torch.sort(
text_lengths.view(-1), dim=0, descending=True)
sorted_lengths = sorted_lengths.long().numpy()
text_input_var = text_input_var[indices]
mel_spec_var = mel_spec_var[indices]
linear_spec_var = linear_spec_var[indices]
if use_cuda:
text_input_var = Variable(torch.from_numpy(text_input).type(
torch.cuda.LongTensor), requires_grad=False).cuda()
mel_input_var = Variable(torch.from_numpy(mel_input).type(
torch.cuda.FloatTensor), requires_grad=False).cuda()
mel_spec_var = Variable(torch.from_numpy(mel_input).type(
torch.cuda.FloatTensor), requires_grad=False).cuda()
linear_spec_var = Variable(torch.from_numpy(magnitude_input)
.type(torch.cuda.FloatTensor), requires_grad=False).cuda()
else:
text_input_var = Variable(torch.from_numpy(text_input).type(
torch.LongTensor), requires_grad=False)
mel_input_var = Variable(torch.from_numpy(mel_input).type(
torch.FloatTensor), requires_grad=False)
mel_spec_var = Variable(torch.from_numpy(
mel_input).type(torch.FloatTensor), requires_grad=False)
linear_spec_var = Variable(torch.from_numpy(
magnitude_input).type(torch.FloatTensor),
requires_grad=False)
text_input_var = text_input_var.cuda()
mel_spec_var = mel_spec_var.cuda()
linear_spec_var = linear_spec_var.cuda()
mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_input_var)
model.forward(text_input_var, mel_spec_var,
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
mel_loss = criterion(mel_output, mel_spec_var)
linear_loss = torch.abs(linear_output - linear_spec_var)
linear_loss = 0.5 * \
torch.mean(linear_loss) + 0.5 * \
torch.mean(linear_loss[:, :n_priority_freq, :])
#linear_loss = torch.abs(linear_output - linear_spec_var)
#linear_loss = 0.5 * \
#torch.mean(linear_loss) + 0.5 * \
#torch.mean(linear_loss[:, :n_priority_freq, :])
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq])
loss = mel_loss + linear_loss
loss = loss.cuda()
start_time = time.time()
# loss = loss.cuda()
loss.backward()
nn.utils.clip_grad_norm(model.parameters(), 1.)
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need
optimizer.step()
time_per_step = time.time() - start_time
progbar.update(i, values=[('total_loss', loss.data[0]),
('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0])])
step_time = time.time() - start_time
epoch_time += step_time
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0]),
('grad_norm', grad_norm)])
# Plot Learning Stats
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
current_step)
tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step)
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
current_step)
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
tb.add_scalar('Time/StepTime', step_time, current_step)
align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_image('Attn/Alignment', align_img, current_step)
if current_step % c.save_step == 0:
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(OUT_PATH, checkpoint_path)
save_checkpoint({'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': current_step,
'total_loss': loss.data[0],
'linear_loss': linear_loss.data[0],
'mel_loss': mel_loss.data[0],
'date': datetime.date.today().strftime("%B %d, %Y")},
checkpoint_path)
print(" > Checkpoint is saved : {}".format(checkpoint_path))
if current_step in c.decay_step:
optimizer = adjust_learning_rate(optimizer, current_step)
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, linear_loss.data[0],
best_loss, OUT_PATH,
current_step, epoch)
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy()
gt_spec = linear_spec_var[0].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, dataset.ap)
gt_spec = plot_spectrogram(gt_spec, dataset.ap)
tb.add_image('Spec/Reconstruction', const_spec, current_step)
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_image('Attn/Alignment', align_img, current_step)
# Sample audio
audio_signal = linear_output[0].data.cpu().numpy()
dataset.ap.griffin_lim_iters = 60
audio_signal = dataset.ap.inv_spectrogram(audio_signal.T)
try:
tb.add_audio('SampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate)
except:
print("\n > Error at audio signal on TB!!")
print(audio_signal.max())
print(audio_signal.min())
def adjust_learning_rate(optimizer, step):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
if step == 500000:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.0005
# average loss after the epoch
avg_epoch_loss = np.mean(
progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1]))
best_loss = save_best_model(model, optimizer, avg_epoch_loss,
best_loss, OUT_PATH,
current_step, epoch)
elif step == 1000000:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.0003
elif step == 2000000:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.0001
return optimizer
#lr_scheduler.step(loss.data[0])
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
epoch_time = 0
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int,
help='Global step to restore checkpoint', default=128)
help='Global step to restore checkpoint', default=0)
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
args = parser.parse_args()

Двоичные данные
utils/.data.py.swp

Двоичный файл не отображается.

Двоичные данные
utils/.generic_utils.py.swp

Двоичный файл не отображается.

Просмотреть файл

@ -1,108 +1,124 @@
import os
import librosa
import pickle
import numpy as np
from scipy import signal
_mel_basis = None
def save_wav(wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.int16), c.sample_rate)
class AudioProcessor(object):
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
griffin_lim_iters=None):
self.sample_rate = sample_rate
self.num_mels = num_mels
self.min_level_db = min_level_db
self.frame_shift_ms = frame_shift_ms
self.frame_length_ms = frame_length_ms
self.preemphasis = preemphasis
self.ref_level_db = ref_level_db
self.num_freq = num_freq
self.power = power
self.griffin_lim_iters = griffin_lim_iters
def _linear_to_mel(spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def save_wav(self, wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
librosa.output.write_wav(path, wav.astype(np.int16), self.sample_rate)
def _build_mel_basis():
n_fft = (c.num_freq - 1) * 2
return librosa.filters.mel(c.sample_rate, n_fft, n_mels=c.num_mels)
def _linear_to_mel(self, spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = self._build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _normalize(S):
return np.clip((S - c.min_level_db) / -c.min_level_db, 0, 1)
def _build_mel_basis(self, ):
n_fft = (self.num_freq - 1) * 2
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
def _denormalize(S):
return (np.clip(S, 0, 1) * -c.min_level_db) + c.min_level_db
def _normalize(self, S):
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
def _stft_parameters():
n_fft = (c.num_freq - 1) * 2
hop_length = int(c.frame_shift_ms / 1000 * c.sample_rate)
win_length = int(c.frame_length_ms / 1000 * c.sample_rate)
return n_fft, hop_length, win_length
def _denormalize(self, S):
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
def _amp_to_db(x):
return 20 * np.log10(np.maximum(1e-5, x))
def _stft_parameters(self, ):
n_fft = (self.num_freq - 1) * 2
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
return n_fft, hop_length, win_length
def _db_to_amp(x):
return np.power(10.0, x * 0.05)
def _amp_to_db(self, x):
return 20 * np.log10(np.maximum(1e-5, x))
def preemphasis(x):
return signal.lfilter([1, -c.preemphasis], [1], x)
def _db_to_amp(self, x):
return np.power(10.0, x * 0.05)
def inv_preemphasis(x):
return signal.lfilter([1], [1, -c.preemphasis], x)
def apply_preemphasis(self, x):
return signal.lfilter([1, -self.preemphasis], [1], x)
def spectrogram(y):
D = _stft(preemphasis(y))
S = _amp_to_db(np.abs(D)) - c.ref_level_db
return _normalize(S)
def apply_inv_preemphasis(self, x):
return signal.lfilter([1], [1, -self.preemphasis], x)
def inv_spectrogram(spectrogram):
'''Converts spectrogram to waveform using librosa'''
S = _denormalize(spectrogram)
S = _db_to_amp(S + c.ref_level_db) # Convert back to linear
# Reconstruct phase
return inv_preemphasis(_griffin_lim(S ** c.power))
def spectrogram(self, y):
D = self._stft(self.apply_preemphasis(y))
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
return self._normalize(S)
def _griffin_lim(S):
'''librosa implementation of Griffin-Lim
Based on https://github.com/librosa/librosa/issues/434
'''
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
y = _istft(S_complex * angles)
for i in range(c.griffin_lim_iters):
angles = np.exp(1j * np.angle(_stft(y)))
y = _istft(S_complex * angles)
return y
def inv_spectrogram(self, spectrogram):
'''Converts spectrogram to waveform using librosa'''
S = self._denormalize(spectrogram)
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
# Reconstruct phase
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
def _istft(y):
_, hop_length, win_length = _stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
def _griffin_lim(self, S):
'''librosa implementation of Griffin-Lim
Based on https://github.com/librosa/librosa/issues/434
'''
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
y = self._istft(S_complex * angles)
for i in range(self.griffin_lim_iters):
angles = np.exp(1j * np.angle(self._stft(y)))
y = self._istft(S_complex * angles)
return y
def melspectrogram(y):
D = _stft(preemphasis(y))
S = _amp_to_db(_linear_to_mel(np.abs(D)))
return _normalize(S)
def melspectrogram(self, y):
D = self._stft(self.apply_preemphasis(y))
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
return self._normalize(S)
def _stft(y):
n_fft, hop_length, win_length = _stft_parameters()
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _stft(self, y):
n_fft, hop_length, win_length = self._stft_parameters()
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _istft(self, y):
_, hop_length, win_length = self._stft_parameters()
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(c.sample_rate * min_silence_sec)
hop_length = int(window_length / 4)
threshold = _db_to_amp(threshold_db)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x:x + window_length]) < threshold:
return x + hop_length
return len(wav)
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
window_length = int(self.sample_rate * min_silence_sec)
hop_length = int(window_length / 4)
threshold = self._db_to_amp(threshold_db)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x:x + window_length]) < threshold:
return x + hop_length
return len(wav)

Просмотреть файл

@ -3,7 +3,10 @@ import numpy as np
def pad_data(x, length):
_pad = 0
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
assert x.ndim == 1
return np.pad(x, (0, length - x.shape[0]),
mode='constant',
constant_values=_pad)
def prepare_data(inputs):
@ -11,8 +14,8 @@ def prepare_data(inputs):
return np.stack([pad_data(x, max_len) for x in inputs])
def pad_per_step(inputs, outputs_per_step):
def pad_per_step(inputs, pad_len):
timesteps = inputs.shape[-1]
return np.pad(inputs, [[0, 0], [0, 0],
[0, outputs_per_step - (timesteps % outputs_per_step)]],
[0, pad_len]],
mode='constant', constant_values=0.0)

Просмотреть файл

@ -5,6 +5,7 @@ import time
import shutil
import datetime
import json
import torch
import numpy as np
@ -34,8 +35,9 @@ def remove_experiment_folder(experiment_path):
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
if len(checkpoint_files) < 1:
shutil.rmtree(experiment_path)
print(" ! Run is removed from {}".format(experiment_path))
if os.path.exists(experiment_path):
shutil.rmtree(experiment_path)
print(" ! Run is removed from {}".format(experiment_path))
else:
print(" ! Run is kept in {}".format(experiment_path))
@ -46,10 +48,44 @@ def copy_config_file(config_file, path):
shutil.copyfile(config_file, out_path)
def save_checkpoint(state, filename='checkpoint.pth.tar'):
torch.save(state, filename)
def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch):
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print("\n | > Checkpoint saving : {}".format(checkpoint_path))
state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")}
torch.save(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch):
if model_loss < best_loss:
state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")}
best_loss = model_loss
bestmodel_path = 'best_model.pth.tar'
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n | > Best model saving with loss {0:.2f} : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
return best_loss
def lr_decay(init_lr, global_step):
warmup_steps = 4000.0
step = global_step + 1.
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
step**-0.5)
return lr
class Progbar(object):
"""Displays a progress bar.
# Arguments

Просмотреть файл

@ -1,8 +1,8 @@
#-*- coding: utf-8 -*-
import re
from Tacotron.utils.text import cleaners
from Tacotron.utils.text.symbols import symbols
from TTS.utils.text import cleaners
from TTS.utils.text.symbols import symbols
# Mappings from symbol to numeric ID and vice versa:

Просмотреть файл

@ -7,7 +7,7 @@ Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
'''
from Tacotron.utils.text import cmudict
from TTS.utils.text import cmudict
_pad = '_'
_eos = '~'

35
utils/visual.py Normal file
Просмотреть файл

@ -0,0 +1,35 @@
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def plot_alignment(alignment, info=None):
fig, ax = plt.subplots()
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_spectrogram(linear_output, audio):
spectrogram = audio._denormalize(linear_output)
fig = plt.figure(figsize=(16, 10))
plt.imshow(spectrogram.T, aspect="auto", origin="lower")
plt.colorbar()
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data