зеркало из https://github.com/mozilla/TTS.git
Merge pull request #3 from mozilla/testing_docing
Tests and some changes on the architecture.
This commit is contained in:
Коммит
ab74c3e0cf
|
@ -116,3 +116,5 @@ venv.bak/
|
||||||
*.pth.tar
|
*.pth.tar
|
||||||
result/
|
result/
|
||||||
|
|
||||||
|
# jupyter dummy files
|
||||||
|
core
|
||||||
|
|
50
README.md
50
README.md
|
@ -15,7 +15,53 @@ Highly recommended to use [miniconda](https://conda.io/miniconda.html) for easie
|
||||||
* TODO
|
* TODO
|
||||||
|
|
||||||
## Data
|
## Data
|
||||||
TODO
|
Currently TTS provides data loaders for
|
||||||
|
- [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
|
||||||
|
|
||||||
## Training the network
|
## Training the network
|
||||||
TODO
|
To run your own training, you need to define a ```config.json``` file (simple template below) and call with the command.
|
||||||
|
|
||||||
|
```train.py --config_path config.json```
|
||||||
|
|
||||||
|
If you like to use specific set of GPUs.
|
||||||
|
|
||||||
|
```CUDA_VISIBLE_DEVICES="0,1,4" train.py --config_path config.json```
|
||||||
|
|
||||||
|
Each run creates an experiment folder with the corresponfing date and time, under the folder you set in ```config.json```. And if there is no checkpoint yet under that folder, it is going to be removed when you press Ctrl+C.
|
||||||
|
|
||||||
|
Example ```config.json```:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
// Data loading parameters
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1024,
|
||||||
|
"sample_rate": 20000,
|
||||||
|
"frame_length_ms": 50.0,
|
||||||
|
"frame_shift_ms": 12.5,
|
||||||
|
"preemphasis": 0.97,
|
||||||
|
"min_level_db": -100,
|
||||||
|
"ref_level_db": 20,
|
||||||
|
"hidden_size": 128,
|
||||||
|
"embedding_size": 256,
|
||||||
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
|
// Training parameters
|
||||||
|
"epochs": 2000,
|
||||||
|
"lr": 0.001,
|
||||||
|
"lr_patience": 2, // lr_scheduler.ReduceLROnPlateau().patience
|
||||||
|
"lr_decay": 0.5, // lr_scheduler.ReduceLROnPlateau().factor
|
||||||
|
"batch_size": 256,
|
||||||
|
"griffinf_lim_iters": 60,
|
||||||
|
"power": 1.5,
|
||||||
|
"r": 5, // number of decoder outputs for Tacotron
|
||||||
|
|
||||||
|
// Number of data loader processes
|
||||||
|
"num_loader_workers": 8,
|
||||||
|
|
||||||
|
// Experiment logging parameters
|
||||||
|
"save_step": 200,
|
||||||
|
"data_path": "/path/to/KeithIto/LJSpeech-1.0",
|
||||||
|
"output_path": "/path/to/my_experiment",
|
||||||
|
"log_dir": "/path/to/my/tensorboard/logs/"
|
||||||
|
}
|
||||||
|
```
|
34
config.json
34
config.json
|
@ -1,29 +1,31 @@
|
||||||
{
|
{
|
||||||
"num_mels": 80,
|
"num_mels": 80,
|
||||||
"num_freq": 1024,
|
"num_freq": 1025,
|
||||||
"sample_rate": 20000,
|
"sample_rate": 20000,
|
||||||
"frame_length_ms": 50.0,
|
"frame_length_ms": 50,
|
||||||
"frame_shift_ms": 12.5,
|
"frame_shift_ms": 12.5,
|
||||||
"preemphasis": 0.97,
|
"preemphasis": 0.97,
|
||||||
"min_level_db": -100,
|
"min_level_db": -100,
|
||||||
"ref_level_db": 20,
|
"ref_level_db": 20,
|
||||||
"hidden_size": 128,
|
"hidden_size": 128,
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
|
|
||||||
"epochs": 10000,
|
|
||||||
"lr": 0.01,
|
|
||||||
"decay_step": [500000, 1000000, 2000000],
|
|
||||||
"batch_size": 128,
|
|
||||||
"max_iters": 200,
|
|
||||||
"griffinf_lim_iters": 60,
|
|
||||||
"power": 1.5,
|
|
||||||
"r": 5,
|
|
||||||
|
|
||||||
"log_step": 100,
|
|
||||||
"save_step": 2000,
|
|
||||||
|
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
|
"epochs": 2000,
|
||||||
|
"lr": 0.003,
|
||||||
|
"lr_patience": 5,
|
||||||
|
"lr_decay": 0.5,
|
||||||
|
"batch_size": 180,
|
||||||
|
"r": 5,
|
||||||
|
|
||||||
|
"griffin_lim_iters": 60,
|
||||||
|
"power": 1.5,
|
||||||
|
|
||||||
|
"num_loader_workers": 32,
|
||||||
|
|
||||||
|
"checkpoint": false,
|
||||||
|
"save_step": 69,
|
||||||
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
"output_path": "result"
|
"output_path": "result",
|
||||||
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
}
|
}
|
||||||
|
|
Двоичные данные
datasets/.LJSpeech.py.swp
Двоичные данные
datasets/.LJSpeech.py.swp
Двоичный файл не отображается.
|
@ -2,28 +2,34 @@ import pandas as pd
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
import collections
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from Tacotron.utils.text import text_to_sequence
|
from TTS.utils.text import text_to_sequence
|
||||||
from Tacotron.utils.audio import *
|
from TTS.utils.audio import AudioProcessor
|
||||||
from Tacotron.utils.data import prepare_data, pad_data, pad_per_step
|
from TTS.utils.data import prepare_data, pad_data, pad_per_step
|
||||||
|
|
||||||
|
|
||||||
class LJSpeechDataset(Dataset):
|
class LJSpeechDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate,
|
||||||
cleaners):
|
text_cleaner, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power):
|
||||||
self.frames = pd.read_csv(csv_file, sep='|', header=None)
|
self.frames = pd.read_csv(csv_file, sep='|', header=None)
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir
|
||||||
self.outputs_per_step = outputs_per_step
|
self.outputs_per_step = outputs_per_step
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.cleaners = cleaners
|
self.cleaners = text_cleaner
|
||||||
|
self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power
|
||||||
|
)
|
||||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||||
|
|
||||||
def load_wav(self, filename):
|
def load_wav(self, filename):
|
||||||
try:
|
try:
|
||||||
audio = librosa.load(filename, sr=self.sample_rate)
|
audio = librosa.core.load(filename, sr=self.sample_rate)
|
||||||
return audio
|
return audio
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(" !! Cannot read file : {}".format(filename))
|
print(" !! Cannot read file : {}".format(filename))
|
||||||
|
@ -37,36 +43,54 @@ class LJSpeechDataset(Dataset):
|
||||||
text = self.frames.ix[idx, 1]
|
text = self.frames.ix[idx, 1]
|
||||||
text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
text = np.asarray(text_to_sequence(text, [self.cleaners]), dtype=np.int32)
|
||||||
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
wav = np.asarray(self.load_wav(wav_name)[0], dtype=np.float32)
|
||||||
sample = {'text': text, 'wav': wav}
|
sample = {'text': text, 'wav': wav, 'item_idx': self.frames.ix[idx, 0]}
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
def get_dummy_data(self):
|
||||||
|
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
|
||||||
|
|
||||||
def collate_fn(self, batch):
|
def collate_fn(self, batch):
|
||||||
|
|
||||||
# Puts each data field into a tensor with outer dimension batch size
|
# Puts each data field into a tensor with outer dimension batch size
|
||||||
if isinstance(batch[0], collections.Mapping):
|
if isinstance(batch[0], collections.Mapping):
|
||||||
keys = list()
|
keys = list()
|
||||||
|
|
||||||
text = [d['text'] for d in batch]
|
|
||||||
wav = [d['wav'] for d in batch]
|
wav = [d['wav'] for d in batch]
|
||||||
|
item_idxs = [d['item_idx'] for d in batch]
|
||||||
|
text = [d['text'] for d in batch]
|
||||||
|
|
||||||
|
text_lenghts = np.array([len(x) for x in text])
|
||||||
|
max_text_len = np.max(text_lenghts)
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
wav = prepare_data(wav)
|
wav = prepare_data(wav)
|
||||||
|
|
||||||
magnitude = np.array([spectrogram(w) for w in wav])
|
linear = np.array([self.ap.spectrogram(w).astype('float32') for w in wav])
|
||||||
mel = np.array([melspectrogram(w) for w in wav])
|
mel = np.array([self.ap.melspectrogram(w).astype('float32') for w in wav])
|
||||||
|
assert mel.shape[2] == linear.shape[2]
|
||||||
timesteps = mel.shape[2]
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
# PAD with zeros that can be divided by outputs per step
|
# PAD with zeros that can be divided by outputs per step
|
||||||
if timesteps % self.outputs_per_step != 0:
|
if (timesteps + 1) % self.outputs_per_step != 0:
|
||||||
magnitude = pad_per_step(magnitude, self.outputs_per_step)
|
pad_len = self.outputs_per_step - \
|
||||||
mel = pad_per_step(mel, self.outputs_per_step)
|
((timesteps + 1) % self.outputs_per_step)
|
||||||
|
pad_len += 1
|
||||||
|
else:
|
||||||
|
pad_len = 1
|
||||||
|
linear = pad_per_step(linear, pad_len)
|
||||||
|
mel = pad_per_step(mel, pad_len)
|
||||||
|
|
||||||
# reshape jombo
|
# reshape jombo
|
||||||
magnitude = magnitude.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
return text, magnitude, mel
|
# convert things to pytorch
|
||||||
|
text_lenghts = torch.LongTensor(text_lenghts)
|
||||||
|
text = torch.LongTensor(text)
|
||||||
|
linear = torch.FloatTensor(linear)
|
||||||
|
mel = torch.FloatTensor(mel)
|
||||||
|
return text, text_lenghts, linear, mel, item_idxs[0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}"
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
{
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1024,
|
||||||
|
"sample_rate": 20000,
|
||||||
|
"frame_length_ms": 50.0,
|
||||||
|
"frame_shift_ms": 12.5,
|
||||||
|
"preemphasis": 0.97,
|
||||||
|
"min_level_db": -100,
|
||||||
|
"ref_level_db": 20,
|
||||||
|
"hidden_size": 128,
|
||||||
|
"embedding_size": 256,
|
||||||
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
|
"epochs": 200,
|
||||||
|
"lr": 0.01,
|
||||||
|
"lr_patience": 2,
|
||||||
|
"lr_decay": 0.5,
|
||||||
|
"batch_size": 32,
|
||||||
|
"griffinf_lim_iters": 60,
|
||||||
|
"power": 1.5,
|
||||||
|
"r": 5,
|
||||||
|
|
||||||
|
"num_loader_workers": 16,
|
||||||
|
|
||||||
|
"save_step":1 ,
|
||||||
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
|
"output_path": "result",
|
||||||
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
|
}
|
Двоичные данные
layers/.attention.py.swp
Двоичные данные
layers/.attention.py.swp
Двоичный файл не отображается.
Двоичные данные
layers/.tacotron.py.swp
Двоичные данные
layers/.tacotron.py.swp
Двоичный файл не отображается.
|
@ -11,11 +11,11 @@ class BahdanauAttention(nn.Module):
|
||||||
self.tanh = nn.Tanh()
|
self.tanh = nn.Tanh()
|
||||||
self.v = nn.Linear(dim, 1, bias=False)
|
self.v = nn.Linear(dim, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, query, processed_memory):
|
def forward(self, query, processed_inputs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
query: (batch, 1, dim) or (batch, dim)
|
query: (batch, 1, dim) or (batch, dim)
|
||||||
processed_memory: (batch, max_time, dim)
|
processed_inputs: (batch, max_time, dim)
|
||||||
"""
|
"""
|
||||||
if query.dim() == 2:
|
if query.dim() == 2:
|
||||||
# insert time-axis for broadcasting
|
# insert time-axis for broadcasting
|
||||||
|
@ -24,63 +24,71 @@ class BahdanauAttention(nn.Module):
|
||||||
processed_query = self.query_layer(query)
|
processed_query = self.query_layer(query)
|
||||||
|
|
||||||
# (batch, max_time, 1)
|
# (batch, max_time, 1)
|
||||||
alignment = self.v(self.tanh(processed_query + processed_memory))
|
alignment = self.v(self.tanh(processed_query + processed_inputs))
|
||||||
|
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
def get_mask_from_lengths(memory, memory_lengths):
|
def get_mask_from_lengths(inputs, inputs_lengths):
|
||||||
"""Get mask tensor from list of length
|
"""Get mask tensor from list of length
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory: (batch, max_time, dim)
|
inputs: (batch, max_time, dim)
|
||||||
memory_lengths: array like
|
inputs_lengths: array like
|
||||||
"""
|
"""
|
||||||
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
|
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
|
||||||
for idx, l in enumerate(memory_lengths):
|
for idx, l in enumerate(inputs_lengths):
|
||||||
mask[idx][:l] = 1
|
mask[idx][:l] = 1
|
||||||
return ~mask
|
return ~mask
|
||||||
|
|
||||||
|
|
||||||
class AttentionWrapper(nn.Module):
|
class AttentionWrapper(nn.Module):
|
||||||
def __init__(self, rnn_cell, attention_mechanism,
|
def __init__(self, rnn_cell, alignment_model,
|
||||||
score_mask_value=-float("inf")):
|
score_mask_value=-float("inf")):
|
||||||
super(AttentionWrapper, self).__init__()
|
super(AttentionWrapper, self).__init__()
|
||||||
self.rnn_cell = rnn_cell
|
self.rnn_cell = rnn_cell
|
||||||
self.attention_mechanism = attention_mechanism
|
self.alignment_model = alignment_model
|
||||||
self.score_mask_value = score_mask_value
|
self.score_mask_value = score_mask_value
|
||||||
|
|
||||||
def forward(self, query, attention, cell_state, memory,
|
def forward(self, query, context_vec, cell_state, inputs,
|
||||||
processed_memory=None, mask=None, memory_lengths=None):
|
processed_inputs=None, mask=None, inputs_lengths=None):
|
||||||
if processed_memory is None:
|
|
||||||
processed_memory = memory
|
|
||||||
if memory_lengths is not None and mask is None:
|
|
||||||
mask = get_mask_from_lengths(memory, memory_lengths)
|
|
||||||
|
|
||||||
# Concat input query and previous attention context
|
if processed_inputs is None:
|
||||||
cell_input = torch.cat((query, attention), -1)
|
processed_inputs = inputs
|
||||||
|
|
||||||
# Feed it to RNN
|
if inputs_lengths is not None and mask is None:
|
||||||
cell_output = self.rnn_cell(cell_input, cell_state)
|
mask = get_mask_from_lengths(inputs, inputs_lengths)
|
||||||
|
|
||||||
# Alignment
|
# Alignment
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
alignment = self.attention_mechanism(cell_output, processed_memory)
|
# e_{ij} = a(s_{i-1}, h_j)
|
||||||
|
# import ipdb
|
||||||
|
# ipdb.set_trace()
|
||||||
|
alignment = self.alignment_model(cell_state, processed_inputs)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.view(query.size(0), -1)
|
mask = mask.view(query.size(0), -1)
|
||||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||||
|
|
||||||
# Normalize attention weight
|
# Normalize context_vec weight
|
||||||
alignment = F.softmax(alignment, dim=0)
|
alignment = F.softmax(alignment, dim=-1)
|
||||||
|
|
||||||
# Attention context vector
|
# Attention context vector
|
||||||
# (batch, 1, dim)
|
# (batch, 1, dim)
|
||||||
attention = torch.bmm(alignment.unsqueeze(1), memory)
|
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||||
|
context_vec = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||||
|
context_vec = context_vec.squeeze(1)
|
||||||
|
|
||||||
# (batch, dim)
|
# Concat input query and previous context_vec context
|
||||||
attention = attention.squeeze(1)
|
cell_input = torch.cat((query, context_vec), -1)
|
||||||
|
#cell_input = cell_input.unsqueeze(1)
|
||||||
|
|
||||||
|
# Feed it to RNN
|
||||||
|
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||||
|
cell_output = self.rnn_cell(cell_input, cell_state)
|
||||||
|
|
||||||
|
context_vec = context_vec.squeeze(1)
|
||||||
|
return cell_output, context_vec, alignment
|
||||||
|
|
||||||
return cell_output, attention, alignment
|
|
||||||
|
|
||||||
|
|
|
@ -7,35 +7,59 @@ from .attention import BahdanauAttention, AttentionWrapper
|
||||||
from .attention import get_mask_from_lengths
|
from .attention import get_mask_from_lengths
|
||||||
|
|
||||||
class Prenet(nn.Module):
|
class Prenet(nn.Module):
|
||||||
def __init__(self, in_dim, sizes=[256, 128]):
|
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||||
|
It creates as many layers as given by 'out_features'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): size of the input vector
|
||||||
|
out_features (int or list): size of each output sample.
|
||||||
|
If it is a list, for each value, there is created a new layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features=[256, 128]):
|
||||||
super(Prenet, self).__init__()
|
super(Prenet, self).__init__()
|
||||||
in_sizes = [in_dim] + sizes[:-1]
|
in_features = [in_features] + out_features[:-1]
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[nn.Linear(in_size, out_size)
|
[nn.Linear(in_size, out_size)
|
||||||
for (in_size, out_size) in zip(in_sizes, sizes)])
|
for (in_size, out_size) in zip(in_features, out_features)])
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.dropout = nn.Dropout(0.5)
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
for linear in self.layers:
|
for linear in self.layers:
|
||||||
inputs = self.dropout(self.relu(linear(inputs)))
|
inputs = self.dropout(self.relu(linear(inputs)))
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
class BatchNormConv1d(nn.Module):
|
class BatchNormConv1d(nn.Module):
|
||||||
def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
|
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
|
||||||
|
function between Conv and BatchNorm layers. BatchNorm layer
|
||||||
|
is initialized with the TF default values for momentum and eps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: size of each input sample
|
||||||
|
out_channels: size of each output samples
|
||||||
|
kernel_size: kernel size of conv filters
|
||||||
|
stride: stride of conv filters
|
||||||
|
padding: padding of conv filters
|
||||||
|
activation: activation function set b/w Conv1d and BatchNorm
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- input: batch x dims
|
||||||
|
- output: batch x dims
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(BatchNormConv1d, self).__init__()
|
super(BatchNormConv1d, self).__init__()
|
||||||
self.conv1d = nn.Conv1d(in_dim, out_dim,
|
self.conv1d = nn.Conv1d(in_channels, out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=stride, padding=padding, bias=False)
|
stride=stride, padding=padding, bias=False)
|
||||||
# Following tensorflow's default parameters
|
# Following tensorflow's default parameters
|
||||||
self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3)
|
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv1d(x)
|
x = self.conv1d(x)
|
||||||
if self.activation is not None:
|
if self.activation is not None:
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
return self.bn(x)
|
return self.bn(x)
|
||||||
|
@ -62,135 +86,180 @@ class CBHG(nn.Module):
|
||||||
- 1-d convolution banks
|
- 1-d convolution banks
|
||||||
- Highway networks + residual connections
|
- Highway networks + residual connections
|
||||||
- Bidirectional gated recurrent units
|
- Bidirectional gated recurrent units
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): sample size
|
||||||
|
K (int): max filter size in conv bank
|
||||||
|
projections (list): conv channel sizes for conv projections
|
||||||
|
num_highways (int): number of highways layers
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- input: batch x time x dim
|
||||||
|
- output: batch x time x dim*2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_dim, K=16, projections=[128, 128]):
|
def __init__(self, in_features, K=16, projections=[128, 128], num_highways=4):
|
||||||
super(CBHG, self).__init__()
|
super(CBHG, self).__init__()
|
||||||
self.in_dim = in_dim
|
self.in_features = in_features
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
# list of conv1d bank with filter size k=1...K
|
||||||
|
# TODO: try dilational layers instead
|
||||||
self.conv1d_banks = nn.ModuleList(
|
self.conv1d_banks = nn.ModuleList(
|
||||||
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
|
[BatchNormConv1d(in_features, in_features, kernel_size=k, stride=1,
|
||||||
padding=k // 2, activation=self.relu)
|
padding=k // 2, activation=self.relu)
|
||||||
for k in range(1, K + 1)])
|
for k in range(1, K + 1)])
|
||||||
|
|
||||||
|
# max pooling of conv bank
|
||||||
|
# TODO: try average pooling OR larger kernel size
|
||||||
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
||||||
|
|
||||||
in_sizes = [K * in_dim] + projections[:-1]
|
out_features = [K * in_features] + projections[:-1]
|
||||||
activations = [self.relu] * (len(projections) - 1) + [None]
|
activations = [self.relu] * (len(projections) - 1)
|
||||||
self.conv1d_projections = nn.ModuleList(
|
activations += [None]
|
||||||
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
|
|
||||||
padding=1, activation=ac)
|
|
||||||
for (in_size, out_size, ac) in zip(
|
|
||||||
in_sizes, projections, activations)])
|
|
||||||
|
|
||||||
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
|
# setup conv1d projection layers
|
||||||
|
layer_set = []
|
||||||
|
for (in_size, out_size, ac) in zip(out_features, projections, activations):
|
||||||
|
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
|
||||||
|
padding=1, activation=ac)
|
||||||
|
layer_set.append(layer)
|
||||||
|
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||||
|
|
||||||
|
# setup Highway layers
|
||||||
|
self.pre_highway = nn.Linear(projections[-1], in_features, bias=False)
|
||||||
self.highways = nn.ModuleList(
|
self.highways = nn.ModuleList(
|
||||||
[Highway(in_dim, in_dim) for _ in range(4)])
|
[Highway(in_features, in_features) for _ in range(num_highways)])
|
||||||
|
|
||||||
|
# bi-directional GPU layer
|
||||||
self.gru = nn.GRU(
|
self.gru = nn.GRU(
|
||||||
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
|
in_features, in_features, 1, batch_first=True, bidirectional=True)
|
||||||
|
|
||||||
def forward(self, inputs, input_lengths=None):
|
def forward(self, inputs):
|
||||||
# (B, T_in, in_dim)
|
# (B, T_in, in_features)
|
||||||
x = inputs
|
x = inputs
|
||||||
|
|
||||||
# Needed to perform conv1d on time-axis
|
# Needed to perform conv1d on time-axis
|
||||||
# (B, in_dim, T_in)
|
# (B, in_features, T_in)
|
||||||
if x.size(-1) == self.in_dim:
|
if x.size(-1) == self.in_features:
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
T = x.size(-1)
|
T = x.size(-1)
|
||||||
|
|
||||||
# (B, in_dim*K, T_in)
|
# (B, in_features*K, T_in)
|
||||||
# Concat conv1d bank outputs
|
# Concat conv1d bank outputs
|
||||||
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
|
outs = []
|
||||||
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
|
for conv1d in self.conv1d_banks:
|
||||||
|
out = conv1d(x)
|
||||||
|
out = out[:, :, :T]
|
||||||
|
outs.append(out)
|
||||||
|
|
||||||
|
x = torch.cat(outs, dim=1)
|
||||||
|
assert x.size(1) == self.in_features * len(self.conv1d_banks)
|
||||||
|
|
||||||
x = self.max_pool1d(x)[:, :, :T]
|
x = self.max_pool1d(x)[:, :, :T]
|
||||||
|
|
||||||
for conv1d in self.conv1d_projections:
|
for conv1d in self.conv1d_projections:
|
||||||
x = conv1d(x)
|
x = conv1d(x)
|
||||||
|
|
||||||
# (B, T_in, in_dim)
|
# (B, T_in, in_features)
|
||||||
# Back to the original shape
|
# Back to the original shape
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
if x.size(-1) != self.in_dim:
|
if x.size(-1) != self.in_features:
|
||||||
x = self.pre_highway(x)
|
x = self.pre_highway(x)
|
||||||
|
|
||||||
# Residual connection
|
# Residual connection
|
||||||
|
# TODO: try residual scaling as in Deep Voice 3
|
||||||
|
# TODO: try plain residual layers
|
||||||
x += inputs
|
x += inputs
|
||||||
for highway in self.highways:
|
for highway in self.highways:
|
||||||
x = highway(x)
|
x = highway(x)
|
||||||
|
|
||||||
if input_lengths is not None:
|
# (B, T_in, in_features*2)
|
||||||
x = nn.utils.rnn.pack_padded_sequence(
|
# TODO: replace GRU with convolution as in Deep Voice 3
|
||||||
x, input_lengths, batch_first=True)
|
self.gru.flatten_parameters()
|
||||||
|
|
||||||
# (B, T_in, in_dim*2)
|
|
||||||
self.gru.flatten_parameters()
|
|
||||||
outputs, _ = self.gru(x)
|
outputs, _ = self.gru(x)
|
||||||
|
|
||||||
if input_lengths is not None:
|
|
||||||
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
|
||||||
outputs, batch_first=True)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, in_dim):
|
r"""Encapsulate Prenet and CBHG modules for encoder"""
|
||||||
|
|
||||||
|
def __init__(self, in_features):
|
||||||
super(Encoder, self).__init__()
|
super(Encoder, self).__init__()
|
||||||
self.prenet = Prenet(in_dim, sizes=[256, 128])
|
self.prenet = Prenet(in_features, out_features=[256, 128])
|
||||||
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
||||||
|
|
||||||
def forward(self, inputs, input_lengths=None):
|
def forward(self, inputs):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
inputs (FloatTensor): embedding features
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- inputs: batch x time x in_features
|
||||||
|
- outputs: batch x time x 128*2
|
||||||
|
"""
|
||||||
inputs = self.prenet(inputs)
|
inputs = self.prenet(inputs)
|
||||||
return self.cbhg(inputs, input_lengths)
|
return self.cbhg(inputs)
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, memory_dim, r):
|
r"""Decoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features (int): input vector (encoder output) sample size.
|
||||||
|
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||||
|
r (int): number of outputs per time step.
|
||||||
|
eps (float): threshold for detecting the end of a sentence.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_features, memory_dim, r, eps=0.2):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
|
self.max_decoder_steps = 200
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
|
self.eps = eps
|
||||||
self.r = r
|
self.r = r
|
||||||
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
|
# input -> |Linear| -> processed_inputs
|
||||||
# attetion RNN
|
self.input_layer = nn.Linear(in_features, 256, bias=False)
|
||||||
|
# memory -> |Prenet| -> processed_memory
|
||||||
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
|
# processed_inputs, processed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
||||||
self.attention_rnn = AttentionWrapper(
|
self.attention_rnn = AttentionWrapper(
|
||||||
nn.GRUCell(256 + 128, 256),
|
nn.GRUCell(in_features + 128, 256),
|
||||||
BahdanauAttention(256)
|
BahdanauAttention(256)
|
||||||
)
|
)
|
||||||
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.memory_layer = nn.Linear(256, 256, bias=False)
|
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||||
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
# concat and project context and attention vectors
|
|
||||||
# (prenet_out + attention context) -> output
|
|
||||||
self.project_to_decoder_in = nn.Linear(512, 256)
|
|
||||||
|
|
||||||
# decoder RNNs
|
|
||||||
self.decoder_rnns = nn.ModuleList(
|
self.decoder_rnns = nn.ModuleList(
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
self.max_decoder_steps = 200
|
|
||||||
|
|
||||||
def forward(self, decoder_inputs, memory=None, memory_lengths=None):
|
def forward(self, inputs, memory=None, memory_lengths=None):
|
||||||
"""
|
r"""
|
||||||
Decoder forward step.
|
Decoder forward step.
|
||||||
|
|
||||||
If decoder inputs are not given (e.g., at testing time), as noted in
|
If decoder inputs are not given (e.g., at testing time), as noted in
|
||||||
Tacotron paper, greedy decoding is adapted.
|
Tacotron paper, greedy decoding is adapted.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
decoder_inputs: Encoder outputs. (B, T_encoder, dim)
|
inputs: Encoder outputs.
|
||||||
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
|
memory: Decoder memory (autoregression. If None (at eval-time),
|
||||||
decoder outputs are used as decoder inputs.
|
decoder outputs are used as decoder inputs.
|
||||||
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
||||||
attention masking.
|
attention masking.
|
||||||
"""
|
|
||||||
B = decoder_inputs.size(0)
|
|
||||||
|
|
||||||
processed_memory = self.memory_layer(decoder_inputs)
|
Shapes:
|
||||||
|
- inputs: batch x time x encoder_out_dim
|
||||||
|
- memory: batch x #mels_pecs x mel_spec_dim
|
||||||
|
"""
|
||||||
|
B = inputs.size(0)
|
||||||
|
|
||||||
|
# TODO: take this segment into Attention module.
|
||||||
|
processed_inputs = self.input_layer(inputs)
|
||||||
if memory_lengths is not None:
|
if memory_lengths is not None:
|
||||||
mask = get_mask_from_lengths(processed_memory, memory_lengths)
|
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
|
||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
|
@ -198,6 +267,7 @@ class Decoder(nn.Module):
|
||||||
greedy = memory is None
|
greedy = memory is None
|
||||||
|
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
|
|
||||||
# Grouping multiple frames if necessary
|
# Grouping multiple frames if necessary
|
||||||
if memory.size(-1) == self.memory_dim:
|
if memory.size(-1) == self.memory_dim:
|
||||||
memory = memory.view(B, memory.size(1) // self.r, -1)
|
memory = memory.view(B, memory.size(1) // self.r, -1)
|
||||||
|
@ -206,18 +276,18 @@ class Decoder(nn.Module):
|
||||||
self.memory_dim, self.r)
|
self.memory_dim, self.r)
|
||||||
T_decoder = memory.size(1)
|
T_decoder = memory.size(1)
|
||||||
|
|
||||||
# go frames - 0 frames tarting the sequence
|
# go frame - 0 frames tarting the sequence
|
||||||
initial_input = Variable(
|
initial_memory = Variable(
|
||||||
decoder_inputs.data.new(B, self.memory_dim * self.r).zero_())
|
inputs.data.new(B, self.memory_dim * self.r).zero_())
|
||||||
|
|
||||||
# Init decoder states
|
# Init decoder states
|
||||||
attention_rnn_hidden = Variable(
|
attention_rnn_hidden = Variable(
|
||||||
decoder_inputs.data.new(B, 256).zero_())
|
inputs.data.new(B, 256).zero_())
|
||||||
decoder_rnn_hiddens = [Variable(
|
decoder_rnn_hiddens = [Variable(
|
||||||
decoder_inputs.data.new(B, 256).zero_())
|
inputs.data.new(B, 256).zero_())
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_attention = Variable(
|
current_context_vec = Variable(
|
||||||
decoder_inputs.data.new(B, 256).zero_())
|
inputs.data.new(B, 256).zero_())
|
||||||
|
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
|
@ -227,21 +297,21 @@ class Decoder(nn.Module):
|
||||||
alignments = []
|
alignments = []
|
||||||
|
|
||||||
t = 0
|
t = 0
|
||||||
current_input = initial_input
|
memory_input = initial_memory
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
current_input = outputs[-1] if greedy else memory[t - 1]
|
memory_input = outputs[-1] if greedy else memory[t - 1]
|
||||||
# Prenet
|
# Prenet
|
||||||
current_input = self.prenet(current_input)
|
processed_memory = self.prenet(memory_input)
|
||||||
|
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
||||||
current_input, current_attention, attention_rnn_hidden,
|
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||||
decoder_inputs, processed_memory=processed_memory, mask=mask)
|
inputs, processed_inputs=processed_inputs, mask=mask)
|
||||||
|
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((attention_rnn_hidden, current_attention), -1))
|
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||||
|
|
||||||
# Pass through the decoder RNNs
|
# Pass through the decoder RNNs
|
||||||
for idx in range(len(self.decoder_rnns)):
|
for idx in range(len(self.decoder_rnns)):
|
||||||
|
@ -261,10 +331,11 @@ class Decoder(nn.Module):
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
if greedy:
|
if greedy:
|
||||||
if t > 1 and is_end_of_frames(output):
|
if t > 1 and is_end_of_frames(output, self.eps):
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print("Warning! doesn't seems to be converged")
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||||
|
Something is probably wrong.")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
|
@ -279,5 +350,5 @@ class Decoder(nn.Module):
|
||||||
return outputs, alignments
|
return outputs, alignments
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, eps=0.2):
|
def is_end_of_frames(output, eps=0.1): #0.2
|
||||||
return (output.data <= eps).all()
|
return (output.data <= eps).all()
|
||||||
|
|
Двоичные данные
models/.tacotron.py.swp
Двоичные данные
models/.tacotron.py.swp
Двоичный файл не отображается.
|
@ -2,8 +2,9 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from utils.text.symbols import symbols
|
from TTS.utils.text.symbols import symbols
|
||||||
from Tacotron.layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
|
||||||
|
|
||||||
|
|
||||||
class Tacotron(nn.Module):
|
class Tacotron(nn.Module):
|
||||||
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
|
||||||
|
@ -15,10 +16,12 @@ class Tacotron(nn.Module):
|
||||||
self.use_memory_mask = use_memory_mask
|
self.use_memory_mask = use_memory_mask
|
||||||
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
self.embedding = nn.Embedding(len(symbols), embedding_dim,
|
||||||
padding_idx=padding_idx)
|
padding_idx=padding_idx)
|
||||||
|
print(" | > Embedding dim : {}".format(len(symbols)))
|
||||||
|
|
||||||
# Trying smaller std
|
# Trying smaller std
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(embedding_dim)
|
self.encoder = Encoder(embedding_dim)
|
||||||
self.decoder = Decoder(mel_dim, r)
|
self.decoder = Decoder(256, mel_dim, r)
|
||||||
|
|
||||||
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
||||||
self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
|
self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
|
||||||
|
@ -28,7 +31,7 @@ class Tacotron(nn.Module):
|
||||||
|
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(characters)
|
||||||
# (B, T', in_dim)
|
# (B, T', in_dim)
|
||||||
encoder_outputs = self.encoder(inputs, input_lengths)
|
encoder_outputs = self.encoder(inputs)
|
||||||
|
|
||||||
if self.use_memory_mask:
|
if self.use_memory_mask:
|
||||||
memory_lengths = input_lengths
|
memory_lengths = input_lengths
|
||||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,51 @@
|
||||||
|
import io
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from TTS.utils.text import text_to_sequence
|
||||||
|
from matplotlib import pylab as plt
|
||||||
|
|
||||||
|
hop_length = 250
|
||||||
|
|
||||||
|
def create_speech(m, s, CONFIG, use_cuda, ap):
|
||||||
|
text_cleaner = [CONFIG.text_cleaner]
|
||||||
|
seq = np.array(text_to_sequence(s, text_cleaner))
|
||||||
|
|
||||||
|
# mel = np.zeros([seq.shape[0], CONFIG.num_mels, 1], dtype=np.float32)
|
||||||
|
|
||||||
|
if use_cuda:
|
||||||
|
chars_var = torch.autograd.Variable(torch.from_numpy(seq), volatile=True).unsqueeze(0).cuda()
|
||||||
|
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.cuda.FloatTensor), volatile=True).cuda()
|
||||||
|
else:
|
||||||
|
chars_var = torch.autograd.Variable(torch.from_numpy(seq), volatile=True).unsqueeze(0)
|
||||||
|
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.FloatTensor), volatile=True)
|
||||||
|
|
||||||
|
mel_out, linear_out, alignments =m.forward(chars_var)
|
||||||
|
linear_out = linear_out[0].data.cpu().numpy()
|
||||||
|
alignment = alignments[0].cpu().data.numpy()
|
||||||
|
spec = ap._denormalize(linear_out)
|
||||||
|
wav = ap.inv_spectrogram(linear_out.T)
|
||||||
|
wav = wav[:ap.find_endpoint(wav)]
|
||||||
|
out = io.BytesIO()
|
||||||
|
ap.save_wav(wav, out)
|
||||||
|
return wav, alignment, spec
|
||||||
|
|
||||||
|
|
||||||
|
def visualize(alignment, spectrogram, CONFIG):
|
||||||
|
label_fontsize = 16
|
||||||
|
plt.figure(figsize=(16,16))
|
||||||
|
|
||||||
|
plt.subplot(2,1,1)
|
||||||
|
plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)
|
||||||
|
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
|
||||||
|
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
||||||
|
plt.colorbar()
|
||||||
|
|
||||||
|
plt.subplot(2,1,2)
|
||||||
|
librosa.display.specshow(spectrogram.T, sr=CONFIG.sample_rate,
|
||||||
|
hop_length=hop_length, x_axis="time", y_axis="linear")
|
||||||
|
plt.xlabel("Time", fontsize=label_fontsize)
|
||||||
|
plt.ylabel("Hz", fontsize=label_fontsize)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.colorbar()
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
librosa
|
||||||
|
inflect
|
||||||
|
unidecode
|
||||||
|
tensorboard
|
||||||
|
tensorboardX
|
16
synthesis.py
16
synthesis.py
|
@ -38,17 +38,11 @@ def main(args):
|
||||||
|
|
||||||
# Sentences for generation
|
# Sentences for generation
|
||||||
sentences = [
|
sentences = [
|
||||||
"And it is worth mention in passing that, as an example of fine typography,",
|
"I try my best to translate text to speech. But I know I need more work",
|
||||||
# From July 8, 2017 New York Times:
|
"The new Firefox, Fast for good.",
|
||||||
'Scientists at the CERN laboratory say they have discovered a new particle.',
|
"Technology is continually providing us with new ways to create and publish stories.",
|
||||||
'There’s a way to measure the acute emotional intelligence that has never gone out of style.',
|
"For these stories to achieve their full impact, it requires tool.",
|
||||||
'President Trump met with other leaders at the Group of 20 conference.',
|
"I am allien and I am here to destron your world."
|
||||||
'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.',
|
|
||||||
# From Google's Tacotron example page:
|
|
||||||
'Generative adversarial network or variational auto-encoder.',
|
|
||||||
'The buses aren\'t the problem, they actually provide a solution.',
|
|
||||||
'Does the quick brown fox jump over the lazy dog?',
|
|
||||||
'Talib Kweli confirmed to AllHipHop that he will be releasing an album in the next year.',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Synthesis and save to wav files
|
# Synthesis and save to wav files
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
import unittest
|
||||||
|
import torch as T
|
||||||
|
|
||||||
|
from TTS.layers.tacotron import Prenet, CBHG, Decoder, Encoder
|
||||||
|
|
||||||
|
|
||||||
|
class PrenetTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_in_out(self):
|
||||||
|
layer = Prenet(128, out_features=[256, 128])
|
||||||
|
dummy_input = T.autograd.Variable(T.rand(4, 128))
|
||||||
|
|
||||||
|
|
||||||
|
print(layer)
|
||||||
|
output = layer(dummy_input)
|
||||||
|
assert output.shape[0] == 4
|
||||||
|
assert output.shape[1] == 128
|
||||||
|
|
||||||
|
|
||||||
|
class CBHGTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_in_out(self):
|
||||||
|
layer = CBHG(128, K= 6, projections=[128, 128], num_highways=2)
|
||||||
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||||
|
|
||||||
|
print(layer)
|
||||||
|
output = layer(dummy_input)
|
||||||
|
assert output.shape[0] == 4
|
||||||
|
assert output.shape[1] == 8
|
||||||
|
assert output.shape[2] == 256
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_in_out(self):
|
||||||
|
layer = Decoder(in_features=128, memory_dim=32, r=5)
|
||||||
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||||
|
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
|
||||||
|
|
||||||
|
print(layer)
|
||||||
|
output, alignment = layer(dummy_input, dummy_memory)
|
||||||
|
print(output.shape)
|
||||||
|
assert output.shape[0] == 4
|
||||||
|
assert output.shape[1] == 120 / 5
|
||||||
|
assert output.shape[2] == 32 * 5
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_in_out(self):
|
||||||
|
layer = Encoder(128)
|
||||||
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||||
|
|
||||||
|
print(layer)
|
||||||
|
output = layer(dummy_input)
|
||||||
|
print(output.shape)
|
||||||
|
assert output.shape[0] == 4
|
||||||
|
assert output.shape[1] == 8
|
||||||
|
assert output.shape[2] == 256 # 128 * 2 BiRNN
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from TTS.utils.generic_utils import load_config
|
||||||
|
from TTS.datasets.LJSpeech import LJSpeechDataset
|
||||||
|
|
||||||
|
|
||||||
|
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||||
|
|
||||||
|
class TestDataset(unittest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(TestDataset, self).__init__(*args, **kwargs)
|
||||||
|
self.max_loader_iter = 4
|
||||||
|
|
||||||
|
def test_loader(self):
|
||||||
|
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||||
|
os.path.join(c.data_path, 'wavs'),
|
||||||
|
c.r,
|
||||||
|
c.sample_rate,
|
||||||
|
c.text_cleaner,
|
||||||
|
c.num_mels,
|
||||||
|
c.min_level_db,
|
||||||
|
c.frame_shift_ms,
|
||||||
|
c.frame_length_ms,
|
||||||
|
c.preemphasis,
|
||||||
|
c.ref_level_db,
|
||||||
|
c.num_freq,
|
||||||
|
c.power
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||||
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
item_idx = data[4]
|
||||||
|
|
||||||
|
neg_values = text_input[text_input < 0]
|
||||||
|
check_count = len(neg_values)
|
||||||
|
assert check_count == 0, \
|
||||||
|
" !! Negative values in text_input: {}".format(check_count)
|
||||||
|
# TODO: more assertion here
|
||||||
|
assert linear_input.shape[0] == c.batch_size
|
||||||
|
assert mel_input.shape[0] == c.batch_size
|
||||||
|
assert mel_input.shape[2] == c.num_mels
|
||||||
|
|
||||||
|
def test_padding(self):
|
||||||
|
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||||
|
os.path.join(c.data_path, 'wavs'),
|
||||||
|
1,
|
||||||
|
c.sample_rate,
|
||||||
|
c.text_cleaner,
|
||||||
|
c.num_mels,
|
||||||
|
c.min_level_db,
|
||||||
|
c.frame_shift_ms,
|
||||||
|
c.frame_length_ms,
|
||||||
|
c.preemphasis,
|
||||||
|
c.ref_level_db,
|
||||||
|
c.num_freq,
|
||||||
|
c.power
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size=1,
|
||||||
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
if i == self.max_loader_iter:
|
||||||
|
break
|
||||||
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
item_idx = data[4]
|
||||||
|
|
||||||
|
# check the last time step to be zero padded
|
||||||
|
assert mel_input[0, -1].sum() == 0
|
||||||
|
assert mel_input[0, -2].sum() != 0
|
||||||
|
assert linear_input[0, -1].sum() == 0
|
||||||
|
assert linear_input[0, -2].sum() != 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
{
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"sample_rate": 20000,
|
||||||
|
"frame_length_ms": 50,
|
||||||
|
"frame_shift_ms": 12.5,
|
||||||
|
"preemphasis": 0.97,
|
||||||
|
"min_level_db": -100,
|
||||||
|
"ref_level_db": 20,
|
||||||
|
"hidden_size": 128,
|
||||||
|
"embedding_size": 256,
|
||||||
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
|
"epochs": 2000,
|
||||||
|
"lr": 0.003,
|
||||||
|
"lr_patience": 5,
|
||||||
|
"lr_decay": 0.5,
|
||||||
|
"batch_size": 2,
|
||||||
|
"r": 5,
|
||||||
|
|
||||||
|
"griffin_lim_iters": 60,
|
||||||
|
"power": 1.5,
|
||||||
|
|
||||||
|
"num_loader_workers": 4,
|
||||||
|
|
||||||
|
"save_step": 200,
|
||||||
|
"data_path": "/data/shared/KeithIto/LJSpeech-1.0",
|
||||||
|
"output_path": "result",
|
||||||
|
"log_dir": "/home/erogol/projects/TTS/logs/"
|
||||||
|
}
|
252
train.py
252
train.py
|
@ -1,28 +1,33 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import datetime
|
||||||
import shutil
|
import shutil
|
||||||
import torch
|
import torch
|
||||||
import signal
|
import signal
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
from torch import onnx
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
load_config)
|
save_best_model, load_config, lr_decay)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
|
@ -33,39 +38,73 @@ def main(args):
|
||||||
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
||||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
||||||
|
|
||||||
|
# save config to tmp place to be loaded by subsequent modules.
|
||||||
|
file_name = str(os.getpid())
|
||||||
|
tmp_path = os.path.join("/tmp/", file_name+'_tts')
|
||||||
|
pickle.dump(c, open(tmp_path, "wb"))
|
||||||
|
|
||||||
|
# setup tensorboard
|
||||||
|
LOG_DIR = OUT_PATH
|
||||||
|
tb = SummaryWriter(LOG_DIR)
|
||||||
|
|
||||||
# Ctrl+C handler to remove empty experiment folder
|
# Ctrl+C handler to remove empty experiment folder
|
||||||
def signal_handler(signal, frame):
|
def signal_handler(signal, frame):
|
||||||
print(" !! Pressed Ctrl+C !!")
|
print(" !! Pressed Ctrl+C !!")
|
||||||
remove_experiment_folder(OUT_PATH)
|
remove_experiment_folder(OUT_PATH)
|
||||||
sys.exit(0)
|
sys.exit(1)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
# Setup the dataset
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.r,
|
c.r,
|
||||||
c.sample_rate,
|
c.sample_rate,
|
||||||
c.text_cleaner
|
c.text_cleaner,
|
||||||
|
c.num_mels,
|
||||||
|
c.min_level_db,
|
||||||
|
c.frame_shift_ms,
|
||||||
|
c.frame_length_ms,
|
||||||
|
c.preemphasis,
|
||||||
|
c.ref_level_db,
|
||||||
|
c.num_freq,
|
||||||
|
c.power
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||||
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=True, num_workers=c.num_loader_workers)
|
||||||
|
|
||||||
|
# setup the model
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size,
|
||||||
c.hidden_size,
|
c.hidden_size,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.r)
|
c.r)
|
||||||
|
|
||||||
|
# plot model on tensorboard
|
||||||
|
dummy_input = dataset.get_dummy_data()
|
||||||
|
|
||||||
|
## TODO: onnx does not support RNN fully yet
|
||||||
|
# model_proto_path = os.path.join(OUT_PATH, "model.proto")
|
||||||
|
# onnx.export(model, dummy_input, model_proto_path, verbose=True)
|
||||||
|
# tb.add_graph_onnx(model_proto_path)
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = nn.DataParallel(model.cuda())
|
model = nn.DataParallel(model.cuda())
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
|
|
||||||
try:
|
if args.restore_step:
|
||||||
checkpoint = torch.load(os.path.join(
|
checkpoint = torch.load(os.path.join(
|
||||||
CHECKPOINT_PATH, 'checkpoint_%d.pth.tar' % args.restore_step))
|
args.restore_path, 'checkpoint_%d.pth.tar' % args.restore_step))
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
print("\n > Model restored from step %d\n" % args.restore_step)
|
print("\n > Model restored from step %d\n" % args.restore_step)
|
||||||
|
start_epoch = checkpoint['step'] // len(dataloader)
|
||||||
except:
|
best_loss = checkpoint['linear_loss']
|
||||||
print("\n > Starting a new training\n")
|
else:
|
||||||
|
start_epoch = 0
|
||||||
|
print("\n > Starting a new training")
|
||||||
|
|
||||||
model = model.train()
|
model = model.train()
|
||||||
|
|
||||||
|
@ -79,112 +118,153 @@ def main(args):
|
||||||
|
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
|
|
||||||
for epoch in range(c.epochs):
|
#lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay,
|
||||||
|
# patience=c.lr_patience, verbose=True)
|
||||||
|
epoch_time = 0
|
||||||
|
best_loss = float('inf')
|
||||||
|
for epoch in range(0, c.epochs):
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
|
||||||
drop_last=True, num_workers=32)
|
|
||||||
progbar = Progbar(len(dataset) / c.batch_size)
|
progbar = Progbar(len(dataset) / c.batch_size)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for num_iter, data in enumerate(dataloader):
|
||||||
text_input = data[0]
|
start_time = time.time()
|
||||||
magnitude_input = data[1]
|
|
||||||
mel_input = data[2]
|
|
||||||
|
|
||||||
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
text_input = data[0]
|
||||||
|
text_lengths = data[1]
|
||||||
|
linear_input = data[2]
|
||||||
|
mel_input = data[3]
|
||||||
|
|
||||||
|
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
|
||||||
|
|
||||||
|
# setup lr
|
||||||
|
current_lr = lr_decay(c.lr, current_step)
|
||||||
|
for params_group in optimizer.param_groups:
|
||||||
|
params_group['lr'] = current_lr
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
try:
|
# Add a single frame of zeros to Mel Specs for better end detection
|
||||||
mel_input = np.concatenate((np.zeros(
|
#try:
|
||||||
[c.batch_size, 1, c.num_mels], dtype=np.float32),
|
# mel_input = np.concatenate((np.zeros(
|
||||||
mel_input[:, 1:, :]), axis=1)
|
# [c.batch_size, 1, c.num_mels], dtype=np.float32),
|
||||||
except:
|
# mel_input[:, 1:, :]), axis=1)
|
||||||
raise TypeError("not same dimension")
|
#except:
|
||||||
|
# raise TypeError("not same dimension")
|
||||||
|
|
||||||
|
# convert inputs to variables
|
||||||
|
text_input_var = Variable(text_input)
|
||||||
|
mel_spec_var = Variable(mel_input)
|
||||||
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
|
# sort sequence by length.
|
||||||
|
# TODO: might be unnecessary
|
||||||
|
sorted_lengths, indices = torch.sort(
|
||||||
|
text_lengths.view(-1), dim=0, descending=True)
|
||||||
|
sorted_lengths = sorted_lengths.long().numpy()
|
||||||
|
|
||||||
|
text_input_var = text_input_var[indices]
|
||||||
|
mel_spec_var = mel_spec_var[indices]
|
||||||
|
linear_spec_var = linear_spec_var[indices]
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
text_input_var = text_input_var.cuda()
|
||||||
torch.cuda.LongTensor), requires_grad=False).cuda()
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
torch.cuda.FloatTensor), requires_grad=False).cuda()
|
|
||||||
mel_spec_var = Variable(torch.from_numpy(mel_input).type(
|
|
||||||
torch.cuda.FloatTensor), requires_grad=False).cuda()
|
|
||||||
linear_spec_var = Variable(torch.from_numpy(magnitude_input)
|
|
||||||
.type(torch.cuda.FloatTensor), requires_grad=False).cuda()
|
|
||||||
|
|
||||||
else:
|
|
||||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
|
||||||
torch.LongTensor), requires_grad=False)
|
|
||||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
|
||||||
torch.FloatTensor), requires_grad=False)
|
|
||||||
mel_spec_var = Variable(torch.from_numpy(
|
|
||||||
mel_input).type(torch.FloatTensor), requires_grad=False)
|
|
||||||
linear_spec_var = Variable(torch.from_numpy(
|
|
||||||
magnitude_input).type(torch.FloatTensor),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments =\
|
||||||
model.forward(text_input_var, mel_input_var)
|
model.forward(text_input_var, mel_spec_var,
|
||||||
|
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
|
||||||
|
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var)
|
||||||
linear_loss = torch.abs(linear_output - linear_spec_var)
|
#linear_loss = torch.abs(linear_output - linear_spec_var)
|
||||||
linear_loss = 0.5 * \
|
#linear_loss = 0.5 * \
|
||||||
torch.mean(linear_loss) + 0.5 * \
|
#torch.mean(linear_loss) + 0.5 * \
|
||||||
torch.mean(linear_loss[:, :n_priority_freq, :])
|
#torch.mean(linear_loss[:, :n_priority_freq, :])
|
||||||
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
||||||
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
|
linear_spec_var[: ,: ,:n_priority_freq])
|
||||||
loss = mel_loss + linear_loss
|
loss = mel_loss + linear_loss
|
||||||
loss = loss.cuda()
|
# loss = loss.cuda()
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need
|
||||||
nn.utils.clip_grad_norm(model.parameters(), 1.)
|
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
time_per_step = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
progbar.update(i, values=[('total_loss', loss.data[0]),
|
epoch_time += step_time
|
||||||
('linear_loss', linear_loss.data[0]),
|
|
||||||
('mel_loss', mel_loss.data[0])])
|
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||||
|
('linear_loss', linear_loss.data[0]),
|
||||||
|
('mel_loss', mel_loss.data[0]),
|
||||||
|
('grad_norm', grad_norm)])
|
||||||
|
|
||||||
|
# Plot Learning Stats
|
||||||
|
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
||||||
|
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
|
||||||
|
current_step)
|
||||||
|
tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step)
|
||||||
|
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
||||||
|
current_step)
|
||||||
|
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||||
|
tb.add_scalar('Time/StepTime', step_time, current_step)
|
||||||
|
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
align_img = plot_alignment(align_img)
|
||||||
|
tb.add_image('Attn/Alignment', align_img, current_step)
|
||||||
|
|
||||||
if current_step % c.save_step == 0:
|
if current_step % c.save_step == 0:
|
||||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
|
||||||
checkpoint_path = os.path.join(OUT_PATH, checkpoint_path)
|
|
||||||
save_checkpoint({'model': model.state_dict(),
|
|
||||||
'optimizer': optimizer.state_dict(),
|
|
||||||
'step': current_step,
|
|
||||||
'total_loss': loss.data[0],
|
|
||||||
'linear_loss': linear_loss.data[0],
|
|
||||||
'mel_loss': mel_loss.data[0],
|
|
||||||
'date': datetime.date.today().strftime("%B %d, %Y")},
|
|
||||||
checkpoint_path)
|
|
||||||
print(" > Checkpoint is saved : {}".format(checkpoint_path))
|
|
||||||
|
|
||||||
if current_step in c.decay_step:
|
if c.checkpoint:
|
||||||
optimizer = adjust_learning_rate(optimizer, current_step)
|
# save model
|
||||||
|
save_checkpoint(model, optimizer, linear_loss.data[0],
|
||||||
|
best_loss, OUT_PATH,
|
||||||
|
current_step, epoch)
|
||||||
|
|
||||||
|
# Diagnostic visualizations
|
||||||
|
const_spec = linear_output[0].data.cpu().numpy()
|
||||||
|
gt_spec = linear_spec_var[0].data.cpu().numpy()
|
||||||
|
|
||||||
|
const_spec = plot_spectrogram(const_spec, dataset.ap)
|
||||||
|
gt_spec = plot_spectrogram(gt_spec, dataset.ap)
|
||||||
|
tb.add_image('Spec/Reconstruction', const_spec, current_step)
|
||||||
|
tb.add_image('Spec/GroundTruth', gt_spec, current_step)
|
||||||
|
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
align_img = plot_alignment(align_img)
|
||||||
|
tb.add_image('Attn/Alignment', align_img, current_step)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
audio_signal = linear_output[0].data.cpu().numpy()
|
||||||
|
dataset.ap.griffin_lim_iters = 60
|
||||||
|
audio_signal = dataset.ap.inv_spectrogram(audio_signal.T)
|
||||||
|
try:
|
||||||
|
tb.add_audio('SampleAudio', audio_signal, current_step,
|
||||||
|
sample_rate=c.sample_rate)
|
||||||
|
except:
|
||||||
|
print("\n > Error at audio signal on TB!!")
|
||||||
|
print(audio_signal.max())
|
||||||
|
print(audio_signal.min())
|
||||||
|
|
||||||
|
|
||||||
def adjust_learning_rate(optimizer, step):
|
# average loss after the epoch
|
||||||
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
avg_epoch_loss = np.mean(
|
||||||
if step == 500000:
|
progbar.sum_values['linear_loss'][0] / max(1, progbar.sum_values['linear_loss'][1]))
|
||||||
for param_group in optimizer.param_groups:
|
best_loss = save_best_model(model, optimizer, avg_epoch_loss,
|
||||||
param_group['lr'] = 0.0005
|
best_loss, OUT_PATH,
|
||||||
|
current_step, epoch)
|
||||||
|
|
||||||
elif step == 1000000:
|
#lr_scheduler.step(loss.data[0])
|
||||||
for param_group in optimizer.param_groups:
|
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||||
param_group['lr'] = 0.0003
|
epoch_time = 0
|
||||||
|
|
||||||
elif step == 2000000:
|
|
||||||
for param_group in optimizer.param_groups:
|
|
||||||
param_group['lr'] = 0.0001
|
|
||||||
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--restore_step', type=int,
|
parser.add_argument('--restore_step', type=int,
|
||||||
help='Global step to restore checkpoint', default=128)
|
help='Global step to restore checkpoint', default=0)
|
||||||
|
parser.add_argument('--restore_path', type=str,
|
||||||
|
help='Folder path to checkpoints', default=0)
|
||||||
parser.add_argument('--config_path', type=str,
|
parser.add_argument('--config_path', type=str,
|
||||||
help='path to config file for training',)
|
help='path to config file for training',)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Двоичные данные
utils/.data.py.swp
Двоичные данные
utils/.data.py.swp
Двоичный файл не отображается.
Двоичные данные
utils/.generic_utils.py.swp
Двоичные данные
utils/.generic_utils.py.swp
Двоичный файл не отображается.
154
utils/audio.py
154
utils/audio.py
|
@ -1,108 +1,124 @@
|
||||||
|
import os
|
||||||
import librosa
|
import librosa
|
||||||
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
|
||||||
_mel_basis = None
|
_mel_basis = None
|
||||||
|
|
||||||
|
|
||||||
def save_wav(wav, path):
|
class AudioProcessor(object):
|
||||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
|
||||||
librosa.output.write_wav(path, wav.astype(np.int16), c.sample_rate)
|
def __init__(self, sample_rate, num_mels, min_level_db, frame_shift_ms,
|
||||||
|
frame_length_ms, preemphasis, ref_level_db, num_freq, power,
|
||||||
|
griffin_lim_iters=None):
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.num_mels = num_mels
|
||||||
|
self.min_level_db = min_level_db
|
||||||
|
self.frame_shift_ms = frame_shift_ms
|
||||||
|
self.frame_length_ms = frame_length_ms
|
||||||
|
self.preemphasis = preemphasis
|
||||||
|
self.ref_level_db = ref_level_db
|
||||||
|
self.num_freq = num_freq
|
||||||
|
self.power = power
|
||||||
|
self.griffin_lim_iters = griffin_lim_iters
|
||||||
|
|
||||||
|
|
||||||
def _linear_to_mel(spectrogram):
|
def save_wav(self, wav, path):
|
||||||
global _mel_basis
|
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||||
if _mel_basis is None:
|
librosa.output.write_wav(path, wav.astype(np.int16), self.sample_rate)
|
||||||
_mel_basis = _build_mel_basis()
|
|
||||||
return np.dot(_mel_basis, spectrogram)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_mel_basis():
|
def _linear_to_mel(self, spectrogram):
|
||||||
n_fft = (c.num_freq - 1) * 2
|
global _mel_basis
|
||||||
return librosa.filters.mel(c.sample_rate, n_fft, n_mels=c.num_mels)
|
if _mel_basis is None:
|
||||||
|
_mel_basis = self._build_mel_basis()
|
||||||
|
return np.dot(_mel_basis, spectrogram)
|
||||||
|
|
||||||
|
|
||||||
def _normalize(S):
|
def _build_mel_basis(self, ):
|
||||||
return np.clip((S - c.min_level_db) / -c.min_level_db, 0, 1)
|
n_fft = (self.num_freq - 1) * 2
|
||||||
|
return librosa.filters.mel(self.sample_rate, n_fft, n_mels=self.num_mels)
|
||||||
|
|
||||||
|
|
||||||
def _denormalize(S):
|
def _normalize(self, S):
|
||||||
return (np.clip(S, 0, 1) * -c.min_level_db) + c.min_level_db
|
return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
def _stft_parameters():
|
def _denormalize(self, S):
|
||||||
n_fft = (c.num_freq - 1) * 2
|
return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
|
||||||
hop_length = int(c.frame_shift_ms / 1000 * c.sample_rate)
|
|
||||||
win_length = int(c.frame_length_ms / 1000 * c.sample_rate)
|
|
||||||
return n_fft, hop_length, win_length
|
|
||||||
|
|
||||||
|
|
||||||
def _amp_to_db(x):
|
def _stft_parameters(self, ):
|
||||||
return 20 * np.log10(np.maximum(1e-5, x))
|
n_fft = (self.num_freq - 1) * 2
|
||||||
|
hop_length = int(self.frame_shift_ms / 1000 * self.sample_rate)
|
||||||
|
win_length = int(self.frame_length_ms / 1000 * self.sample_rate)
|
||||||
|
return n_fft, hop_length, win_length
|
||||||
|
|
||||||
|
|
||||||
def _db_to_amp(x):
|
def _amp_to_db(self, x):
|
||||||
return np.power(10.0, x * 0.05)
|
return 20 * np.log10(np.maximum(1e-5, x))
|
||||||
|
|
||||||
|
|
||||||
def preemphasis(x):
|
def _db_to_amp(self, x):
|
||||||
return signal.lfilter([1, -c.preemphasis], [1], x)
|
return np.power(10.0, x * 0.05)
|
||||||
|
|
||||||
|
|
||||||
def inv_preemphasis(x):
|
def apply_preemphasis(self, x):
|
||||||
return signal.lfilter([1], [1, -c.preemphasis], x)
|
return signal.lfilter([1, -self.preemphasis], [1], x)
|
||||||
|
|
||||||
|
|
||||||
def spectrogram(y):
|
def apply_inv_preemphasis(self, x):
|
||||||
D = _stft(preemphasis(y))
|
return signal.lfilter([1], [1, -self.preemphasis], x)
|
||||||
S = _amp_to_db(np.abs(D)) - c.ref_level_db
|
|
||||||
return _normalize(S)
|
|
||||||
|
|
||||||
|
|
||||||
def inv_spectrogram(spectrogram):
|
def spectrogram(self, y):
|
||||||
'''Converts spectrogram to waveform using librosa'''
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
|
S = self._amp_to_db(np.abs(D)) - self.ref_level_db
|
||||||
S = _denormalize(spectrogram)
|
return self._normalize(S)
|
||||||
S = _db_to_amp(S + c.ref_level_db) # Convert back to linear
|
|
||||||
|
|
||||||
# Reconstruct phase
|
|
||||||
return inv_preemphasis(_griffin_lim(S ** c.power))
|
|
||||||
|
|
||||||
|
|
||||||
def _griffin_lim(S):
|
def inv_spectrogram(self, spectrogram):
|
||||||
'''librosa implementation of Griffin-Lim
|
'''Converts spectrogram to waveform using librosa'''
|
||||||
Based on https://github.com/librosa/librosa/issues/434
|
S = self._denormalize(spectrogram)
|
||||||
'''
|
S = self._db_to_amp(S + self.ref_level_db) # Convert back to linear
|
||||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
# Reconstruct phase
|
||||||
S_complex = np.abs(S).astype(np.complex)
|
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||||
y = _istft(S_complex * angles)
|
|
||||||
for i in range(c.griffin_lim_iters):
|
|
||||||
angles = np.exp(1j * np.angle(_stft(y)))
|
|
||||||
y = _istft(S_complex * angles)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def _istft(y):
|
def _griffin_lim(self, S):
|
||||||
_, hop_length, win_length = _stft_parameters()
|
'''librosa implementation of Griffin-Lim
|
||||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
Based on https://github.com/librosa/librosa/issues/434
|
||||||
|
'''
|
||||||
|
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||||
|
S_complex = np.abs(S).astype(np.complex)
|
||||||
|
y = self._istft(S_complex * angles)
|
||||||
|
for i in range(self.griffin_lim_iters):
|
||||||
|
angles = np.exp(1j * np.angle(self._stft(y)))
|
||||||
|
y = self._istft(S_complex * angles)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
def melspectrogram(y):
|
def melspectrogram(self, y):
|
||||||
D = _stft(preemphasis(y))
|
D = self._stft(self.apply_preemphasis(y))
|
||||||
S = _amp_to_db(_linear_to_mel(np.abs(D)))
|
S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
|
||||||
return _normalize(S)
|
return self._normalize(S)
|
||||||
|
|
||||||
|
|
||||||
def _stft(y):
|
def _stft(self, y):
|
||||||
n_fft, hop_length, win_length = _stft_parameters()
|
n_fft, hop_length, win_length = self._stft_parameters()
|
||||||
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||||
|
|
||||||
|
def _istft(self, y):
|
||||||
|
_, hop_length, win_length = self._stft_parameters()
|
||||||
|
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
||||||
|
|
||||||
|
|
||||||
def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8):
|
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||||
window_length = int(c.sample_rate * min_silence_sec)
|
window_length = int(self.sample_rate * min_silence_sec)
|
||||||
hop_length = int(window_length / 4)
|
hop_length = int(window_length / 4)
|
||||||
threshold = _db_to_amp(threshold_db)
|
threshold = self._db_to_amp(threshold_db)
|
||||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||||
if np.max(wav[x:x + window_length]) < threshold:
|
if np.max(wav[x:x + window_length]) < threshold:
|
||||||
return x + hop_length
|
return x + hop_length
|
||||||
return len(wav)
|
return len(wav)
|
||||||
|
|
|
@ -3,7 +3,10 @@ import numpy as np
|
||||||
|
|
||||||
def pad_data(x, length):
|
def pad_data(x, length):
|
||||||
_pad = 0
|
_pad = 0
|
||||||
return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
assert x.ndim == 1
|
||||||
|
return np.pad(x, (0, length - x.shape[0]),
|
||||||
|
mode='constant',
|
||||||
|
constant_values=_pad)
|
||||||
|
|
||||||
|
|
||||||
def prepare_data(inputs):
|
def prepare_data(inputs):
|
||||||
|
@ -11,8 +14,8 @@ def prepare_data(inputs):
|
||||||
return np.stack([pad_data(x, max_len) for x in inputs])
|
return np.stack([pad_data(x, max_len) for x in inputs])
|
||||||
|
|
||||||
|
|
||||||
def pad_per_step(inputs, outputs_per_step):
|
def pad_per_step(inputs, pad_len):
|
||||||
timesteps = inputs.shape[-1]
|
timesteps = inputs.shape[-1]
|
||||||
return np.pad(inputs, [[0, 0], [0, 0],
|
return np.pad(inputs, [[0, 0], [0, 0],
|
||||||
[0, outputs_per_step - (timesteps % outputs_per_step)]],
|
[0, pad_len]],
|
||||||
mode='constant', constant_values=0.0)
|
mode='constant', constant_values=0.0)
|
||||||
|
|
|
@ -5,6 +5,7 @@ import time
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,8 +35,9 @@ def remove_experiment_folder(experiment_path):
|
||||||
|
|
||||||
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
|
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
|
||||||
if len(checkpoint_files) < 1:
|
if len(checkpoint_files) < 1:
|
||||||
shutil.rmtree(experiment_path)
|
if os.path.exists(experiment_path):
|
||||||
print(" ! Run is removed from {}".format(experiment_path))
|
shutil.rmtree(experiment_path)
|
||||||
|
print(" ! Run is removed from {}".format(experiment_path))
|
||||||
else:
|
else:
|
||||||
print(" ! Run is kept in {}".format(experiment_path))
|
print(" ! Run is kept in {}".format(experiment_path))
|
||||||
|
|
||||||
|
@ -46,10 +48,44 @@ def copy_config_file(config_file, path):
|
||||||
shutil.copyfile(config_file, out_path)
|
shutil.copyfile(config_file, out_path)
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(state, filename='checkpoint.pth.tar'):
|
def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
|
||||||
torch.save(state, filename)
|
current_step, epoch):
|
||||||
|
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
|
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||||
|
print("\n | > Checkpoint saving : {}".format(checkpoint_path))
|
||||||
|
state = {'model': model.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'step': current_step,
|
||||||
|
'epoch': epoch,
|
||||||
|
'linear_loss': model_loss,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y")}
|
||||||
|
torch.save(state, checkpoint_path)
|
||||||
|
|
||||||
|
|
||||||
|
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||||
|
current_step, epoch):
|
||||||
|
if model_loss < best_loss:
|
||||||
|
state = {'model': model.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'step': current_step,
|
||||||
|
'epoch': epoch,
|
||||||
|
'linear_loss': model_loss,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y")}
|
||||||
|
best_loss = model_loss
|
||||||
|
bestmodel_path = 'best_model.pth.tar'
|
||||||
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||||
|
print("\n | > Best model saving with loss {0:.2f} : {1:}".format(model_loss, bestmodel_path))
|
||||||
|
torch.save(state, bestmodel_path)
|
||||||
|
return best_loss
|
||||||
|
|
||||||
|
|
||||||
|
def lr_decay(init_lr, global_step):
|
||||||
|
warmup_steps = 4000.0
|
||||||
|
step = global_step + 1.
|
||||||
|
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
|
||||||
|
step**-0.5)
|
||||||
|
return lr
|
||||||
|
|
||||||
class Progbar(object):
|
class Progbar(object):
|
||||||
"""Displays a progress bar.
|
"""Displays a progress bar.
|
||||||
# Arguments
|
# Arguments
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
#-*- coding: utf-8 -*-
|
#-*- coding: utf-8 -*-
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from Tacotron.utils.text import cleaners
|
from TTS.utils.text import cleaners
|
||||||
from Tacotron.utils.text.symbols import symbols
|
from TTS.utils.text.symbols import symbols
|
||||||
|
|
||||||
|
|
||||||
# Mappings from symbol to numeric ID and vice versa:
|
# Mappings from symbol to numeric ID and vice versa:
|
||||||
|
|
|
@ -7,7 +7,7 @@ Defines the set of symbols used in text input to the model.
|
||||||
The default is a set of ASCII characters that works well for English or text that has been run
|
The default is a set of ASCII characters that works well for English or text that has been run
|
||||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||||
'''
|
'''
|
||||||
from Tacotron.utils.text import cmudict
|
from TTS.utils.text import cmudict
|
||||||
|
|
||||||
_pad = '_'
|
_pad = '_'
|
||||||
_eos = '~'
|
_eos = '~'
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def plot_alignment(alignment, info=None):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
im = ax.imshow(alignment.T, aspect='auto', origin='lower',
|
||||||
|
interpolation='none')
|
||||||
|
fig.colorbar(im, ax=ax)
|
||||||
|
xlabel = 'Decoder timestep'
|
||||||
|
if info is not None:
|
||||||
|
xlabel += '\n\n' + info
|
||||||
|
plt.xlabel(xlabel)
|
||||||
|
plt.ylabel('Encoder timestep')
|
||||||
|
plt.tight_layout()
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
plt.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def plot_spectrogram(linear_output, audio):
|
||||||
|
spectrogram = audio._denormalize(linear_output)
|
||||||
|
fig = plt.figure(figsize=(16, 10))
|
||||||
|
plt.imshow(spectrogram.T, aspect="auto", origin="lower")
|
||||||
|
plt.colorbar()
|
||||||
|
plt.tight_layout()
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||||
|
plt.close()
|
||||||
|
return data
|
Загрузка…
Ссылка в новой задаче