зеркало из https://github.com/mozilla/TTS.git
Merge branch 'split-train-val'
This commit is contained in:
Коммит
7f740d8e3f
12
config.json
12
config.json
|
@ -12,18 +12,18 @@
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 2000,
|
"epochs": 2000,
|
||||||
"lr": 0.003,
|
"lr": 0.0006,
|
||||||
"batch_size": 180,
|
"warmup_steps": 4000,
|
||||||
|
"batch_size": 32,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
"power": 1.5,
|
"power": 1.5,
|
||||||
|
|
||||||
"num_loader_workers": 32,
|
"num_loader_workers": 16,
|
||||||
|
|
||||||
"checkpoint": false,
|
"checkpoint": false,
|
||||||
"save_step": 69,
|
"save_step": 69,
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
"data_path": "/run/shm/erogol/LJSpeech-1.0",
|
||||||
"output_path": "result",
|
"output_path": "result"
|
||||||
"log_dir": "/home/erogol/projects/TTS/logs/"
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import pandas as pd
|
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
import collections
|
||||||
|
@ -16,16 +15,18 @@ class LJSpeechDataset(Dataset):
|
||||||
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
||||||
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power):
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power):
|
||||||
self.frames = pd.read_csv(csv_file, sep='|', header=None)
|
|
||||||
|
with open(csv_file, "r") as f:
|
||||||
|
self.frames = [line.split('|') for line in f]
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.cleaners = text_cleaner
|
self.cleaners = text_cleaner
|
||||||
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||||
frame_length_ms, preemphasis, ref_level_db, num_freq, power
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power)
|
||||||
)
|
|
||||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
|
self._sort_frames()
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
try:
|
try:
|
||||||
|
@ -34,22 +35,44 @@ class LJSpeechDataset(Dataset):
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(" !! Cannot read file : {}".format(filename))
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
|
||||||
|
def _sort_frames(self):
|
||||||
|
r"""Sort sequences in ascending order"""
|
||||||
|
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||||
|
|
||||||
|
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||||
|
print(" | > Min length sequence {}".format(np.min(lengths)))
|
||||||
|
print(" | > Avg length sequence {}".format(np.mean(lengths)))
|
||||||
|
|
||||||
|
idxs = np.argsort(lengths)
|
||||||
|
new_frames = [None] * len(lengths)
|
||||||
|
for i, idx in enumerate(idxs):
|
||||||
|
new_frames[i] = self.frames[idx]
|
||||||
|
self.frames = new_frames
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.frames)
|
return len(self.frames)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
wav_name = os.path.join(self.root_dir,
|
wav_name = os.path.join(self.root_dir,
|
||||||
self.frames.ix[idx, 0]) + '.wav'
|
self.frames[idx][0]) + '.wav'
|
||||||
text = self.frames.ix[idx, 1]
|
text = self.frames[idx][1]
|
||||||
text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||||
sample = {'text': text, 'wav': wav, 'item_idx': self.frames.ix[idx, 0]}
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def get_dummy_data(self):
|
def get_dummy_data(self):
|
||||||
|
r"""Get a dummy input for testing"""
|
||||||
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
|
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
|
r"""
|
||||||
|
Perform preprocessing and create a final data batch:
|
||||||
|
1. PAD sequences with the longest sequence in the batch
|
||||||
|
2. Convert Audio signal to Spectrograms.
|
||||||
|
3. PAD sequences that can be divided by r.
|
||||||
|
4. Convert Numpy to Torch tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
# Puts each data field into a tensor with outer dimension batch size
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
if isinstance(batch[0], collections.Mapping):
|
if isinstance(batch[0], collections.Mapping):
|
||||||
|
|
|
@ -5,26 +5,27 @@ from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
class BahdanauAttention(nn.Module):
|
class BahdanauAttention(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, annot_dim, query_dim, hidden_dim):
|
||||||
super(BahdanauAttention, self).__init__()
|
super(BahdanauAttention, self).__init__()
|
||||||
self.query_layer = nn.Linear(dim, dim, bias=False)
|
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
|
||||||
self.tanh = nn.Tanh()
|
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
||||||
self.v = nn.Linear(dim, 1, bias=False)
|
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, query, processed_inputs):
|
def forward(self, annots, query):
|
||||||
"""
|
"""
|
||||||
Args:
|
Shapes:
|
||||||
query: (batch, 1, dim) or (batch, dim)
|
- query: (batch, 1, dim) or (batch, dim)
|
||||||
processed_inputs: (batch, max_time, dim)
|
- annots: (batch, max_time, dim)
|
||||||
"""
|
"""
|
||||||
if query.dim() == 2:
|
if query.dim() == 2:
|
||||||
# insert time-axis for broadcasting
|
# insert time-axis for broadcasting
|
||||||
query = query.unsqueeze(1)
|
query = query.unsqueeze(1)
|
||||||
# (batch, 1, dim)
|
# (batch, 1, dim)
|
||||||
processed_query = self.query_layer(query)
|
processed_query = self.query_layer(query)
|
||||||
|
processed_annots = self.annot_layer(annots)
|
||||||
|
|
||||||
# (batch, max_time, 1)
|
# (batch, max_time, 1)
|
||||||
alignment = self.v(self.tanh(processed_query + processed_inputs))
|
alignment = self.v(nn.functional.tanh(processed_query + processed_annots))
|
||||||
|
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
@ -34,7 +35,7 @@ def get_mask_from_lengths(inputs, inputs_lengths):
|
||||||
"""Get mask tensor from list of length
|
"""Get mask tensor from list of length
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: (batch, max_time, dim)
|
inputs: Tensor in size (batch, max_time, dim)
|
||||||
inputs_lengths: array like
|
inputs_lengths: array like
|
||||||
"""
|
"""
|
||||||
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
|
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
|
||||||
|
@ -43,52 +44,48 @@ def get_mask_from_lengths(inputs, inputs_lengths):
|
||||||
return ~mask
|
return ~mask
|
||||||
|
|
||||||
|
|
||||||
class AttentionWrapper(nn.Module):
|
class AttentionRNN(nn.Module):
|
||||||
def __init__(self, rnn_cell, alignment_model,
|
def __init__(self, out_dim, annot_dim, memory_dim,
|
||||||
score_mask_value=-float("inf")):
|
score_mask_value=-float("inf")):
|
||||||
super(AttentionWrapper, self).__init__()
|
super(AttentionRNN, self).__init__()
|
||||||
self.rnn_cell = rnn_cell
|
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, out_dim)
|
||||||
self.alignment_model = alignment_model
|
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||||
self.score_mask_value = score_mask_value
|
self.score_mask_value = score_mask_value
|
||||||
|
|
||||||
def forward(self, query, context_vec, cell_state, inputs,
|
def forward(self, memory, context, rnn_state, annotations,
|
||||||
processed_inputs=None, mask=None, inputs_lengths=None):
|
mask=None, annotations_lengths=None):
|
||||||
|
|
||||||
if processed_inputs is None:
|
if annotations_lengths is not None and mask is None:
|
||||||
processed_inputs = inputs
|
mask = get_mask_from_lengths(annotations, annotations_lengths)
|
||||||
|
|
||||||
if inputs_lengths is not None and mask is None:
|
|
||||||
mask = get_mask_from_lengths(inputs, inputs_lengths)
|
|
||||||
|
|
||||||
# Alignment
|
# Alignment
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
# e_{ij} = a(s_{i-1}, h_j)
|
# e_{ij} = a(s_{i-1}, h_j)
|
||||||
# import ipdb
|
alignment = self.alignment_model(annotations, rnn_state)
|
||||||
# ipdb.set_trace()
|
|
||||||
alignment = self.alignment_model(cell_state, processed_inputs)
|
|
||||||
|
|
||||||
|
# TODO: needs recheck.
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.view(query.size(0), -1)
|
mask = mask.view(query.size(0), -1)
|
||||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||||
|
|
||||||
# Normalize context_vec weight
|
# Normalize context weight
|
||||||
alignment = F.softmax(alignment, dim=-1)
|
alignment = F.softmax(alignment, dim=-1)
|
||||||
|
|
||||||
# Attention context vector
|
# Attention context vector
|
||||||
# (batch, 1, dim)
|
# (batch, 1, dim)
|
||||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||||
context_vec = torch.bmm(alignment.unsqueeze(1), inputs)
|
context = torch.bmm(alignment.unsqueeze(1), annotations)
|
||||||
context_vec = context_vec.squeeze(1)
|
context = context.squeeze(1)
|
||||||
|
|
||||||
# Concat input query and previous context_vec context
|
# Concat input query and previous context context
|
||||||
cell_input = torch.cat((query, context_vec), -1)
|
rnn_input = torch.cat((memory, context), -1)
|
||||||
#cell_input = cell_input.unsqueeze(1)
|
#rnn_input = rnn_input.unsqueeze(1)
|
||||||
|
|
||||||
# Feed it to RNN
|
# Feed it to RNN
|
||||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||||
cell_output = self.rnn_cell(cell_input, cell_state)
|
rnn_output = self.rnn_cell(rnn_input, rnn_state)
|
||||||
|
|
||||||
context_vec = context_vec.squeeze(1)
|
context = context.squeeze(1)
|
||||||
return cell_output, context_vec, alignment
|
return rnn_output, context, alignment
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .attention import BahdanauAttention, AttentionWrapper
|
from .attention import AttentionRNN
|
||||||
from .attention import get_mask_from_lengths
|
from .attention import get_mask_from_lengths
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
|
@ -153,7 +153,7 @@ class CBHG(nn.Module):
|
||||||
out = conv1d(x)
|
out = conv1d(x)
|
||||||
out = out[:, :, :T]
|
out = out[:, :, :T]
|
||||||
outs.append(out)
|
outs.append(out)
|
||||||
|
|
||||||
x = torch.cat(outs, dim=1)
|
x = torch.cat(outs, dim=1)
|
||||||
assert x.size(1) == self.in_features * len(self.conv1d_banks)
|
assert x.size(1) == self.in_features * len(self.conv1d_banks)
|
||||||
|
|
||||||
|
@ -219,15 +219,10 @@ class Decoder(nn.Module):
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.r = r
|
self.r = r
|
||||||
# input -> |Linear| -> processed_inputs
|
|
||||||
self.input_layer = nn.Linear(in_features, 256, bias=False)
|
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
||||||
self.attention_rnn = AttentionWrapper(
|
self.attention_rnn = AttentionRNN(256, in_features, 128)
|
||||||
nn.GRUCell(in_features + 128, 256),
|
|
||||||
BahdanauAttention(256)
|
|
||||||
)
|
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
@ -236,7 +231,7 @@ class Decoder(nn.Module):
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
|
|
||||||
def forward(self, inputs, memory=None, memory_lengths=None):
|
def forward(self, inputs, memory=None, input_lengths=None):
|
||||||
r"""
|
r"""
|
||||||
Decoder forward step.
|
Decoder forward step.
|
||||||
|
|
||||||
|
@ -245,9 +240,9 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: Encoder outputs.
|
inputs: Encoder outputs.
|
||||||
memory: Decoder memory (autoregression. If None (at eval-time),
|
memory (None): Decoder memory (autoregression. If None (at eval-time),
|
||||||
decoder outputs are used as decoder inputs.
|
decoder outputs are used as decoder inputs.
|
||||||
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
input_lengths (None): input lengths, used for
|
||||||
attention masking.
|
attention masking.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -256,12 +251,11 @@ class Decoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
|
|
||||||
# TODO: take this segment into Attention module.
|
|
||||||
processed_inputs = self.input_layer(inputs)
|
# if input_lengths is not None:
|
||||||
if memory_lengths is not None:
|
# mask = get_mask_from_lengths(processed_inputs, input_lengths)
|
||||||
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
|
# else:
|
||||||
else:
|
# mask = None
|
||||||
mask = None
|
|
||||||
|
|
||||||
# Run greedy decoding if memory is None
|
# Run greedy decoding if memory is None
|
||||||
greedy = memory is None
|
greedy = memory is None
|
||||||
|
@ -301,13 +295,14 @@ class Decoder(nn.Module):
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
memory_input = outputs[-1] if greedy else memory[t - 1]
|
memory_input = outputs[-1] if greedy else memory[t - 1]
|
||||||
|
|
||||||
# Prenet
|
# Prenet
|
||||||
processed_memory = self.prenet(memory_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
|
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
||||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||||
inputs, processed_inputs=processed_inputs, mask=mask)
|
inputs)
|
||||||
|
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
|
@ -350,5 +345,5 @@ class Decoder(nn.Module):
|
||||||
return outputs, alignments
|
return outputs, alignments
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, eps=0.1): #0.2
|
def is_end_of_frames(output, eps=0.2): #0.2
|
||||||
return (output.data <= eps).all()
|
return (output.data <= eps).all()
|
||||||
|
|
|
@ -9,11 +9,11 @@ from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
||||||
class Tacotron(nn.Module):
|
class Tacotron(nn.Module):
|
||||||
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
||||||
freq_dim=1025, r=5, padding_idx=None,
|
freq_dim=1025, r=5, padding_idx=None,
|
||||||
use_memory_mask=False):
|
use_atten_mask=False):
|
||||||
super(Tacotron, self).__init__()
|
super(Tacotron, self).__init__()
|
||||||
self.mel_dim = mel_dim
|
self.mel_dim = mel_dim
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.use_memory_mask = use_memory_mask
|
self.use_atten_mask = use_atten_mask
|
||||||
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
||||||
padding_idx=padding_idx)
|
padding_idx=padding_idx)
|
||||||
print(" | > Embedding dim : {}".format(len(symbols)))
|
print(" | > Embedding dim : {}".format(len(symbols)))
|
||||||
|
@ -33,13 +33,12 @@ class Tacotron(nn.Module):
|
||||||
# (B, T', in_dim)
|
# (B, T', in_dim)
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
|
||||||
if self.use_memory_mask:
|
if not self.use_atten_mask:
|
||||||
memory_lengths = input_lengths
|
input_lengths = None
|
||||||
else:
|
|
||||||
memory_lengths = None
|
|
||||||
# (B, T', mel_dim*r)
|
# (B, T', mel_dim*r)
|
||||||
mel_outputs, alignments = self.decoder(
|
mel_outputs, alignments = self.decoder(
|
||||||
encoder_outputs, mel_specs, memory_lengths=memory_lengths)
|
encoder_outputs, mel_specs, input_lengths=input_lengths)
|
||||||
|
|
||||||
# Post net processing below
|
# Post net processing below
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
import unittest
|
||||||
|
import torch as T
|
||||||
|
|
||||||
|
from TTS.utils.generic_utils import save_checkpoint, save_best_model
|
||||||
|
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
||||||
|
|
||||||
|
OUT_PATH = '/tmp/test.pth.tar'
|
||||||
|
|
||||||
|
class ModelSavingTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def save_checkpoint_test(self):
|
||||||
|
# create a dummy model
|
||||||
|
model = Prenet(128, out_features=[256, 128])
|
||||||
|
model = T.nn.DataParallel(layer)
|
||||||
|
|
||||||
|
# save the model
|
||||||
|
save_checkpoint(model, None, 100,
|
||||||
|
OUTPATH, 1, 1)
|
||||||
|
|
||||||
|
# load the model to CPU
|
||||||
|
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
|
||||||
|
loc: storage)
|
||||||
|
model.load_state_dict(model_dict['model'])
|
||||||
|
|
||||||
|
def save_best_model_test(self):
|
||||||
|
# create a dummy model
|
||||||
|
model = Prenet(256, out_features=[256, 256])
|
||||||
|
model = T.nn.DataParallel(layer)
|
||||||
|
|
||||||
|
# save the model
|
||||||
|
best_loss = save_best_model(model, None, 0,
|
||||||
|
100, OUT_PATH,
|
||||||
|
10, 1)
|
||||||
|
|
||||||
|
# load the model to CPU
|
||||||
|
model_dict = torch.load(MODEL_PATH, map_location=lambda storage,
|
||||||
|
loc: storage)
|
||||||
|
model.load_state_dict(model_dict['model'])
|
522
train.py
522
train.py
|
@ -21,42 +21,275 @@ from tensorboardX import SummaryWriter
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
save_best_model, load_config, lr_decay,
|
save_best_model, load_config, lr_decay,
|
||||||
count_parameters)
|
count_parameters, check_update)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
|
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--restore_path', type=str,
|
||||||
|
help='Folder path to checkpoints', default=0)
|
||||||
|
parser.add_argument('--config_path', type=str,
|
||||||
|
help='path to config file for training',)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# setup output paths and read configs
|
||||||
|
c = load_config(args.config_path)
|
||||||
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
OUT_PATH = os.path.join(_, c.output_path)
|
||||||
|
OUT_PATH = create_experiment_folder(OUT_PATH)
|
||||||
|
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
||||||
|
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
||||||
|
|
||||||
|
# save config to tmp place to be loaded by subsequent modules.
|
||||||
|
file_name = str(os.getpid())
|
||||||
|
tmp_path = os.path.join("/tmp/", file_name+'_tts')
|
||||||
|
pickle.dump(c, open(tmp_path, "wb"))
|
||||||
|
|
||||||
|
# setup tensorboard
|
||||||
|
LOG_DIR = OUT_PATH
|
||||||
|
tb = SummaryWriter(LOG_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(signal, frame):
|
||||||
|
"""Ctrl+C handler to remove empty experiment folder"""
|
||||||
|
print(" !! Pressed Ctrl+C !!")
|
||||||
|
remove_experiment_folder(OUT_PATH)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
|
model = model.train()
|
||||||
|
epoch_time = 0
|
||||||
|
avg_linear_loss = 0
|
||||||
|
avg_mel_loss = 0
|
||||||
|
|
||||||
|
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
|
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||||
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
|
for num_iter, data in enumerate(data_loader):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# setup input data
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
|
||||||
|
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
|
||||||
|
|
||||||
|
# setup lr
|
||||||
|
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||||
|
for params_group in optimizer.param_groups:
|
||||||
|
params_group['lr'] = current_lr
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# convert inputs to variables
|
||||||
|
text_input_var = Variable(text_input)
|
||||||
|
mel_spec_var = Variable(mel_input)
|
||||||
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
|
# sort sequence by length for curriculum learning
|
||||||
|
# TODO: might be unnecessary
|
||||||
|
sorted_lengths, indices = torch.sort(
|
||||||
|
text_lengths.view(-1), dim=0, descending=True)
|
||||||
|
sorted_lengths = sorted_lengths.long().numpy()
|
||||||
|
text_input_var = text_input_var[indices]
|
||||||
|
mel_spec_var = mel_spec_var[indices]
|
||||||
|
linear_spec_var = linear_spec_var[indices]
|
||||||
|
|
||||||
|
# dispatch data to GPU
|
||||||
|
if use_cuda:
|
||||||
|
text_input_var = text_input_var.cuda()
|
||||||
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
mel_output, linear_output, alignments =\
|
||||||
|
model.forward(text_input_var, mel_spec_var,
|
||||||
|
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
|
||||||
|
|
||||||
|
# loss computation
|
||||||
|
mel_loss = criterion(mel_output, mel_spec_var)
|
||||||
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
||||||
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
|
linear_spec_var[: ,: ,:n_priority_freq])
|
||||||
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
|
# backpass and check the grad norm
|
||||||
|
loss.backward()
|
||||||
|
grad_norm, skip_flag = check_update(model, 0.5, 100)
|
||||||
|
if skip_flag:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
print(" | > Iteration skipped!!")
|
||||||
|
continue
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
epoch_time += step_time
|
||||||
|
|
||||||
|
# update
|
||||||
|
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||||
|
('linear_loss', linear_loss.data[0]),
|
||||||
|
('mel_loss', mel_loss.data[0]),
|
||||||
|
('grad_norm', grad_norm)])
|
||||||
|
|
||||||
|
# Plot Training Iter Stats
|
||||||
|
tb.add_scalar('TrainIterLoss/TotalLoss', loss.data[0], current_step)
|
||||||
|
tb.add_scalar('TrainIterLoss/LinearLoss', linear_loss.data[0],
|
||||||
|
current_step)
|
||||||
|
tb.add_scalar('TrainIterLoss/MelLoss', mel_loss.data[0], current_step)
|
||||||
|
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
||||||
|
current_step)
|
||||||
|
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||||
|
tb.add_scalar('Time/StepTime', step_time, current_step)
|
||||||
|
|
||||||
|
if current_step % c.save_step == 0:
|
||||||
|
if c.checkpoint:
|
||||||
|
# save model
|
||||||
|
save_checkpoint(model, optimizer, linear_loss.data[0],
|
||||||
|
OUT_PATH, current_step, epoch)
|
||||||
|
|
||||||
|
# Diagnostic visualizations
|
||||||
|
const_spec = linear_output[0].data.cpu().numpy()
|
||||||
|
gt_spec = linear_spec_var[0].data.cpu().numpy()
|
||||||
|
|
||||||
|
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
||||||
|
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
|
||||||
|
tb.add_image('Visual/Reconstruction', const_spec, current_step)
|
||||||
|
tb.add_image('Visual/GroundTruth', gt_spec, current_step)
|
||||||
|
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
align_img = plot_alignment(align_img)
|
||||||
|
tb.add_image('Visual/Alignment', align_img, current_step)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
audio_signal = linear_output[0].data.cpu().numpy()
|
||||||
|
data_loader.dataset.ap.griffin_lim_iters = 60
|
||||||
|
audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T)
|
||||||
|
try:
|
||||||
|
tb.add_audio('SampleAudio', audio_signal, current_step,
|
||||||
|
sample_rate=c.sample_rate)
|
||||||
|
except:
|
||||||
|
print("\n > Error at audio signal on TB!!")
|
||||||
|
print(audio_signal.max())
|
||||||
|
print(audio_signal.min())
|
||||||
|
|
||||||
|
|
||||||
|
avg_linear_loss /= (num_iter + 1)
|
||||||
|
avg_mel_loss /= (num_iter + 1)
|
||||||
|
avg_total_loss = avg_mel_loss + avg_linear_loss
|
||||||
|
|
||||||
|
# Plot Training Epoch Stats
|
||||||
|
tb.add_scalar('TrainEpochLoss/TotalLoss', loss.data[0], current_step)
|
||||||
|
tb.add_scalar('TrainEpochLoss/LinearLoss', linear_loss.data[0], current_step)
|
||||||
|
tb.add_scalar('TrainEpochLoss/MelLoss', mel_loss.data[0], current_step)
|
||||||
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
|
epoch_time = 0
|
||||||
|
|
||||||
|
return avg_linear_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, criterion, data_loader, current_step):
|
||||||
|
model = model.train()
|
||||||
|
epoch_time = 0
|
||||||
|
|
||||||
|
print(" | > Validation")
|
||||||
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
|
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||||
|
|
||||||
|
avg_linear_loss = 0
|
||||||
|
avg_mel_loss = 0
|
||||||
|
|
||||||
|
for num_iter, data in enumerate(data_loader):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# setup input data
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
|
||||||
|
# convert inputs to variables
|
||||||
|
text_input_var = Variable(text_input)
|
||||||
|
mel_spec_var = Variable(mel_input)
|
||||||
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
|
# dispatch data to GPU
|
||||||
|
if use_cuda:
|
||||||
|
text_input_var = text_input_var.cuda()
|
||||||
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
mel_output, linear_output, alignments =\
|
||||||
|
model.forward(text_input_var, mel_spec_var)
|
||||||
|
|
||||||
|
# loss computation
|
||||||
|
mel_loss = criterion(mel_output, mel_spec_var)
|
||||||
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
||||||
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
|
linear_spec_var[: ,: ,:n_priority_freq])
|
||||||
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
|
step_time = time.time() - start_time
|
||||||
|
epoch_time += step_time
|
||||||
|
|
||||||
|
# update
|
||||||
|
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||||
|
('linear_loss', linear_loss.data[0]),
|
||||||
|
('mel_loss', mel_loss.data[0])])
|
||||||
|
|
||||||
|
avg_linear_loss += linear_loss.data[0]
|
||||||
|
avg_mel_loss += mel_loss.data[0]
|
||||||
|
|
||||||
|
# Diagnostic visualizations
|
||||||
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
|
const_spec = linear_output[idx].data.cpu().numpy()
|
||||||
|
gt_spec = linear_spec_var[idx].data.cpu().numpy()
|
||||||
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
|
|
||||||
|
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
||||||
|
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
|
||||||
|
align_img = plot_alignment(align_img)
|
||||||
|
|
||||||
|
tb.add_image('ValVisual/Reconstruction', const_spec, current_step)
|
||||||
|
tb.add_image('ValVisual/GroundTruth', gt_spec, current_step)
|
||||||
|
tb.add_image('ValVisual/ValidationAlignment', align_img, current_step)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
audio_signal = linear_output[idx].data.cpu().numpy()
|
||||||
|
data_loader.dataset.ap.griffin_lim_iters = 60
|
||||||
|
audio_signal = data_loader.dataset.ap.inv_spectrogram(audio_signal.T)
|
||||||
|
try:
|
||||||
|
tb.add_audio('ValSampleAudio', audio_signal, current_step,
|
||||||
|
sample_rate=c.sample_rate)
|
||||||
|
except:
|
||||||
|
print(" | > Error at audio signal on TB!!")
|
||||||
|
print(audio_signal.max())
|
||||||
|
print(audio_signal.min())
|
||||||
|
|
||||||
|
# compute average losses
|
||||||
|
avg_linear_loss /= (num_iter + 1)
|
||||||
|
avg_mel_loss /= (num_iter + 1)
|
||||||
|
avg_total_loss = avg_mel_loss + avg_linear_loss
|
||||||
|
|
||||||
|
# Plot Learning Stats
|
||||||
|
tb.add_scalar('ValEpochLoss/TotalLoss', avg_total_loss, current_step)
|
||||||
|
tb.add_scalar('ValEpochLoss/LinearLoss', avg_linear_loss, current_step)
|
||||||
|
tb.add_scalar('ValEpochLoss/MelLoss', avg_mel_loss, current_step)
|
||||||
|
return avg_linear_loss
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
|
||||||
# setup output paths and read configs
|
|
||||||
c = load_config(args.config_path)
|
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
OUT_PATH = os.path.join(_, c.output_path)
|
|
||||||
OUT_PATH = create_experiment_folder(OUT_PATH)
|
|
||||||
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
|
||||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
|
||||||
|
|
||||||
# save config to tmp place to be loaded by subsequent modules.
|
|
||||||
file_name = str(os.getpid())
|
|
||||||
tmp_path = os.path.join("/tmp/", file_name+'_tts')
|
|
||||||
pickle.dump(c, open(tmp_path, "wb"))
|
|
||||||
|
|
||||||
# setup tensorboard
|
|
||||||
LOG_DIR = OUT_PATH
|
|
||||||
tb = SummaryWriter(LOG_DIR)
|
|
||||||
|
|
||||||
# Ctrl+C handler to remove empty experiment folder
|
|
||||||
def signal_handler(signal, frame):
|
|
||||||
print(" !! Pressed Ctrl+C !!")
|
|
||||||
remove_experiment_folder(OUT_PATH)
|
|
||||||
sys.exit(1)
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
train_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_train.csv'),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
|
@ -71,204 +304,77 @@ def main(args):
|
||||||
c.power
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
train_loader = DataLoader(train_dataset, batch_size=c.batch_size,
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
shuffle=False, collate_fn=train_dataset.collate_fn,
|
||||||
drop_last=True, num_workers=c.num_loader_workers)
|
drop_last=False, num_workers=c.num_loader_workers,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
|
val_dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata_val.csv'),
|
||||||
|
os.path.join(c.data_path, 'wavs'),
|
||||||
|
c.r,
|
||||||
|
c.sample_rate,
|
||||||
|
c.text_cleaner,
|
||||||
|
c.num_mels,
|
||||||
|
c.min_level_db,
|
||||||
|
c.frame_shift_ms,
|
||||||
|
c.frame_length_ms,
|
||||||
|
c.preemphasis,
|
||||||
|
c.ref_level_db,
|
||||||
|
c.num_freq,
|
||||||
|
c.power
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=c.batch_size,
|
||||||
|
shuffle=False, collate_fn=val_dataset.collate_fn,
|
||||||
|
drop_last=False, num_workers= 4,
|
||||||
|
pin_memory=True)
|
||||||
|
|
||||||
# setup the model
|
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size,
|
||||||
c.hidden_size,
|
c.hidden_size,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.r)
|
c.r,
|
||||||
|
use_atten_mask=True)
|
||||||
# plot model on tensorboard
|
|
||||||
dummy_input = dataset.get_dummy_data()
|
|
||||||
|
|
||||||
## TODO: onnx does not support RNN fully yet
|
|
||||||
# model_proto_path = os.path.join(OUT_PATH, "model.proto")
|
|
||||||
# onnx.export(model, dummy_input, model_proto_path, verbose=True)
|
|
||||||
# tb.add_graph_onnx(model_proto_path)
|
|
||||||
|
|
||||||
if use_cuda:
|
|
||||||
model = nn.DataParallel(model.cuda())
|
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
if args.restore_step:
|
|
||||||
checkpoint = torch.load(os.path.join(
|
|
||||||
args.restore_path, 'checkpoint_%d.pth.tar' % args.restore_step))
|
|
||||||
model.load_state_dict(checkpoint['model'])
|
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
||||||
print("\n > Model restored from step %d\n" % args.restore_step)
|
|
||||||
start_epoch = checkpoint['step'] // len(dataloader)
|
|
||||||
best_loss = checkpoint['linear_loss']
|
|
||||||
else:
|
|
||||||
start_epoch = 0
|
|
||||||
print("\n > Starting a new training")
|
|
||||||
|
|
||||||
num_params = count_parameters(model)
|
|
||||||
print(" | > Model has {} parameters".format(num_params))
|
|
||||||
|
|
||||||
model = model.train()
|
|
||||||
|
|
||||||
if not os.path.exists(CHECKPOINT_PATH):
|
|
||||||
os.mkdir(CHECKPOINT_PATH)
|
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
criterion = nn.L1Loss().cuda()
|
criterion = nn.L1Loss().cuda()
|
||||||
else:
|
else:
|
||||||
criterion = nn.L1Loss()
|
criterion = nn.L1Loss()
|
||||||
|
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
if args.restore_path:
|
||||||
|
checkpoint = torch.load(args.restore_path)
|
||||||
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
print("\n > Model restored from step %d\n" % checkpoint['step'])
|
||||||
|
start_epoch = checkpoint['step'] // len(train_loader)
|
||||||
|
best_loss = checkpoint['linear_loss']
|
||||||
|
start_epoch = 0
|
||||||
|
args.restore_step = checkpoint['step']
|
||||||
|
else:
|
||||||
|
args.restore_step = 0
|
||||||
|
print("\n > Starting a new training")
|
||||||
|
|
||||||
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
if use_cuda:
|
||||||
# patience=c.lr_patience, verbose=True)
|
model = nn.DataParallel(model.cuda())
|
||||||
epoch_time = 0
|
|
||||||
best_loss = float('inf')
|
num_params = count_parameters(model)
|
||||||
|
print(" | > Model has {} parameters".format(num_params))
|
||||||
|
|
||||||
|
if not os.path.exists(CHECKPOINT_PATH):
|
||||||
|
os.mkdir(CHECKPOINT_PATH)
|
||||||
|
|
||||||
|
if 'best_loss' not in locals():
|
||||||
|
best_loss = float('inf')
|
||||||
|
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
|
train_loss, current_step = train(model, criterion, train_loader, optimizer, epoch)
|
||||||
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
val_loss = evaluate(model, criterion, val_loader, current_step)
|
||||||
progbar = Progbar(len(dataset) / c.batch_size)
|
best_loss = save_best_model(model, optimizer, val_loss,
|
||||||
|
|
||||||
for num_iter, data in enumerate(dataloader):
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
text_input = data[0]
|
|
||||||
text_lengths = data[1]
|
|
||||||
linear_input = data[2]
|
|
||||||
mel_input = data[3]
|
|
||||||
|
|
||||||
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
|
|
||||||
|
|
||||||
# setup lr
|
|
||||||
current_lr = lr_decay(c.lr, current_step)
|
|
||||||
for params_group in optimizer.param_groups:
|
|
||||||
params_group['lr'] = current_lr
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Add a single frame of zeros to Mel Specs for better end detection
|
|
||||||
#try:
|
|
||||||
# mel_input = np.concatenate((np.zeros(
|
|
||||||
# [c.batch_size, 1, c.num_mels], dtype=np.float32),
|
|
||||||
# mel_input[:, 1:, :]), axis=1)
|
|
||||||
#except:
|
|
||||||
# raise TypeError("not same dimension")
|
|
||||||
|
|
||||||
# convert inputs to variables
|
|
||||||
text_input_var = Variable(text_input)
|
|
||||||
mel_spec_var = Variable(mel_input)
|
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
|
||||||
|
|
||||||
# sort sequence by length.
|
|
||||||
# TODO: might be unnecessary
|
|
||||||
sorted_lengths, indices = torch.sort(
|
|
||||||
text_lengths.view(-1), dim=0, descending=True)
|
|
||||||
sorted_lengths = sorted_lengths.long().numpy()
|
|
||||||
|
|
||||||
text_input_var = text_input_var[indices]
|
|
||||||
mel_spec_var = mel_spec_var[indices]
|
|
||||||
linear_spec_var = linear_spec_var[indices]
|
|
||||||
|
|
||||||
if use_cuda:
|
|
||||||
text_input_var = text_input_var.cuda()
|
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
|
||||||
|
|
||||||
mel_output, linear_output, alignments =\
|
|
||||||
model.forward(text_input_var, mel_spec_var,
|
|
||||||
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
|
|
||||||
|
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
|
||||||
#linear_loss = torch.abs(linear_output - linear_spec_var)
|
|
||||||
#linear_loss = 0.5 * \
|
|
||||||
#torch.mean(linear_loss) + 0.5 * \
|
|
||||||
#torch.mean(linear_loss[:, :n_priority_freq, :])
|
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
|
||||||
linear_spec_var[: ,: ,:n_priority_freq])
|
|
||||||
loss = mel_loss + linear_loss
|
|
||||||
# loss = loss.cuda()
|
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
|
||||||
epoch_time += step_time
|
|
||||||
|
|
||||||
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
|
||||||
('linear_loss', linear_loss.data[0]),
|
|
||||||
('mel_loss', mel_loss.data[0]),
|
|
||||||
('grad_norm', grad_norm)])
|
|
||||||
|
|
||||||
# Plot Learning Stats
|
|
||||||
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
|
||||||
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
|
|
||||||
current_step)
|
|
||||||
tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step)
|
|
||||||
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
|
||||||
current_step)
|
|
||||||
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
|
||||||
tb.add_scalar('Time/StepTime', step_time, current_step)
|
|
||||||
|
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
|
||||||
align_img = plot_alignment(align_img)
|
|
||||||
tb.add_image('Attn/Alignment', align_img, current_step)
|
|
||||||
|
|
||||||
if current_step % c.save_step == 0:
|
|
||||||
|
|
||||||
if c.checkpoint:
|
|
||||||
# save model
|
|
||||||
save_checkpoint(model, optimizer, linear_loss.data[0],
|
|
||||||
OUT_PATH, current_step, epoch)
|
|
||||||
|
|
||||||
# Diagnostic visualizations
|
|
||||||
const_spec = linear_output[0].data.cpu().numpy()
|
|
||||||
gt_spec = linear_spec_var[0].data.cpu().numpy()
|
|
||||||
|
|
||||||
const_spec = plot_spectrogram(const_spec, dataset.ap)
|
|
||||||
gt_spec = plot_spectrogram(gt_spec, dataset.ap)
|
|
||||||
tb.add_image('Spec/Reconstruction', const_spec, current_step)
|
|
||||||
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
|
|
||||||
|
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
|
||||||
align_img = plot_alignment(align_img)
|
|
||||||
tb.add_image('Attn/Alignment', align_img, current_step)
|
|
||||||
|
|
||||||
# Sample audio
|
|
||||||
audio_signal = linear_output[0].data.cpu().numpy()
|
|
||||||
dataset.ap.griffin_lim_iters = 60
|
|
||||||
audio_signal = dataset.ap.inv_spectrogram(audio_signal.T)
|
|
||||||
try:
|
|
||||||
tb.add_audio('SampleAudio', audio_signal, current_step,
|
|
||||||
sample_rate=c.sample_rate)
|
|
||||||
except:
|
|
||||||
print("\n > Error at audio signal on TB!!")
|
|
||||||
print(audio_signal.max())
|
|
||||||
print(audio_signal.min())
|
|
||||||
|
|
||||||
|
|
||||||
# average loss after the epoch
|
|
||||||
avg_epoch_loss = np.mean(
|
|
||||||
progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1]))
|
|
||||||
best_loss = save_best_model(model, optimizer, avg_epoch_loss,
|
|
||||||
best_loss, OUT_PATH,
|
best_loss, OUT_PATH,
|
||||||
current_step, epoch)
|
current_step, epoch)
|
||||||
|
|
||||||
#lr_scheduler.step(loss.data[0])
|
|
||||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
|
||||||
epoch_time = 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
parser.add_argument('--restore_step', type=int,
|
|
||||||
help='Global step to restore checkpoint', default=0)
|
|
||||||
parser.add_argument('--restore_path', type=str,
|
|
||||||
help='Folder path to checkpoints', default=0)
|
|
||||||
parser.add_argument('--config_path', type=str,
|
|
||||||
help='path to config file for training',)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import datetime
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(dict):
|
||||||
|
@ -94,8 +95,21 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
return best_loss
|
return best_loss
|
||||||
|
|
||||||
|
|
||||||
def lr_decay(init_lr, global_step):
|
def check_update(model, grad_clip, grad_top):
|
||||||
warmup_steps = 4000.0
|
r'''Check model gradient against unexpected jumps and failures'''
|
||||||
|
skip_flag = False
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm(model.parameters(), grad_clip)
|
||||||
|
if np.isinf(grad_norm):
|
||||||
|
print(" | > Gradient is INF !!")
|
||||||
|
skip_flag = True
|
||||||
|
elif grad_norm > grad_top:
|
||||||
|
print(" | > Gradient is above the top limit !!")
|
||||||
|
skip_flag = True
|
||||||
|
return grad_norm, skip_flag
|
||||||
|
|
||||||
|
|
||||||
|
def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
|
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
|
||||||
step = global_step + 1.
|
step = global_step + 1.
|
||||||
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
|
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
|
||||||
step**-0.5)
|
step**-0.5)
|
||||||
|
@ -197,13 +211,13 @@ class Progbar(object):
|
||||||
eta_format = '%ds' % eta
|
eta_format = '%ds' % eta
|
||||||
|
|
||||||
info = ' - ETA: %s' % eta_format
|
info = ' - ETA: %s' % eta_format
|
||||||
|
|
||||||
|
if time_per_unit >= 1:
|
||||||
|
info += ' %.0fs/step' % time_per_unit
|
||||||
|
elif time_per_unit >= 1e-3:
|
||||||
|
info += ' %.0fms/step' % (time_per_unit * 1e3)
|
||||||
else:
|
else:
|
||||||
if time_per_unit >= 1:
|
info += ' %.0fus/step' % (time_per_unit * 1e6)
|
||||||
info += ' %.0fs/step' % time_per_unit
|
|
||||||
elif time_per_unit >= 1e-3:
|
|
||||||
info += ' %.0fms/step' % (time_per_unit * 1e3)
|
|
||||||
else:
|
|
||||||
info += ' %.0fus/step' % (time_per_unit * 1e6)
|
|
||||||
|
|
||||||
for k in self.unique_values:
|
for k in self.unique_values:
|
||||||
info += ' - %s:' % k
|
info += ' - %s:' % k
|
||||||
|
|
|
@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
def plot_alignment(alignment, info=None):
|
def plot_alignment(alignment, info=None):
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots(figsize=(16,10))
|
||||||
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
|
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
|
||||||
interpolation='none')
|
interpolation='none')
|
||||||
fig.colorbar(im, ax=ax)
|
fig.colorbar(im, ax=ax)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче