Merge branch 'split-train-val'

This commit is contained in:
Eren Golge 2018-03-08 03:06:48 -08:00
Родитель d79e49f302 b4032e8dff
Коммит 7f740d8e3f
9 изменённых файлов: 462 добавлений и 290 удалений

Просмотреть файл

@ -12,18 +12,18 @@
"text_cleaner": "english_cleaners",
"epochs": 2000,
"lr": 0.003,
"batch_size": 180,
"lr": 0.0006,
"warmup_steps": 4000,
"batch_size": 32,
"r": 5,
"griffin_lim_iters": 60,
"power": 1.5,
"num_loader_workers": 32,
"num_loader_workers": 16,
"checkpoint": false,
"save_step": 69,
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
"output_path": "result",
"log_dir": "/home/erogol/projects/TTS/logs/"
"data_path": "/run/shm/erogol/LJSpeech-1.0",
"output_path": "result"
}

Просмотреть файл

@ -1,4 +1,3 @@
import pandas as pd
import os
import numpy as np
import collections
@ -16,16 +15,18 @@ class LJSpeechDataset(Dataset):
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
text_cleaner, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power):
self.frames = pd.read_csv(csv_file, sep='|', header=None)
with open(csv_file, "r") as f:
self.frames = [line.split('|') for line in f]
self.root_dir = root_dir
self.outputs_per_step = outputs_per_step
self.sample_rate = sample_rate
self.cleaners = text_cleaner
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
frame_length_ms, preemphasis, ref_level_db, num_freq, power
)
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames()
def load_wav(self, filename):
try:
@ -34,22 +35,44 @@ class LJSpeechDataset(Dataset):
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def _sort_frames(self):
r"""Sort sequences in ascending order"""
lengths = np.array([len(ins[1]) for ins in self.frames])
print(" | > Max length sequence {}".format(np.max(lengths)))
print(" | > Min length sequence {}".format(np.min(lengths)))
print(" | > Avg length sequence {}".format(np.mean(lengths)))
idxs = np.argsort(lengths)
new_frames = [None] * len(lengths)
for i, idx in enumerate(idxs):
new_frames[i] = self.frames[idx]
self.frames = new_frames
def __len__(self):
return len(self.frames)
def __getitem__(self, idx):
wav_name = os.path.join(self.root_dir,
self.frames.ix[idx, 0]) + '.wav'
text = self.frames.ix[idx, 1]
self.frames[idx][0]) + '.wav'
text = self.frames[idx][1]
text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32)
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
sample = {'text': text, 'wav': wav, 'item_idx': self.frames.ix[idx, 0]}
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample
def get_dummy_data(self):
r"""Get a dummy input for testing"""
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. PAD sequences with the longest sequence in the batch
2. Convert Audio signal to Spectrograms.
3. PAD sequences that can be divided by r.
4. Convert Numpy to Torch tensors.
"""
# Puts each data field into a tensor with outer dimension batch size
if isinstance(batch[0], collections.Mapping):

Просмотреть файл

@ -5,26 +5,27 @@ from torch.nn import functional as F
class BahdanauAttention(nn.Module):
def __init__(self, dim):
def __init__(self, annot_dim, query_dim, hidden_dim):
super(BahdanauAttention, self).__init__()
self.query_layer = nn.Linear(dim, dim, bias=False)
self.tanh = nn.Tanh()
self.v = nn.Linear(dim, 1, bias=False)
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
self.v = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, query, processed_inputs):
def forward(self, annots, query):
"""
Args:
query: (batch, 1, dim) or (batch, dim)
processed_inputs: (batch, max_time, dim)
Shapes:
- query: (batch, 1, dim) or (batch, dim)
- annots: (batch, max_time, dim)
"""
if query.dim() == 2:
# insert time-axis for broadcasting
query = query.unsqueeze(1)
# (batch, 1, dim)
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annots)
# (batch, max_time, 1)
alignment = self.v(self.tanh(processed_query + processed_inputs))
alignment = self.v(nn.functional.tanh(processed_query + processed_annots))
# (batch, max_time)
return alignment.squeeze(-1)
@ -34,7 +35,7 @@ def get_mask_from_lengths(inputs, inputs_lengths):
"""Get mask tensor from list of length
Args:
inputs: (batch, max_time, dim)
inputs: Tensor in size (batch, max_time, dim)
inputs_lengths: array like
"""
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
@ -43,52 +44,48 @@ def get_mask_from_lengths(inputs, inputs_lengths):
return ~mask
class AttentionWrapper(nn.Module):
def __init__(self, rnn_cell, alignment_model,
class AttentionRNN(nn.Module):
def __init__(self, out_dim, annot_dim, memory_dim,
score_mask_value=-float("inf")):
super(AttentionWrapper, self).__init__()
self.rnn_cell = rnn_cell
self.alignment_model = alignment_model
super(AttentionRNN, self).__init__()
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, out_dim)
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
self.score_mask_value = score_mask_value
def forward(self, query, context_vec, cell_state, inputs,
processed_inputs=None, mask=None, inputs_lengths=None):
def forward(self, memory, context, rnn_state, annotations,
mask=None, annotations_lengths=None):
if processed_inputs is None:
processed_inputs = inputs
if inputs_lengths is not None and mask is None:
mask = get_mask_from_lengths(inputs, inputs_lengths)
if annotations_lengths is not None and mask is None:
mask = get_mask_from_lengths(annotations, annotations_lengths)
# Alignment
# (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j)
# import ipdb
# ipdb.set_trace()
alignment = self.alignment_model(cell_state, processed_inputs)
alignment = self.alignment_model(annotations, rnn_state)
# TODO: needs recheck.
if mask is not None:
mask = mask.view(query.size(0), -1)
alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize context_vec weight
# Normalize context weight
alignment = F.softmax(alignment, dim=-1)
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
context_vec = torch.bmm(alignment.unsqueeze(1), inputs)
context_vec = context_vec.squeeze(1)
context = torch.bmm(alignment.unsqueeze(1), annotations)
context = context.squeeze(1)
# Concat input query and previous context_vec context
cell_input = torch.cat((query, context_vec), -1)
#cell_input = cell_input.unsqueeze(1)
# Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1)
#rnn_input = rnn_input.unsqueeze(1)
# Feed it to RNN
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
cell_output = self.rnn_cell(cell_input, cell_state)
rnn_output = self.rnn_cell(rnn_input, rnn_state)
context_vec = context_vec.squeeze(1)
return cell_output, context_vec, alignment
context = context.squeeze(1)
return rnn_output, context, alignment

Просмотреть файл

@ -3,7 +3,7 @@ import torch
from torch.autograd import Variable
from torch import nn
from .attention import BahdanauAttention, AttentionWrapper
from .attention import AttentionRNN
from .attention import get_mask_from_lengths
class Prenet(nn.Module):
@ -153,7 +153,7 @@ class CBHG(nn.Module):
out = conv1d(x)
out = out[:, :, :T]
outs.append(out)
x = torch.cat(outs, dim=1)
assert x.size(1) == self.in_features * len(self.conv1d_banks)
@ -219,15 +219,10 @@ class Decoder(nn.Module):
self.memory_dim = memory_dim
self.eps = eps
self.r = r
# input -> |Linear| -> processed_inputs
self.input_layer = nn.Linear(in_features, 256, bias=False)
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
self.attention_rnn = AttentionWrapper(
nn.GRUCell(in_features + 128, 256),
BahdanauAttention(256)
)
self.attention_rnn = AttentionRNN(256, in_features, 128)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
@ -236,7 +231,7 @@ class Decoder(nn.Module):
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r)
def forward(self, inputs, memory=None, memory_lengths=None):
def forward(self, inputs, memory=None, input_lengths=None):
r"""
Decoder forward step.
@ -245,9 +240,9 @@ class Decoder(nn.Module):
Args:
inputs: Encoder outputs.
memory: Decoder memory (autoregression. If None (at eval-time),
memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs.
memory_lengths: Encoder output (memory) lengths. If not None, used for
input_lengths (None): input lengths, used for
attention masking.
Shapes:
@ -256,12 +251,11 @@ class Decoder(nn.Module):
"""
B = inputs.size(0)
# TODO: take this segment into Attention module.
processed_inputs = self.input_layer(inputs)
if memory_lengths is not None:
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
else:
mask = None
# if input_lengths is not None:
# mask = get_mask_from_lengths(processed_inputs, input_lengths)
# else:
# mask = None
# Run greedy decoding if memory is None
greedy = memory is None
@ -301,13 +295,14 @@ class Decoder(nn.Module):
while True:
if t > 0:
memory_input = outputs[-1] if greedy else memory[t - 1]
# Prenet
processed_memory = self.prenet(memory_input)
# Attention RNN
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, processed_inputs=processed_inputs, mask=mask)
inputs)
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
@ -350,5 +345,5 @@ class Decoder(nn.Module):
return outputs, alignments
def is_end_of_frames(output, eps=0.1): #0.2
def is_end_of_frames(output, eps=0.2): #0.2
return (output.data <= eps).all()

Просмотреть файл

@ -9,11 +9,11 @@ from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
freq_dim=1025, r=5, padding_idx=None,
use_memory_mask=False):
use_atten_mask=False):
super(Tacotron, self).__init__()
self.mel_dim = mel_dim
self.linear_dim = linear_dim
self.use_memory_mask = use_memory_mask
self.use_atten_mask = use_atten_mask
self.embedding = nn.Embedding(len(symbols), embedding_dim,
padding_idx=padding_idx)
print(" | > Embedding dim : {}".format(len(symbols)))
@ -33,13 +33,12 @@ class Tacotron(nn.Module):
# (B, T', in_dim)
encoder_outputs = self.encoder(inputs)
if self.use_memory_mask:
memory_lengths = input_lengths
else:
memory_lengths = None
if not self.use_atten_mask:
input_lengths = None
# (B, T', mel_dim*r)
mel_outputs, alignments = self.decoder(
encoder_outputs, mel_specs, memory_lengths=memory_lengths)
encoder_outputs, mel_specs, input_lengths=input_lengths)
# Post net processing below

Просмотреть файл

@ -0,0 +1,38 @@
import unittest
import torch as T
from TTS.utils.generic_utils import save_checkpoint, save_best_model
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
OUT_PATH = '/tmp/test.pth.tar'
class ModelSavingTests(unittest.TestCase):
def save_checkpoint_test(self):
# create a dummy model
model = Prenet(128, out_features=[256, 128])
model = T.nn.DataParallel(layer)
# save the model
save_checkpoint(model, None, 100,
OUTPATH, 1, 1)
# load the model to CPU
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
loc: storage)
model.load_state_dict(model_dict['model'])
def save_best_model_test(self):
# create a dummy model
model = Prenet(256, out_features=[256, 256])
model = T.nn.DataParallel(layer)
# save the model
best_loss = save_best_model(model, None, 0,
100, OUT_PATH,
10, 1)
# load the model to CPU
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
loc: storage)
model.load_state_dict(model_dict['model'])

522
train.py
Просмотреть файл

@ -21,42 +21,275 @@ from tensorboardX import SummaryWriter
from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay,
count_parameters)
count_parameters, check_update)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
use_cuda = torch.cuda.is_available()
parser = argparse.ArgumentParser()
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
args = parser.parse_args()
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# save config to tmp place to be loaded by subsequent modules.
file_name = str(os.getpid())
tmp_path = os.path.join("/tmp/", file_name+'_tts')
pickle.dump(c, open(tmp_path, "wb"))
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
def signal_handler(signal, frame):
"""Ctrl+C handler to remove empty experiment folder"""
print(" !! Pressed Ctrl+C !!")
remove_experiment_folder(OUT_PATH)
sys.exit(1)
def train(model, criterion, data_loader, optimizer, epoch):
model = model.train()
epoch_time = 0
avg_linear_loss = 0
avg_mel_loss = 0
print(" | > Epoch {}/{}".format(epoch, c.epochs))
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# setup input data
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
# setup lr
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
for params_group in optimizer.param_groups:
params_group['lr'] = current_lr
optimizer.zero_grad()
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
linear_spec_var = Variable(linear_input, volatile=True)
# sort sequence by length for curriculum learning
# TODO: might be unnecessary
sorted_lengths, indices = torch.sort(
text_lengths.view(-1), dim=0, descending=True)
sorted_lengths = sorted_lengths.long().numpy()
text_input_var = text_input_var[indices]
mel_spec_var = mel_spec_var[indices]
linear_spec_var = linear_spec_var[indices]
# dispatch data to GPU
if use_cuda:
text_input_var = text_input_var.cuda()
mel_spec_var = mel_spec_var.cuda()
linear_spec_var = linear_spec_var.cuda()
# forward pass
mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_spec_var,
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
# loss computation
mel_loss = criterion(mel_output, mel_spec_var)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq])
loss = mel_loss + linear_loss
# backpass and check the grad norm
loss.backward()
grad_norm, skip_flag = check_update(model, 0.5, 100)
if skip_flag:
optimizer.zero_grad()
print(" | > Iteration skipped!!")
continue
optimizer.step()
step_time = time.time() - start_time
epoch_time += step_time
# update
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0]),
('grad_norm', grad_norm)])
# Plot Training Iter Stats
tb.add_scalar('TrainIterLoss/TotalLoss', loss.data[0], current_step)
tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0],
current_step)
tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step)
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
current_step)
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
tb.add_scalar('Time/StepTime', step_time, current_step)
if current_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, linear_loss.data[0],
OUT_PATH, current_step, epoch)
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy()
gt_spec = linear_spec_var[0].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
tb.add_image('Visual/Reconstruction', const_spec, current_step)
tb.add_image('Visual/GroundTruth', gt_spec, current_step)
align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_image('Visual/Alignment', align_img, current_step)
# Sample audio
audio_signal = linear_output[0].data.cpu().numpy()
data_loader.dataset.ap.griffin_lim_iters = 60
audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T)
try:
tb.add_audio('SampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate)
except:
print("\n > Error at audio signal on TB!!")
print(audio_signal.max())
print(audio_signal.min())
avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss
# Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step)
tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step)
tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step)
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
epoch_time = 0
return avg_linear_loss, current_step
def evaluate(model, criterion, data_loader, current_step):
model = model.train()
epoch_time = 0
print(" | > Validation")
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
avg_linear_loss = 0
avg_mel_loss = 0
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# setup input data
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
linear_spec_var = Variable(linear_input, volatile=True)
# dispatch data to GPU
if use_cuda:
text_input_var = text_input_var.cuda()
mel_spec_var = mel_spec_var.cuda()
linear_spec_var = linear_spec_var.cuda()
# forward pass
mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_spec_var)
# loss computation
mel_loss = criterion(mel_output, mel_spec_var)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq])
loss = mel_loss + linear_loss
step_time = time.time() - start_time
epoch_time += step_time
# update
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0])])
avg_linear_loss += linear_loss.data[0]
avg_mel_loss += mel_loss.data[0]
# Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0])
const_spec = linear_output[idx].data.cpu().numpy()
gt_spec = linear_spec_var[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
align_img = plot_alignment(align_img)
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step)
# Sample audio
audio_signal = linear_output[idx].data.cpu().numpy()
data_loader.dataset.ap.griffin_lim_iters = 60
audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T)
try:
tb.add_audio('ValSampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate)
except:
print(" | > Error at audio signal on TB!!")
print(audio_signal.max())
print(audio_signal.min())
# compute average losses
avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss
# Plot Learning Stats
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
return avg_linear_loss
def main(args):
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# save config to tmp place to be loaded by subsequent modules.
file_name = str(os.getpid())
tmp_path = os.path.join("/tmp/", file_name+'_tts')
pickle.dump(c, open(tmp_path, "wb"))
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
# Ctrl+C handler to remove empty experiment folder
def signal_handler(signal, frame):
print(" !! Pressed Ctrl+C !!")
remove_experiment_folder(OUT_PATH)
sys.exit(1)
signal.signal(signal.SIGINT, signal_handler)
# Setup the dataset
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
os.path.join(c.data_path, 'wavs'),
c.r,
c.sample_rate,
@ -71,204 +304,77 @@ def main(args):
c.power
)
dataloader = DataLoader(dataset, batch_size=c.batch_size,
shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=c.num_loader_workers)
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
shuffle=False, collate_fn=train_dataset.collate_fn,
drop_last=False, num_workers=c.num_loader_workers,
pin_memory=True)
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
os.path.join(c.data_path, 'wavs'),
c.r,
c.sample_rate,
c.text_cleaner,
c.num_mels,
c.min_level_db,
c.frame_shift_ms,
c.frame_length_ms,
c.preemphasis,
c.ref_level_db,
c.num_freq,
c.power
)
val_loader = DataLoader(val_dataset, batch_size=c.batch_size,
shuffle=False, collate_fn=val_dataset.collate_fn,
drop_last=False, num_workers= 4,
pin_memory=True)
# setup the model
model = Tacotron(c.embedding_size,
c.hidden_size,
c.num_mels,
c.num_freq,
c.r)
# plot model on tensorboard
dummy_input = dataset.get_dummy_data()
## TODO: onnx does not support RNN fully yet
# model_proto_path = os.path.join(OUT_PATH, "model.proto")
# onnx.export(model, dummy_input, model_proto_path, verbose=True)
# tb.add_graph_onnx(model_proto_path)
if use_cuda:
model = nn.DataParallel(model.cuda())
c.r,
use_atten_mask=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr)
if args.restore_step:
checkpoint = torch.load(os.path.join(
args.restore_path, 'checkpoint_%d.pth.tar' % args.restore_step))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % args.restore_step)
start_epoch = checkpoint['step'] // len(dataloader)
best_loss = checkpoint['linear_loss']
else:
start_epoch = 0
print("\n > Starting a new training")
num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params))
model = model.train()
if not os.path.exists(CHECKPOINT_PATH):
os.mkdir(CHECKPOINT_PATH)
if use_cuda:
criterion = nn.L1Loss().cuda()
else:
criterion = nn.L1Loss()
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
if args.restore_path:
checkpoint = torch.load(args.restore_path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % checkpoint['step'])
start_epoch = checkpoint['step'] // len(train_loader)
best_loss = checkpoint['linear_loss']
start_epoch = 0
args.restore_step = checkpoint['step']
else:
args.restore_step = 0
print("\n > Starting a new training")
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
# patience=c.lr_patience, verbose=True)
epoch_time = 0
best_loss = float('inf')
if use_cuda:
model = nn.DataParallel(model.cuda())
num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params))
if not os.path.exists(CHECKPOINT_PATH):
os.mkdir(CHECKPOINT_PATH)
if 'best_loss' not in locals():
best_loss = float('inf')
for epoch in range(0, c.epochs):
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
progbar = Progbar(len(dataset) / c.batch_size)
for num_iter, data in enumerate(dataloader):
start_time = time.time()
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
# setup lr
current_lr = lr_decay(c.lr, current_step)
for params_group in optimizer.param_groups:
params_group['lr'] = current_lr
optimizer.zero_grad()
# Add a single frame of zeros to Mel Specs for better end detection
#try:
# mel_input = np.concatenate((np.zeros(
# [c.batch_size, 1, c.num_mels], dtype=np.float32),
# mel_input[:, 1:, :]), axis=1)
#except:
# raise TypeError("not same dimension")
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
linear_spec_var = Variable(linear_input, volatile=True)
# sort sequence by length.
# TODO: might be unnecessary
sorted_lengths, indices = torch.sort(
text_lengths.view(-1), dim=0, descending=True)
sorted_lengths = sorted_lengths.long().numpy()
text_input_var = text_input_var[indices]
mel_spec_var = mel_spec_var[indices]
linear_spec_var = linear_spec_var[indices]
if use_cuda:
text_input_var = text_input_var.cuda()
mel_spec_var = mel_spec_var.cuda()
linear_spec_var = linear_spec_var.cuda()
mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_spec_var,
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
mel_loss = criterion(mel_output, mel_spec_var)
#linear_loss = torch.abs(linear_output - linear_spec_var)
#linear_loss = 0.5 * \
#torch.mean(linear_loss) + 0.5 * \
#torch.mean(linear_loss[:, :n_priority_freq, :])
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[: ,: ,:n_priority_freq])
loss = mel_loss + linear_loss
# loss = loss.cuda()
loss.backward()
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need
optimizer.step()
step_time = time.time() - start_time
epoch_time += step_time
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
('linear_loss', linear_loss.data[0]),
('mel_loss', mel_loss.data[0]),
('grad_norm', grad_norm)])
# Plot Learning Stats
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
current_step)
tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step)
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
current_step)
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
tb.add_scalar('Time/StepTime', step_time, current_step)
align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_image('Attn/Alignment', align_img, current_step)
if current_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, linear_loss.data[0],
OUT_PATH, current_step, epoch)
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy()
gt_spec = linear_spec_var[0].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, dataset.ap)
gt_spec = plot_spectrogram(gt_spec, dataset.ap)
tb.add_image('Spec/Reconstruction', const_spec, current_step)
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
align_img = alignments[0].data.cpu().numpy()
align_img = plot_alignment(align_img)
tb.add_image('Attn/Alignment', align_img, current_step)
# Sample audio
audio_signal = linear_output[0].data.cpu().numpy()
dataset.ap.griffin_lim_iters = 60
audio_signal = dataset.ap.inv_spectrogram(audio_signal.T)
try:
tb.add_audio('SampleAudio', audio_signal, current_step,
sample_rate=c.sample_rate)
except:
print("\n > Error at audio signal on TB!!")
print(audio_signal.max())
print(audio_signal.min())
# average loss after the epoch
avg_epoch_loss = np.mean(
progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1]))
best_loss = save_best_model(model, optimizer, avg_epoch_loss,
train_loss, current_step = train(model, criterion, train_loader, optimizer, epoch)
val_loss = evaluate(model, criterion, val_loader, current_step)
best_loss = save_best_model(model, optimizer, val_loss,
best_loss, OUT_PATH,
current_step, epoch)
#lr_scheduler.step(loss.data[0])
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
epoch_time = 0
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int,
help='Global step to restore checkpoint', default=0)
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
args = parser.parse_args()
signal.signal(signal.SIGINT, signal_handler)
main(args)

Просмотреть файл

@ -7,6 +7,7 @@ import datetime
import json
import torch
import numpy as np
from collections import OrderedDict
class AttrDict(dict):
@ -94,8 +95,21 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
return best_loss
def lr_decay(init_lr, global_step):
warmup_steps = 4000.0
def check_update(model, grad_clip, grad_top):
r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False
grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), grad_clip)
if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
elif grad_norm > grad_top:
print(" | > Gradient is above the top limit !!")
skip_flag = True
return grad_norm, skip_flag
def lr_decay(init_lr, global_step, warmup_steps):
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
step = global_step + 1.
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
step**-0.5)
@ -197,13 +211,13 @@ class Progbar(object):
eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format
if time_per_unit >= 1:
info += ' %.0fs/step' % time_per_unit
elif time_per_unit >= 1e-3:
info += ' %.0fms/step' % (time_per_unit * 1e3)
else:
if time_per_unit >= 1:
info += ' %.0fs/step' % time_per_unit
elif time_per_unit >= 1e-3:
info += ' %.0fms/step' % (time_per_unit * 1e3)
else:
info += ' %.0fus/step' % (time_per_unit * 1e6)
info += ' %.0fus/step' % (time_per_unit * 1e6)
for k in self.unique_values:
info += ' - %s:' % k

Просмотреть файл

@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
def plot_alignment(alignment, info=None):
fig, ax = plt.subplots()
fig, ax = plt.subplots(figsize=(16,10))
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)