speaker encoder implementation

This commit is contained in:
Eren Golge 2019-11-01 12:23:03 +01:00
Родитель ec579d02a1
Коммит 6906adc39e
13 изменённых файлов: 1179 добавлений и 0 удалений

12
speaker_encoder/README.md Normal file
Просмотреть файл

@ -0,0 +1,12 @@
### Speaker embedding (Experimental)
This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding. So you can generate d-vectors for multi-speaker TTS or prune bad samples from your TTS dataset. Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook.
![](https://user-images.githubusercontent.com/1402048/64603079-7fa5c100-d3c8-11e9-88e7-88a00d0e37d1.png)
To run the code, you need to follow the same flow as in TTS.
- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model.
- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360```
- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth.tar model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files.
- Watch training on Tensorboard as in TTS

Просмотреть файл

Просмотреть файл

@ -0,0 +1,64 @@
import argparse
import glob
import os
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from TTS.datasets.preprocess import get_preprocessor_by_name
from TTS.speaker_encoder.dataset import MyDataset
from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.speaker_encoder.visual import plot_embeddings
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import load_config
parser = argparse.ArgumentParser(
description='Compute embedding vectors for each wav file in a dataset. ')
parser.add_argument(
'model_path',
type=str,
help='Path to model outputs (checkpoint, tensorboard etc.).')
parser.add_argument(
'config_path',
type=str,
help='Path to config file for training.',
)
parser.add_argument(
'data_path',
type=str,
help='Defines the data path. It overwrites config.json.')
parser.add_argument(
'output_path',
type=str,
help='path for training outputs.')
parser.add_argument(
'--use_cuda', type=bool, help='flag to set cuda.', default=False
)
args = parser.parse_args()
c = load_config(args.config_path)
ap = AudioProcessor(**c['audio'])
wav_files = glob.glob(args.data_path + '/**/*.wav', recursive=True)
output_files = [wav_file.replace(args.data_path, args.output_path).replace(
'.wav', '.npy') for wav_file in wav_files]
for output_file in output_files:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
model = SpeakerEncoder(**c.model)
model.load_state_dict(torch.load(args.model_path)['model'])
model.eval()
if args.use_cuda:
model.cuda()
for idx, wav_file in enumerate(tqdm(wav_files)):
mel_spec = ap.melspectrogram(ap.load_wav(wav_file)).T
mel_spec = torch.FloatTensor(mel_spec[None, :, :])
if args.use_cuda:
mel_spec = mel_spec.cuda()
embedd = model.compute_embedding(mel_spec)
np.save(output_files[idx], embedd.detach().cpu().numpy())

Просмотреть файл

@ -0,0 +1,58 @@
{
"run_name": "libritts_360-half",
"run_description": "train speaker encoder for libritts 360",
"audio": {
// Audio processing parameters
"num_mels": 40, // size of the mel spec frame.
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"frame_length_ms": 50, // stft window length in ms.
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
// Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
"do_trim_silence": false // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
},
"reinit_layers": [],
"grad_clip": 3.0, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"steps_plot_stats": 10, // number of steps to plot embeddings.
"num_speakers_in_batch": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 1, // Number of steps to log traning on console.
"output_path": "/media/erogol/data_ssd/Models/libri_tts/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
"model": {
"input_dim": 40,
"proj_dim": 128,
"lstm_dim": 384,
"num_lstm_layers": 3
},
"datasets":
[
{
"name": "libri_tts",
"path": "/home/erogol/Data/Libri-TTS/train-clean-360/",
"meta_file_train": null,
"meta_file_val": null
},
{
"name": "libri_tts",
"path": "/home/erogol/Data/Libri-TTS/train-clean-100/",
"meta_file_train": null,
"meta_file_val": null
}
]
}

128
speaker_encoder/dataset.py Normal file
Просмотреть файл

@ -0,0 +1,128 @@
import os
import numpy as np
import collections
import torch
import random
from torch.utils.data import Dataset
from TTS.utils.text import text_to_sequence, phoneme_to_sequence, pad_with_eos_bos
from TTS.utils.data import prepare_data, prepare_tensor, prepare_stop_target
class MyDataset(Dataset):
def __init__(self,
ap,
meta_data,
voice_len=1.6,
num_speakers_in_batch=64,
num_utter_per_speaker=10,
skip_speakers=False,
verbose=False):
"""
Args:
ap (TTS.utils.AudioProcessor): audio processor object.
meta_data (list): list of dataset instances.
seq_len (int): voice segment length in seconds.
verbose (bool): print diagnostic information.
"""
self.items = meta_data
self.sample_rate = ap.sample_rate
self.voice_len = voice_len
self.seq_len = int(voice_len * self.sample_rate)
self.num_utter_per_speaker = num_utter_per_speaker
self.skip_speakers = skip_speakers
self.ap = ap
self.verbose = verbose
self.__parse_items()
if self.verbose:
print("\n > DataLoader initialization")
print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num speakers: {len(self.speakers)}")
def load_wav(self, filename):
audio = self.ap.load_wav(filename)
return audio
def load_data(self, idx):
text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype('float32')
# sample seq_len
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
sample = {
'mel': mel,
'item_idx': self.items[idx][1],
'speaker_name': speaker_name
}
return sample
def __parse_items(self):
"""
Find unique speaker ids and create a dict mapping utterances from speaker id
"""
speakers = list(set([item[-1] for item in self.items]))
self.speaker_to_utters = {}
self.speakers = []
for speaker in speakers:
speaker_utters = [item[1] for item in self.items if item[2] == speaker]
if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
print(f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}.")
else:
self.speakers.append(speaker)
self.speaker_to_utters[speaker] = speaker_utters
def __len__(self):
return int(1e+10)
def __sample_speaker(self):
speaker = random.sample(self.speakers, 1)[0]
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
else:
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker)
return speaker, utters
def __sample_speaker_utterances(self, speaker):
"""
Sample all M utterances for the given speaker.
"""
feats = []
labels = []
for idx in range(self.num_utter_per_speaker):
# TODO:dummy but works
while True:
if len(self.speaker_to_utters[speaker]) > 0:
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
else:
self.speakers.remove(speaker)
speaker, _ = self.__sample_speaker()
continue
wav = self.load_wav(utter)
if wav.shape[0] - self.seq_len > 0:
break
else:
self.speaker_to_utters[speaker].remove(utter)
offset = random.randint(0, wav.shape[0] - self.seq_len)
mel = self.ap.melspectrogram(wav[offset:offset+self.seq_len])
feats.append(torch.FloatTensor(mel))
labels.append(speaker)
return feats, labels
def __getitem__(self, idx):
speaker, _ = self.__sample_speaker()
return speaker
def collate_fn(self, batch):
labels = []
feats = []
for speaker in batch:
feats_, labels_ = self.__sample_speaker_utterances(speaker)
labels.append(labels_)
feats.extend(feats_)
feats = torch.stack(feats)
return feats.transpose(1, 2), labels

Просмотреть файл

@ -0,0 +1,41 @@
import os
import datetime
import torch
def save_checkpoint(model, optimizer, model_loss, out_path,
current_step, epoch):
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict() if optimizer is not None else None,
'step': current_step,
'epoch': epoch,
'GE2Eloss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y"),
}
torch.save(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'step': current_step,
'GE2Eloss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = 'best_model.pth.tar'
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
return best_loss

104
speaker_encoder/loss.py Normal file
Просмотреть файл

@ -0,0 +1,104 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# adapted from https://github.com/cvqluu/GE2E-Loss
class GE2ELoss(nn.Module):
def __init__(self, init_w=10.0, init_b=-5.0, loss_method='softmax'):
'''
Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
Accepts an input of size (N, M, D)
where N is the number of speakers in the batch,
M is the number of utterances per speaker,
and D is the dimensionality of the embedding vector (e.g. d-vector)
Args:
- init_w (float): defines the initial value of w in Equation (5) of [1]
- init_b (float): definies the initial value of b in Equation (5) of [1]
'''
super(GE2ELoss, self).__init__()
self.w = nn.Parameter(torch.tensor(init_w))
self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method
assert self.loss_method in ['softmax', 'contrast']
if self.loss_method == 'softmax':
self.embed_loss = self.embed_loss_softmax
if self.loss_method == 'contrast':
self.embed_loss = self.embed_loss_contrast
def calc_new_centroids(self, dvecs, centroids, spkr, utt):
'''
Calculates the new centroids excluding the reference utterance
'''
excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt+1:]))
excl = torch.mean(excl, 0)
new_centroids = []
for i, centroid in enumerate(centroids):
if i == spkr:
new_centroids.append(excl)
else:
new_centroids.append(centroid)
return torch.stack(new_centroids)
def calc_cosine_sim(self, dvecs, centroids):
'''
Make the cosine similarity matrix with dims (N,M,N)
'''
cos_sim_matrix = []
for spkr_idx, speaker in enumerate(dvecs):
cs_row = []
for utt_idx, utterance in enumerate(speaker):
new_centroids = self.calc_new_centroids(
dvecs, centroids, spkr_idx, utt_idx)
# vector based cosine similarity for speed
cs_row.append(torch.clamp(torch.mm(utterance.unsqueeze(1).transpose(0, 1), new_centroids.transpose(
0, 1)) / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), 1e-6))
cs_row = torch.cat(cs_row, dim=0)
cos_sim_matrix.append(cs_row)
return torch.stack(cos_sim_matrix)
def embed_loss_softmax(self, dvecs, cos_sim_matrix):
'''
Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
'''
N, M, _ = dvecs.shape
L = []
for j in range(N):
L_row = []
for i in range(M):
L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j])
L_row = torch.stack(L_row)
L.append(L_row)
return torch.stack(L)
def embed_loss_contrast(self, dvecs, cos_sim_matrix):
'''
Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
'''
N, M, _ = dvecs.shape
L = []
for j in range(N):
L_row = []
for i in range(M):
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
excl_centroids_sigmoids = torch.cat(
(centroids_sigmoids[:j], centroids_sigmoids[j+1:]))
L_row.append(
1. - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids))
L_row = torch.stack(L_row)
L.append(L_row)
return torch.stack(L)
def forward(self, dvecs):
'''
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
'''
centroids = torch.mean(dvecs, 1)
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
torch.clamp(self.w, 1e-6)
cos_sim_matrix = self.w * cos_sim_matrix + self.b
L = self.embed_loss(dvecs, cos_sim_matrix)
return L.mean()

87
speaker_encoder/model.py Normal file
Просмотреть файл

@ -0,0 +1,87 @@
import torch
from torch import nn
class LSTMWithProjection(nn.Module):
def __init__(self, input_size, hidden_size, proj_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.proj_size = proj_size
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.linear = nn.Linear(hidden_size, proj_size, bias=False)
def forward(self, x):
self.lstm.flatten_parameters()
o, (h, c) = self.lstm(x)
return self.linear(o)
class SpeakerEncoder(nn.Module):
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3):
super().__init__()
layers = []
layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
for _ in range(num_lstm_layers-1):
layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
self.layers = nn.Sequential(*layers)
self._init_layers()
def _init_layers(self):
for name, param in self.layers.named_parameters():
if 'bias' in name:
nn.init.constant_(param, 0.0)
elif 'weight' in name:
nn.init.xavier_normal_(param)
def forward(self, x):
# TODO: implement state passing for lstms
d = self.layers(x)
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
return d
def inference(self, x):
d = self.layers.forward(x)
d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
return d
def compute_embedding(self, x, num_frames=160, overlap=0.5):
"""
Generate embeddings for a batch of utterances
x: 1xTxD
"""
num_overlap = int(num_frames * overlap)
max_len = x.shape[1]
embed = None
cur_iter = 0
for offset in range(0, max_len, num_frames - num_overlap):
cur_iter += 1
end_offset = min(x.shape[1], offset + num_frames)
frames = x[:, offset:end_offset]
if embed is None:
embed = self.inference(frames)
else:
embed += self.inference(frames)
return embed / cur_iter
def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
"""
Generate embeddings for a batch of utterances
x: BxTxD
"""
num_overlap = num_frames * overlap
max_len = x.shape[1]
embed = None
num_iters = seq_lens / (num_frames - num_overlap)
cur_iter = 0
for offset in range(0, max_len, num_frames - num_overlap):
cur_iter += 1
end_offset = min(x.shape[1], offset + num_frames)
frames = x[:, offset:end_offset]
if embed is None:
embed = self.inference(frames)
else:
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
return embed / num_iters

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

80
speaker_encoder/tests.py Normal file
Просмотреть файл

@ -0,0 +1,80 @@
import os
import unittest
import torch as T
from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.speaker_encoder.loss import GE2ELoss
from TTS.speaker_encoder.dataset import MyDataset
from TTS.utils.audio import AudioProcessor
from torch.utils.data import DataLoader
from TTS.datasets.preprocess import libri_tts
from TTS.utils.generic_utils import load_config
file_path = os.path.dirname(os.path.realpath(__file__)) + "/../tests/"
c = load_config(os.path.join(file_path, 'test_config.json'))
class SpeakerEncoderTests(unittest.TestCase):
def test_in_out(self):
dummy_input = T.rand(4, 20, 80) # B x T x D
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = SpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
# computing d vectors
output = model.forward(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 256
output = model.inference(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 256
# compute d vectors by passing LSTM hidden
# output = model.forward(dummy_input, dummy_hidden)
# assert output.shape[0] == 4
# assert output.shape[1] == 20
# assert output.shape[2] == 256
# check normalization
output_norm = T.nn.functional.normalize(output, dim=1, p=2)
assert_diff = (output_norm - output).sum().item()
assert output.type() == 'torch.FloatTensor'
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch
dummy_input = T.rand(1, 240, 80) # B x T x D
output = model.compute_embedding(dummy_input, num_frames=160, overlap=0.5)
assert output.shape[0] == 1
assert output.shape[1] == 256
assert len(output.shape) == 2
class GE2ELossTests(unittest.TestCase):
def test_in_out(self):
# check random input
dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim
loss = GE2ELoss(loss_method='softmax')
output = loss.forward(dummy_input)
assert output.item() >= 0.
# check all zeros
dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim
loss = GE2ELoss(loss_method='softmax')
output = loss.forward(dummy_input)
# check speaker loss with orthogonal d-vectors
dummy_input = T.empty(3, 64)
dummy_input = T.nn.init.orthogonal(dummy_input)
dummy_input = T.cat([dummy_input[0].repeat(5, 1, 1).transpose(0, 1), dummy_input[1].repeat(5, 1, 1).transpose(0, 1), dummy_input[2].repeat(5, 1, 1).transpose(0, 1)]) # num_speaker x num_utterance x dim
loss = GE2ELoss(loss_method='softmax')
output = loss.forward(dummy_input)
assert output.item() < 0.005
# class LoaderTest(unittest.TestCase):
# def test_output(self):
# items = libri_tts("/home/erogol/Data/Libri-TTS/train-clean-360/")
# ap = AudioProcessor(**c['audio'])
# dataset = MyDataset(ap, items, 1.6, 64, 10)
# loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn)
# count = 0
# for mel, spk in loader:
# print(mel.shape)
# if count == 4:
# break
# count += 1

315
speaker_encoder/train.py Normal file
Просмотреть файл

@ -0,0 +1,315 @@
import argparse
import os
import sys
import time
import traceback
import torch
from torch import optim
from torch.utils.data import DataLoader
from TTS.datasets.preprocess import load_meta_data
from TTS.speaker_encoder.dataset import MyDataset
from TTS.speaker_encoder.generic_utils import save_best_model, save_checkpoint
from TTS.speaker_encoder.loss import GE2ELoss
from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.speaker_encoder.visual import plot_embeddings
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (NoamLR, check_update, copy_config_file,
count_parameters,
create_experiment_folder, get_git_branch,
gradual_training_scheduler, load_config,
remove_experiment_folder, set_init_dict,
setup_model, split_dataset)
from TTS.utils.logger import Logger
from TTS.utils.radam import RAdam
from TTS.utils.visual import plot_alignment, plot_spectrogram
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.manual_seed(54321)
use_cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap, is_val=False, verbose=False):
global meta_data_train
global meta_data_eval
if "meta_data_train" not in globals():
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
if is_val:
loader = None
else:
dataset = MyDataset(ap,
meta_data_eval if is_val else meta_data_train,
voice_len=1.6,
num_utter_per_speaker=10,
skip_speakers=False,
verbose=verbose)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=c.num_speakers_in_batch,
shuffle=False,
num_workers=0,
collate_fn=dataset.collate_fn)
return loader
def train(model, criterion, optimizer, scheduler, ap, global_step):
data_loader = setup_loader(ap, is_val=False, verbose=True)
model.train()
epoch_time = 0
best_loss = float('inf')
avg_loss = 0
end_time = time.time()
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# setup input data
inputs = data[0]
labels = data[1]
loader_time = time.time() - end_time
global_step += 1
# setup lr
if c.lr_decay:
scheduler.step()
optimizer.zero_grad()
# dispatch data to GPU
if use_cuda:
inputs = inputs.cuda(non_blocking=True)
# labels = labels.cuda(non_blocking=True)
# forward pass model
outputs = model(inputs)
# loss computation
loss = criterion(
outputs.view(c.num_speakers_in_batch,
outputs.shape[0] // c.num_speakers_in_batch, -1))
loss.backward()
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
step_time = time.time() - start_time
epoch_time += step_time
avg_loss = 0.01 * loss.item(
) + 0.99 * avg_loss if avg_loss != 0 else loss.item()
current_lr = optimizer.param_groups[0]['lr']
if global_step % c.steps_plot_stats == 0:
# Plot Training Epoch Stats
train_stats = {
"GE2Eloss": avg_loss,
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time
}
tb_logger.tb_train_epoch_stats(global_step, train_stats)
figures = {
# FIXME: not constant
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(),
10),
}
tb_logger.tb_train_figures(global_step, figures)
if global_step % c.print_step == 0:
print(
" | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} "
"StepTime:{:.2f} LoaderTime:{:.2f} LR:{:.6f}".format(
global_step, loss.item(), avg_loss, grad_norm, step_time,
loader_time, current_lr),
flush=True)
# save best model
best_loss = save_best_model(model, optimizer, avg_loss, best_loss,
OUT_PATH, global_step)
end_time = time.time()
return avg_loss, global_step
# def evaluate(model, criterion, ap, global_step, epoch):
# data_loader = setup_loader(ap, is_val=True)
# model.eval()
# epoch_time = 0
# avg_loss = 0
# print("\n > Validation")
# with torch.no_grad():
# if data_loader is not None:
# for num_iter, data in enumerate(data_loader):
# start_time = time.time()
# # setup input data
# inputs = data[0]
# labels = data[1]
# # dispatch data to GPU
# if use_cuda:
# inputs = inputs.cuda()
# # labels = labels.cuda()
# # forward pass
# outputs = model.forward(inputs)
# # loss computation
# loss = criterion(outputs.reshape(
# c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1))
# step_time = time.time() - start_time
# epoch_time += step_time
# if num_iter % c.print_step == 0:
# print(
# " | > Loss: {:.5f} ".format(loss.item()),
# flush=True)
# avg_loss += float(loss.item())
# eval_figures = {
# "prediction": plot_spectrogram(const_spec, ap),
# "ground_truth": plot_spectrogram(gt_spec, ap),
# "alignment": plot_alignment(align_img)
# }
# tb_logger.tb_eval_figures(global_step, eval_figures)
# # Sample audio
# if c.model in ["Tacotron", "TacotronGST"]:
# eval_audio = ap.inv_spectrogram(const_spec.T)
# else:
# eval_audio = ap.inv_mel_spectrogram(const_spec.T)
# tb_logger.tb_eval_audios(
# global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
# # compute average losses
# avg_loss /= (num_iter + 1)
# # Plot Validation Stats
# epoch_stats = {"GE2Eloss": avg_loss}
# tb_logger.tb_eval_stats(global_step, epoch_stats)
# return avg_loss
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name
ap = AudioProcessor(**c.audio)
model = SpeakerEncoder(input_dim=40,
proj_dim=128,
lstm_dim=384,
num_lstm_layers=3)
optimizer = RAdam(model.parameters(), lr=c.lr)
criterion = GE2ELoss(loss_method='softmax')
if args.restore_path:
checkpoint = torch.load(args.restore_path)
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
# optimizer restore
# optimizer.load_state_dict(checkpoint['optimizer'])
if c.reinit_layers:
raise RuntimeError
model.load_state_dict(checkpoint['model'])
except:
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint, c)
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group['lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
else:
args.restore_step = 0
if use_cuda:
model = model.cuda()
criterion.cuda()
if c.lr_decay:
scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)
else:
scheduler = None
num_params = count_parameters(model)
print("\n > Model has {} parameters".format(num_params), flush=True)
global_step = args.restore_step
train_loss, global_step = train(model, criterion, optimizer, scheduler, ap,
global_step)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--restore_path',
type=str,
help='Path to model outputs (checkpoint, tensorboard etc.).',
default=0)
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
)
parser.add_argument('--debug',
type=bool,
default=True,
help='Do not verify commit integrity to run training.')
parser.add_argument(
'--data_path',
type=str,
default='',
help='Defines the data path. It overwrites config.json.')
parser.add_argument('--output_path',
type=str,
help='path for training outputs.',
default='')
parser.add_argument('--output_folder',
type=str,
default='',
help='folder name for training outputs.')
args = parser.parse_args()
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
if args.data_path != '':
c.data_path = args.data_path
if args.output_path == '':
OUT_PATH = os.path.join(_, c.output_path)
else:
OUT_PATH = args.output_path
if args.output_folder == '':
OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
else:
OUT_PATH = os.path.join(OUT_PATH, args.output_folder)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_config_file(args.config_path, os.path.join(OUT_PATH, 'config.json'),
new_fields)
LOG_DIR = OUT_PATH
tb_logger = Logger(LOG_DIR)
try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)

Двоичные данные
speaker_encoder/umap.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 23 KiB

40
speaker_encoder/visual.py Normal file
Просмотреть файл

@ -0,0 +1,40 @@
import umap
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
colormap = np.array([
[76, 255, 0],
[0, 127, 70],
[255, 0, 0],
[255, 217, 38],
[0, 135, 255],
[165, 0, 165],
[255, 167, 255],
[0, 255, 255],
[255, 96, 38],
[142, 76, 0],
[33, 0, 127],
[0, 0, 0],
[183, 183, 183],
], dtype=np.float) / 255
def plot_embeddings(embeddings, num_utter_per_speaker):
embeddings = embeddings[:10*num_utter_per_speaker]
model = umap.UMAP()
projection = model.fit_transform(embeddings)
num_speakers = embeddings.shape[0] // num_utter_per_speaker
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker)
colors = [colormap[i] for i in ground_truth]
fig, ax = plt.subplots(figsize=(16, 10))
im = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
plt.gca().set_aspect("equal", "datalim")
plt.title("UMAP projection")
plt.tight_layout()
plt.savefig("umap")
return fig