add diarization demo for unispeech_sat pre-training model

This commit is contained in:
czy97 2021-11-22 16:25:24 +08:00
Родитель 2816e682dc
Коммит 822afbabc6
13 изменённых файлов: 1946 добавлений и 0 удалений

Просмотреть файл

@ -0,0 +1,34 @@
## Pre-training Representations for Speaker Diarization
### Downstream Model
[EEND-vector-clustering](https://arxiv.org/abs/2105.09040)
### Pre-trained models
- It should be noted that the diarization system is trained on 8k audio data.
| Model | 2 spk DER | 3 spk DER | 4 spk DER | 5 spk DER | 6 spk DER | ALL spk DER |
| ------------------------------------------------------------ | --------- | --------- | --------- | --------- | --------- | ----------- |
| EEND-vector-clustering | 7.96 | 11.93 | 16.38 | 21.21 | 23.1 | 12.49 |
| [**UniSpeech-SAT large**](https://drive.google.com/file/d/16OwIyOk2uYm0aWtSPaS0S12xE8RxF7k_/view?usp=sharing) | 5.93 | 10.66 | 12.90 | 16.48 | 23.25 | 10.92 |
### How to use?
#### Environment Setup
1. `pip install -r requirements.txt`
2. Install fairseq code
- For UniSpeech-SAT large, we should install the [Unispeech-SAT](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT) fairseq code.
#### Example
1. First, you should download the pre-trained model in the above table to `checkpoint_path`.
2. Then, run the following codes:
- The wav file is the multi-talker simulated speech from Librispeech corpus.
3. The output will be written in `out.rttm` by default.
```bash
python diarization.py --wav_path tmp/mix_0000496.wav --model_init $checkpoint_path
```

Просмотреть файл

@ -0,0 +1,31 @@
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
# All rights reserved
# inference options
est_nspk: 1
sil_spk_th: 0.05
ahc_dis_th: 1.0
clink_dis: 1.0e+4
model:
n_speakers: 3
all_n_speakers: 0
feat_dim: 1024
n_units: 256
n_heads: 8
n_layers: 6
dropout_rate: 0.1
spk_emb_dim: 256
sr: 8000
frame_shift: 320
frame_size: 200
context_size: 0
subsampling: 1
feat_type: "config/unispeech_sat.th"
feature_selection: "hidden_states"
interpolate_mode: "linear"
dataset:
chunk_size: 750
frame_shift: 320
sampling_rate: 8000
subsampling: 1
num_speakers: 3

Двоичные данные
UniSpeech-SAT/speaker_diarization/config/unispeech_sat.th Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,321 @@
import sys
import h5py
import soundfile as sf
import fire
import math
import yamlargparse
import numpy as np
from torch.utils.data import DataLoader
import torch
from utils.utils import parse_config_or_kwargs
from utils.dataset import DiarizationDataset
from models.models import TransformerDiarization
from scipy.signal import medfilt
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial import distance
from utils.kaldi_data import KaldiData
def get_cl_sil(args, acti, cls_num):
n_chunks = len(acti)
mean_acti = np.array([np.mean(acti[i], axis=0)
for i in range(n_chunks)]).flatten()
n = args.num_speakers
sil_spk_th = args.sil_spk_th
cl_lst = []
sil_lst = []
for chunk_idx in range(n_chunks):
if cls_num is not None:
if args.num_speakers > cls_num:
mean_acti_bi = np.array([mean_acti[n * chunk_idx + s_loc_idx]
for s_loc_idx in range(n)])
min_idx = np.argmin(mean_acti_bi)
mean_acti[n * chunk_idx + min_idx] = 0.0
for s_loc_idx in range(n):
a = n * chunk_idx + (s_loc_idx + 0) % n
b = n * chunk_idx + (s_loc_idx + 1) % n
if mean_acti[a] > sil_spk_th and mean_acti[b] > sil_spk_th:
cl_lst.append((a, b))
else:
if mean_acti[a] <= sil_spk_th:
sil_lst.append(a)
return cl_lst, sil_lst
def clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst):
org_svec_len = len(svec)
svec = np.delete(svec, sil_lst, 0)
# update cl_lst idx
_tbl = [i - sum(sil < i for sil in sil_lst) for i in range(org_svec_len)]
cl_lst = [(_tbl[_cl[0]], _tbl[_cl[1]]) for _cl in cl_lst]
distMat = distance.cdist(svec, svec, metric='euclidean')
for cl in cl_lst:
distMat[cl[0], cl[1]] = args.clink_dis
distMat[cl[1], cl[0]] = args.clink_dis
clusterer = AgglomerativeClustering(
n_clusters=cls_num,
affinity='precomputed',
linkage='average',
distance_threshold=ahc_dis_th)
clusterer.fit(distMat)
if cls_num is not None:
print("oracle n_clusters is known")
else:
print("oracle n_clusters is unknown")
print("estimated n_clusters by constraind AHC: {}"
.format(len(np.unique(clusterer.labels_))))
cls_num = len(np.unique(clusterer.labels_))
sil_lab = cls_num
insert_sil_lab = [sil_lab for i in range(len(sil_lst))]
insert_sil_lab_idx = [sil_lst[i] - i for i in range(len(sil_lst))]
print("insert_sil_lab : {}".format(insert_sil_lab))
print("insert_sil_lab_idx : {}".format(insert_sil_lab_idx))
clslab = np.insert(clusterer.labels_,
insert_sil_lab_idx,
insert_sil_lab).reshape(-1, args.num_speakers)
print("clslab : {}".format(clslab))
return clslab, cls_num
def merge_act_max(act, i, j):
for k in range(len(act)):
act[k, i] = max(act[k, i], act[k, j])
act[k, j] = 0.0
return act
def merge_acti_clslab(args, acti, clslab, cls_num):
sil_lab = cls_num
for i in range(len(clslab)):
_lab = clslab[i].reshape(-1, 1)
distM = distance.cdist(_lab, _lab, metric='euclidean').astype(np.int64)
for j in range(len(distM)):
distM[j][:j] = -1
idx_lst = np.where(np.count_nonzero(distM == 0, axis=1) > 1)
merge_done = []
for j in idx_lst[0]:
for k in (np.where(distM[j] == 0))[0]:
if j != k and clslab[i, j] != sil_lab and k not in merge_done:
print("merge : (i, j, k) == ({}, {}, {})".format(i, j, k))
acti[i] = merge_act_max(acti[i], j, k)
clslab[i, k] = sil_lab
merge_done.append(j)
return acti, clslab
def stitching(args, acti, clslab, cls_num):
n_chunks = len(acti)
s_loc = args.num_speakers
sil_lab = cls_num
s_tot = max(cls_num, s_loc-1)
# Extend the max value of s_loc_idx to s_tot+1
add_acti = []
for chunk_idx in range(n_chunks):
zeros = np.zeros((len(acti[chunk_idx]), s_tot+1))
if s_tot+1 > s_loc:
zeros[:, :-(s_tot+1-s_loc)] = acti[chunk_idx]
else:
zeros = acti[chunk_idx]
add_acti.append(zeros)
acti = np.array(add_acti)
out_chunks = []
for chunk_idx in range(n_chunks):
# Make sloci2lab_dct.
# key: s_loc_idx
# value: estimated label by clustering or sil_lab
cls_set = set()
for s_loc_idx in range(s_tot+1):
cls_set.add(s_loc_idx)
sloci2lab_dct = {}
for s_loc_idx in range(s_tot+1):
if s_loc_idx < s_loc:
sloci2lab_dct[s_loc_idx] = clslab[chunk_idx][s_loc_idx]
if clslab[chunk_idx][s_loc_idx] in cls_set:
cls_set.remove(clslab[chunk_idx][s_loc_idx])
else:
if clslab[chunk_idx][s_loc_idx] != sil_lab:
raise ValueError
else:
sloci2lab_dct[s_loc_idx] = list(cls_set)[s_loc_idx-s_loc]
# Sort by label value
sloci2lab_lst = sorted(sloci2lab_dct.items(), key=lambda x: x[1])
# Select sil_lab_idx
sil_lab_idx = None
for idx_lab in sloci2lab_lst:
if idx_lab[1] == sil_lab:
sil_lab_idx = idx_lab[0]
break
if sil_lab_idx is None:
raise ValueError
# Get swap_idx
# [idx of label(0), idx of label(1), ..., idx of label(s_tot)]
swap_idx = [sil_lab_idx for j in range(s_tot+1)]
for lab in range(s_tot+1):
for idx_lab in sloci2lab_lst:
if lab == idx_lab[1]:
swap_idx[lab] = idx_lab[0]
print("swap_idx {}".format(swap_idx))
swap_acti = acti[chunk_idx][:, swap_idx]
swap_acti = np.delete(swap_acti, sil_lab, 1)
out_chunks.append(swap_acti)
return out_chunks
def prediction(num_speakers, net, wav_list, chunk_len_list):
acti_lst = []
svec_lst = []
len_list = []
with torch.no_grad():
for wav, chunk_len in zip(wav_list, chunk_len_list):
wav = wav.to('cuda')
outputs = net.batch_estimate(torch.unsqueeze(wav, 0))
ys = outputs[0]
for i in range(num_speakers):
spkivecs = outputs[i+1]
svec_lst.append(spkivecs[0].cpu().detach().numpy())
acti = ys[0][-chunk_len:].cpu().detach().numpy()
acti_lst.append(acti)
len_list.append(chunk_len)
acti_arr = np.concatenate(acti_lst, axis=0) # totol_len x num_speakers
svec_arr = np.stack(svec_lst) # (chunk_num x num_speakers) x emb_dim
len_arr = np.array(len_list) # chunk_num
return acti_arr, svec_arr, len_arr
def cluster(args, conf, acti_arr, svec_arr, len_arr):
acti_list = []
n_chunks = len_arr.shape[0]
start = 0
for i in range(n_chunks):
chunk_len = len_arr[i]
acti_list.append(acti_arr[start: start+chunk_len])
start += chunk_len
acti = np.array(acti_list)
svec = svec_arr
# initialize clustering setting
cls_num = None
ahc_dis_th = args.ahc_dis_th
# Get cannot-link index list and silence index list
cl_lst, sil_lst = get_cl_sil(args, acti, cls_num)
n_samples = n_chunks * args.num_speakers - len(sil_lst)
min_n_samples = 2
if cls_num is not None:
min_n_samples = cls_num
if n_samples >= min_n_samples:
# clustering (if cls_num is None, update cls_num)
clslab, cls_num =\
clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst)
# merge
acti, clslab = merge_acti_clslab(args, acti, clslab, cls_num)
# stitching
out_chunks = stitching(args, acti, clslab, cls_num)
else:
out_chunks = acti
outdata = np.vstack(out_chunks)
# Saving the resuts
return outdata
def make_rttm(args, conf, cluster_data):
args.frame_shift = conf['model']['frame_shift']
args.subsampling = conf['model']['subsampling']
args.sampling_rate = conf['dataset']['sampling_rate']
with open(args.out_rttm_file, 'w') as wf:
a = np.where(cluster_data > args.threshold, 1, 0)
if args.median > 1:
a = medfilt(a, (args.median, 1))
for spkid, frames in enumerate(a.T):
frames = np.pad(frames, (1, 1), 'constant')
changes, = np.where(np.diff(frames, axis=0) != 0)
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
for s, e in zip(changes[::2], changes[1::2]):
print(fmt.format(
args.session,
s * args.frame_shift * args.subsampling / args.sampling_rate,
(e - s) * args.frame_shift * args.subsampling / args.sampling_rate,
args.session + "_" + str(spkid)), file=wf)
def main(args):
conf = parse_config_or_kwargs(args.config_path)
num_speakers = conf['dataset']['num_speakers']
args.num_speakers = num_speakers
# Prepare model
model_parameter_dict = torch.load(args.model_init)['model']
model_all_n_speakers = model_parameter_dict["embed.weight"].shape[0]
conf['model']['all_n_speakers'] = model_all_n_speakers
net = TransformerDiarization(**conf['model'])
net.load_state_dict(model_parameter_dict, strict=False)
net.eval()
net = net.to("cuda")
audio, sr = sf.read(args.wav_path, dtype="float32")
audio_len = audio.shape[0]
chunk_size, frame_shift, subsampling = conf['dataset']['chunk_size'], conf['model']['frame_shift'], conf['model']['subsampling']
scale_ratio = int(frame_shift * subsampling)
chunk_audio_size = chunk_size * scale_ratio
wav_list, chunk_len_list = [], []
for i in range(0, math.ceil(1.0 * audio_len / chunk_audio_size)):
start, end = i*chunk_audio_size, (i+1)*chunk_audio_size
if end > audio_len:
chunk_len_list.append(int((audio_len-start) / scale_ratio))
end = audio_len
start = max(0, audio_len - chunk_audio_size)
else:
chunk_len_list.append(chunk_size)
wav_list.append(audio[start:end])
wav_list = [torch.from_numpy(wav).float() for wav in wav_list]
acti_arr, svec_arr, len_arr = prediction(num_speakers, net, wav_list, chunk_len_list)
cluster_data = cluster(args, conf, acti_arr, svec_arr, len_arr)
make_rttm(args, conf, cluster_data)
if __name__ == '__main__':
parser = yamlargparse.ArgumentParser(description='decoding')
parser.add_argument('--wav_path',
help='the input wav path',
default="tmp/mix_0000496.wav")
parser.add_argument('--config_path',
help='config file path',
default="config/infer_est_nspk1.yaml")
parser.add_argument('--model_init',
help='model initialize path',
default="")
parser.add_argument('--sil_spk_th', default=0.05, type=float)
parser.add_argument('--ahc_dis_th', default=1.0, type=float)
parser.add_argument('--clink_dis', default=1.0e+4, type=float)
parser.add_argument('--session', default='Anonymous', help='the name of the output speaker')
parser.add_argument('--out_rttm_file', default='out.rttm', help='the output rttm file')
parser.add_argument('--threshold', default=0.4, type=float)
parser.add_argument('--median', default=25, type=int)
args = parser.parse_args()
main(args)

Просмотреть файл

@ -0,0 +1,391 @@
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
# All rights reserved
import sys
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchaudio.transforms as trans
from collections import OrderedDict
from itertools import permutations
from models.transformer import TransformerEncoder
from .utils import UpstreamExpert
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
"""
P: number of permutation
T: number of frames
C: number of speakers (classes)
B: mini-batch size
"""
def batch_pit_loss_parallel(outputs, labels, ilens=None):
""" calculate the batch pit loss parallelly
Args:
outputs (torch.Tensor): B x T x C
labels (torch.Tensor): B x T x C
ilens (torch.Tensor): B
Returns:
perm (torch.Tensor): permutation for outputs (Batch, num_spk)
loss
"""
if ilens is None:
mask, scale = 1.0, outputs.shape[1]
else:
scale = torch.unsqueeze(torch.LongTensor(ilens), 1).to(outputs.device)
mask = outputs.new_zeros(outputs.size()[:-1])
for i, chunk_len in enumerate(ilens):
mask[i, :chunk_len] += 1.0
mask /= scale
def loss_func(output, label):
# return torch.mean(F.binary_cross_entropy_with_logits(output, label, reduction='none'), dim=tuple(range(1, output.dim())))
return torch.sum(F.binary_cross_entropy_with_logits(output, label, reduction='none') * mask, dim=-1)
def pair_loss(outputs, labels, permutation):
return sum([loss_func(outputs[:,:,s], labels[:,:,t]) for s, t in enumerate(permutation)]) / len(permutation)
device = outputs.device
num_spk = outputs.shape[-1]
all_permutations = list(permutations(range(num_spk)))
losses = torch.stack([pair_loss(outputs, labels, p) for p in all_permutations], dim=1)
loss, perm = torch.min(losses, dim=1)
perm = torch.index_select(torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm)
return torch.mean(loss), perm
def fix_state_dict(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
# remove 'module.' of DataParallel
k = k[7:]
if k.startswith('net.'):
# remove 'net.' of PadertorchModel
k = k[4:]
new_state_dict[k] = v
return new_state_dict
class TransformerDiarization(nn.Module):
def __init__(self,
n_speakers,
all_n_speakers,
feat_dim,
n_units,
n_heads,
n_layers,
dropout_rate,
spk_emb_dim,
sr=8000,
frame_shift=256,
frame_size=1024,
context_size=0,
subsampling=1,
feat_type='fbank',
feature_selection='default',
interpolate_mode='linear',
update_extract=False,
feature_grad_mult=1.0
):
super(TransformerDiarization, self).__init__()
self.context_size = context_size
self.subsampling = subsampling
self.feat_type = feat_type
self.feature_selection = feature_selection
self.sr = sr
self.frame_shift = frame_shift
self.interpolate_mode = interpolate_mode
self.update_extract = update_extract
self.feature_grad_mult = feature_grad_mult
if feat_type == 'fbank':
self.feature_extract = trans.MelSpectrogram(sample_rate=sr,
n_fft=frame_size,
win_length=frame_size,
hop_length=frame_shift,
f_min=0.0,
f_max=sr // 2,
pad=0,
n_mels=feat_dim)
else:
self.feature_extract = UpstreamExpert(feat_type)
# self.feature_extract = torch.hub.load('s3prl/s3prl', 'hubert_local', ckpt=feat_type)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
# for param in self.feature_extract.parameters():
# param.requires_grad = False
self.resample = trans.Resample(orig_freq=sr, new_freq=16000)
if feat_type != 'fbank' and feat_type != 'mfcc':
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer', 'spk_proj', 'layer_norm_for_extract']
for name, param in self.feature_extract.named_parameters():
for freeze_val in freeze_list:
if freeze_val in name:
param.requires_grad = False
break
if not self.update_extract:
for param in self.feature_extract.parameters():
param.requires_grad = False
self.instance_norm = nn.InstanceNorm1d(feat_dim)
feat_dim = feat_dim * (self.context_size*2 + 1)
self.enc = TransformerEncoder(
feat_dim, n_layers, n_units, h=n_heads, dropout_rate=dropout_rate)
self.linear = nn.Linear(n_units, n_speakers)
for i in range(n_speakers):
setattr(self, '{}{:d}'.format("linear", i), nn.Linear(n_units, spk_emb_dim))
self.n_speakers = n_speakers
self.embed = nn.Embedding(all_n_speakers, spk_emb_dim)
self.alpha = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
self.beta = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
def get_feat_num(self):
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features[self.feature_selection]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def fix_except_embedding(self, requires_grad=False):
for name, param in self.named_parameters():
if 'embed' not in name:
param.requires_grad = requires_grad
def modfy_emb(self, weight):
self.embed = nn.Embedding.from_pretrained(weight)
def splice(self, data, context_size):
# data: B x feat_dim x time_len
data = torch.unsqueeze(data, -1)
kernel_size = context_size*2 + 1
splice_data = F.unfold(data, kernel_size=(kernel_size, 1), padding=(context_size, 0))
return splice_data
def get_feat(self, xs):
wav_len = xs.shape[-1]
chunk_size = int(wav_len / self.frame_shift)
chunk_size = int(chunk_size / self.subsampling)
self.feature_extract.eval()
if self.update_extract:
xs = self.resample(xs)
feature = self.feature_extract([sample for sample in xs])
else:
with torch.no_grad():
if self.feat_type == 'fbank':
feature = self.feature_extract(xs) + 1e-6 # B x feat_dim x time_len
feature = feature.log()
else:
xs = self.resample(xs)
feature = self.feature_extract([sample for sample in xs])
if self.feat_type != "fbank" and self.feat_type != "mfcc":
feature = feature[self.feature_selection]
if isinstance(feature, (list, tuple)):
feature = torch.stack(feature, dim=0)
else:
feature = feature.unsqueeze(0)
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
feature = (norm_weights * feature).sum(dim=0)
feature = torch.transpose(feature, 1, 2) + 1e-6
feature = self.instance_norm(feature)
feature = self.splice(feature, self.context_size)
feature = feature[:, :, ::self.subsampling]
feature = F.interpolate(feature, chunk_size, mode=self.interpolate_mode)
feature = torch.transpose(feature, 1, 2)
if self.feature_grad_mult != 1.0:
feature = GradMultiply.apply(feature, self.feature_grad_mult)
return feature
def forward(self, inputs):
if isinstance(inputs, list):
xs = inputs[0]
else:
xs = inputs
feature = self.get_feat(xs)
pad_shape = feature.shape
emb = self.enc(feature)
ys = self.linear(emb)
ys = ys.reshape(pad_shape[0], pad_shape[1], -1)
spksvecs = []
for i in range(self.n_speakers):
spkivecs = getattr(self, '{}{:d}'.format("linear", i))(emb)
spkivecs = spkivecs.reshape(pad_shape[0], pad_shape[1], -1)
spksvecs.append(spkivecs)
return ys, spksvecs
def get_loss(self, inputs, ys, spksvecs, cal_spk_loss=True):
ts = inputs[1]
ss = inputs[2]
ns = inputs[3]
ilens = inputs[4]
ilens = [ilen.item() for ilen in ilens]
pit_loss, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
if cal_spk_loss:
spk_loss = self.spk_loss_parallel(spksvecs, ys, ts, ss, sigmas, ns, ilens)
else:
spk_loss = torch.tensor(0.0).to(pit_loss.device)
alpha = torch.clamp(self.alpha, min=sys.float_info.epsilon)
return {'spk_loss':spk_loss,
'pit_loss': pit_loss}
def batch_estimate(self, xs):
out = self(xs)
ys = out[0]
spksvecs = out[1]
spksvecs = list(zip(*spksvecs))
outputs = [
self.estimate(spksvec, y)
for (spksvec, y) in zip(spksvecs, ys)]
outputs = list(zip(*outputs))
return outputs
def batch_estimate_with_perm(self, xs, ts, ilens=None):
out = self(xs)
ys = out[0]
if ts[0].shape[1] > ys[0].shape[1]:
# e.g. the case of training 3-spk model with 4-spk data
add_dim = ts[0].shape[1] - ys[0].shape[1]
y_device = ys[0].device
zeros = [torch.zeros(ts[0].shape).to(y_device)
for i in range(len(ts))]
_ys = []
for zero, y in zip(zeros, ys):
_zero = zero
_zero[:, :-add_dim] = y
_ys.append(_zero)
_, sigmas = batch_pit_loss_parallel(_ys, ts, ilens)
else:
_, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
spksvecs = out[1]
spksvecs = list(zip(*spksvecs))
outputs = [self.estimate(spksvec, y)
for (spksvec, y) in zip(spksvecs, ys)]
outputs = list(zip(*outputs))
zs = outputs[0]
if ts[0].shape[1] > ys[0].shape[1]:
# e.g. the case of training 3-spk model with 4-spk data
add_dim = ts[0].shape[1] - ys[0].shape[1]
z_device = zs[0].device
zeros = [torch.zeros(ts[0].shape).to(z_device)
for i in range(len(ts))]
_zs = []
for zero, z in zip(zeros, zs):
_zero = zero
_zero[:, :-add_dim] = z
_zs.append(_zero)
zs = _zs
outputs[0] = zs
outputs.append(sigmas)
# outputs: [zs, nmz_wavg_spk0vecs, nmz_wavg_spk1vecs, ..., sigmas]
return outputs
def estimate(self, spksvec, y):
outputs = []
z = torch.sigmoid(y.transpose(1, 0))
outputs.append(z.transpose(1, 0))
for spkid, spkvec in enumerate(spksvec):
norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1)
# Normalize speaker vectors before weighted average
spkvec = torch.mul(
spkvec.transpose(1, 0), norm_spkvec_inv
).transpose(1, 0)
wavg_spkvec = torch.mul(
spkvec.transpose(1, 0), z[spkid]
).transpose(1, 0)
sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0)
nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec)
outputs.append(nmz_wavg_spkvec)
# outputs: [z, nmz_wavg_spk0vec, nmz_wavg_spk1vec, ...]
return outputs
def spk_loss_parallel(self, spksvecs, ys, ts, ss, sigmas, ns, ilens):
'''
spksvecs (List[torch.Tensor, ...]): [B x T x emb_dim, ...]
ys (torch.Tensor): B x T x 3
ts (torch.Tensor): B x T x 3
ss (torch.Tensor): B x 3
sigmas (torch.Tensor): B x 3
ns (torch.Tensor): B x total_spk_num x 1
ilens (List): B
'''
chunk_spk_num = len(spksvecs) # 3
len_mask = ys.new_zeros((ys.size()[:-1])) # B x T
for i, len_val in enumerate(ilens):
len_mask[i,:len_val] += 1.0
ts = ts * len_mask.unsqueeze(-1)
len_mask = len_mask.repeat((chunk_spk_num, 1)) # B*3 x T
spk_vecs = torch.cat(spksvecs, dim=0) # B*3 x T x emb_dim
# Normalize speaker vectors before weighted average
spk_vecs = F.normalize(spk_vecs, dim=-1)
ys = torch.permute(torch.sigmoid(ys), dims=(2, 0, 1)) # 3 x B x T
ys = ys.reshape(-1, ys.shape[-1]).unsqueeze(-1) # B*3 x T x 1
weight_spk_vec = ys * spk_vecs # B*3 x T x emb_dim
weight_spk_vec *= len_mask.unsqueeze(-1)
sum_spk_vec = torch.sum(weight_spk_vec, dim=1) # B*3 x emb_dim
norm_spk_vec = F.normalize(sum_spk_vec, dim=1)
embeds = F.normalize(self.embed(ns[0]).squeeze(), dim=1) # total_spk_num x emb_dim
dist = torch.cdist(norm_spk_vec, embeds) # B*3 x total_spk_num
logits = -1.0 * torch.add(torch.clamp(self.alpha, min=sys.float_info.epsilon) * torch.pow(dist, 2), self.beta)
label = torch.gather(ss, 1, sigmas).transpose(0, 1).reshape(-1, 1).squeeze() # B*3
label[label==-1] = 0
valid_spk_mask = torch.gather(torch.sum(ts, dim=1), 1, sigmas).transpose(0, 1) # 3 x B
valid_spk_mask = (torch.flatten(valid_spk_mask) > 0).float() # B*3
valid_spk_loss_num = torch.sum(valid_spk_mask).item()
if valid_spk_loss_num > 0:
loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_loss_num
# uncomment the line below, the loss result is same as batch_spk_loss
# loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_mask.shape[0]
return torch.sum(loss)
else:
return torch.tensor(0.0).to(ys.device)

Просмотреть файл

@ -0,0 +1,147 @@
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
# All rights reserved
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
class NoamScheduler(_LRScheduler):
""" learning rate scheduler used in the transformer
See https://arxiv.org/pdf/1706.03762.pdf
lrate = d_model**(-0.5) * \
min(step_num**(-0.5), step_num*warmup_steps**(-1.5))
Scaling factor is implemented as in
http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer
"""
def __init__(
self, optimizer, d_model, warmup_steps, tot_step, scale,
last_epoch=-1
):
self.d_model = d_model
self.warmup_steps = warmup_steps
self.tot_step = tot_step
self.scale = scale
super(NoamScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
self.last_epoch = max(1, self.last_epoch)
step_num = self.last_epoch
val = self.scale * self.d_model ** (-0.5) * \
min(step_num ** (-0.5), step_num * self.warmup_steps ** (-1.5))
return [base_lr / base_lr * val for base_lr in self.base_lrs]
class MultiHeadSelfAttention(nn.Module):
""" Multi head "self" attention layer
"""
def __init__(self, n_units, h=8, dropout_rate=0.1):
super(MultiHeadSelfAttention, self).__init__()
self.linearQ = nn.Linear(n_units, n_units)
self.linearK = nn.Linear(n_units, n_units)
self.linearV = nn.Linear(n_units, n_units)
self.linearO = nn.Linear(n_units, n_units)
self.d_k = n_units // h
self.h = h
self.dropout = nn.Dropout(p=dropout_rate)
# attention for plot
self.att = None
def forward(self, x, batch_size):
# x: (BT, F)
q = self.linearQ(x).reshape(batch_size, -1, self.h, self.d_k)
k = self.linearK(x).reshape(batch_size, -1, self.h, self.d_k)
v = self.linearV(x).reshape(batch_size, -1, self.h, self.d_k)
scores = torch.matmul(
q.transpose(1, 2), k.permute(0, 2, 3, 1)) / np.sqrt(self.d_k)
# scores: (B, h, T, T) = (B, h, T, d_k) x (B, h, d_k, T)
self.att = F.softmax(scores, dim=3)
p_att = self.dropout(self.att)
x = torch.matmul(p_att, v.transpose(1, 2))
x = x.transpose(1, 2).reshape(-1, self.h * self.d_k)
return self.linearO(x)
class PositionwiseFeedForward(nn.Module):
""" Positionwise feed-forward layer
"""
def __init__(self, n_units, d_units, dropout_rate):
super(PositionwiseFeedForward, self).__init__()
self.linear1 = nn.Linear(n_units, d_units)
self.linear2 = nn.Linear(d_units, n_units)
self.dropout = nn.Dropout(p=dropout_rate)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
class PositionalEncoding(nn.Module):
""" Positional encoding function
"""
def __init__(self, n_units, dropout_rate, max_len):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout_rate)
positions = np.arange(0, max_len, dtype='f')[:, None]
dens = np.exp(
np.arange(0, n_units, 2, dtype='f') * -(np.log(10000.) / n_units))
self.enc = np.zeros((max_len, n_units), dtype='f')
self.enc[:, ::2] = np.sin(positions * dens)
self.enc[:, 1::2] = np.cos(positions * dens)
self.scale = np.sqrt(n_units)
def forward(self, x):
x = x * self.scale + self.xp.array(self.enc[:, :x.shape[1]])
return self.dropout(x)
class TransformerEncoder(nn.Module):
def __init__(self, idim, n_layers, n_units,
e_units=2048, h=8, dropout_rate=0.1):
super(TransformerEncoder, self).__init__()
self.linear_in = nn.Linear(idim, n_units)
# self.lnorm_in = nn.LayerNorm(n_units)
self.pos_enc = PositionalEncoding(n_units, dropout_rate, 5000)
self.n_layers = n_layers
self.dropout = nn.Dropout(p=dropout_rate)
for i in range(n_layers):
setattr(self, '{}{:d}'.format("lnorm1_", i),
nn.LayerNorm(n_units))
setattr(self, '{}{:d}'.format("self_att_", i),
MultiHeadSelfAttention(n_units, h, dropout_rate))
setattr(self, '{}{:d}'.format("lnorm2_", i),
nn.LayerNorm(n_units))
setattr(self, '{}{:d}'.format("ff_", i),
PositionwiseFeedForward(n_units, e_units, dropout_rate))
self.lnorm_out = nn.LayerNorm(n_units)
def forward(self, x):
# x: (B, T, F) ... batch, time, (mel)freq
BT_size = x.shape[0] * x.shape[1]
# e: (BT, F)
e = self.linear_in(x.reshape(BT_size, -1))
# Encoder stack
for i in range(self.n_layers):
# layer normalization
e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
# self-attention
s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0])
# residual
e = e + self.dropout(s)
# layer normalization
e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
# positionwise feed-forward
s = getattr(self, '{}{:d}'.format("ff_", i))(e)
# residual
e = e + self.dropout(s)
# final layer normalization
# output: (BT, F)
return self.lnorm_out(e)

Просмотреть файл

@ -0,0 +1,78 @@
import torch
import fairseq
from packaging import version
import torch.nn.functional as F
from fairseq import tasks
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from omegaconf import OmegaConf
from s3prl.upstream.interfaces import UpstreamBase
from torch.nn.utils.rnn import pad_sequence
def load_model(filepath):
state = torch.load(filepath, map_location=lambda storage, loc: storage)
# state = load_checkpoint_to_cpu(filepath)
state["cfg"] = OmegaConf.create(state["cfg"])
if "args" in state and state["args"] is not None:
cfg = convert_namespace_to_omegaconf(state["args"])
elif "cfg" in state and state["cfg"] is not None:
cfg = state["cfg"]
else:
raise RuntimeError(
f"Neither args nor cfg exist in state keys = {state.keys()}"
)
task = tasks.setup_task(cfg.task)
if "task_state" in state:
task.load_state_dict(state["task_state"])
model = task.build_model(cfg.model)
return model, cfg, task
###################
# UPSTREAM EXPERT #
###################
class UpstreamExpert(UpstreamBase):
def __init__(self, ckpt, **kwargs):
super().__init__(**kwargs)
assert version.parse(fairseq.__version__) > version.parse(
"0.10.2"
), "Please install the fairseq master branch."
model, cfg, task = load_model(ckpt)
self.model = model
self.task = task
if len(self.hooks) == 0:
module_name = "self.model.encoder.layers"
for module_id in range(len(eval(module_name))):
self.add_hook(
f"{module_name}[{module_id}]",
lambda input, output: input[0].transpose(0, 1),
)
self.add_hook("self.model.encoder", lambda input, output: output[0])
def forward(self, wavs):
if self.task.cfg.normalize:
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
device = wavs[0].device
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
wav_padding_mask = ~torch.lt(
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
wav_lengths.unsqueeze(1),
)
padded_wav = pad_sequence(wavs, batch_first=True)
features, feat_padding_mask = self.model.extract_features(
padded_wav,
padding_mask=wav_padding_mask,
mask=None,
)
return {
"default": features,
}

Просмотреть файл

@ -0,0 +1,11 @@
soundfile
fire
sentencepiece
tqdm
pyyaml
h5py
yamlargparse
sklearn
matplotlib
torchaudio
s3rpl

Просмотреть файл

@ -0,0 +1 @@
/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/wav/dev_clean_2_ns3_beta2_500/100/mix_0000496.wav

Просмотреть файл

@ -0,0 +1,484 @@
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ dataset.py ]
# Synopsis [ the speaker diarization dataset ]
# Source [ Refactored from https://github.com/hitachi-speech/EEND ]
# Author [ Jiatong Shi ]
# Copyright [ Copyright(c), Johns Hopkins University ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import io
import os
import subprocess
import sys
# -------------#
import numpy as np
import soundfile as sf
import torch
from torch.nn.utils.rnn import pad_sequence
# -------------#
from torch.utils.data.dataset import Dataset
# -------------#
def _count_frames(data_len, size, step):
# no padding at edges, last remaining samples are ignored
return int((data_len - size + step) / step)
def _gen_frame_indices(data_length, size=2000, step=2000):
i = -1
for i in range(_count_frames(data_length, size, step)):
yield i * step, i * step + size
if i * step + size < data_length:
if data_length - (i + 1) * step > 0:
if i == -1:
yield (i + 1) * step, data_length
else:
yield data_length - size, data_length
def _gen_chunk_indices(data_len, chunk_size):
step = chunk_size
start = 0
while start < data_len:
end = min(data_len, start + chunk_size)
yield start, end
start += step
#######################
# Diarization Dataset #
#######################
class DiarizationDataset(Dataset):
def __init__(
self,
mode,
data_dir,
chunk_size=2000,
frame_shift=256,
sampling_rate=16000,
subsampling=1,
use_last_samples=True,
num_speakers=3,
filter_spk=False
):
super(DiarizationDataset, self).__init__()
self.mode = mode
self.data_dir = data_dir
self.chunk_size = chunk_size
self.frame_shift = frame_shift
self.subsampling = subsampling
self.n_speakers = num_speakers
self.chunk_indices = [] if mode != "test" else {}
self.data = KaldiData(self.data_dir)
self.all_speakers = sorted(self.data.spk2utt.keys())
self.all_n_speakers = len(self.all_speakers)
# make chunk indices: filepath, start_frame, end_frame
for rec in self.data.wavs:
data_len = int(self.data.reco2dur[rec] * sampling_rate / frame_shift)
data_len = int(data_len / self.subsampling)
if mode == "test":
self.chunk_indices[rec] = []
if mode != "test":
for st, ed in _gen_frame_indices(data_len, chunk_size, chunk_size):
self.chunk_indices.append(
(rec, st * self.subsampling, ed * self.subsampling)
)
else:
for st, ed in _gen_chunk_indices(data_len, chunk_size):
self.chunk_indices[rec].append(
(rec, st, ed)
)
if mode != "test":
if filter_spk:
self.filter_spk()
print(len(self.chunk_indices), " chunks")
else:
self.rec_list = list(self.chunk_indices.keys())
print(len(self.rec_list), " recordings")
def __len__(self):
return (
len(self.rec_list)
if type(self.chunk_indices) == dict
else len(self.chunk_indices)
)
def filter_spk(self):
# filter the spk in spk2utt but will not be used in training
# i.e. the chunks contains more spk than self.n_speakers
occur_spk_set = set()
new_chunk_indices = [] # filter the chunk that more than self.num_speakers
for idx in range(self.__len__()):
rec, st, ed = self.chunk_indices[idx]
filtered_segments = self.data.segments[rec]
# all the speakers in this recording not the chunk
speakers = np.unique(
[self.data.utt2spk[seg['utt']] for seg in filtered_segments]
).tolist()
n_speakers = self.n_speakers
# we assume that in each chunk the speaker number is less or equal than self.n_speakers
# but the speaker number in the whole recording may exceed self.n_speakers
if self.n_speakers < len(speakers):
n_speakers = len(speakers)
# Y: (length,), T: (frame_num, n_speakers)
Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
# the spk index exist in this chunk data
exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
chunk_spk_num = np.sum(exist_spk_idx)
if chunk_spk_num <= self.n_speakers:
spk_arr = np.array(speakers)
valid_spk_arr = spk_arr[exist_spk_idx[:spk_arr.shape[0]]]
for spk in valid_spk_arr:
occur_spk_set.add(spk)
new_chunk_indices.append((rec, st, ed))
self.chunk_indices = new_chunk_indices
self.all_speakers = sorted(list(occur_spk_set))
self.all_n_speakers = len(self.all_speakers)
def __getitem__(self, i):
if self.mode != "test":
rec, st, ed = self.chunk_indices[i]
filtered_segments = self.data.segments[rec]
# all the speakers in this recording not the chunk
speakers = np.unique(
[self.data.utt2spk[seg['utt']] for seg in filtered_segments]
).tolist()
n_speakers = self.n_speakers
# we assume that in each chunk the speaker number is less or equal than self.n_speakers
# but the speaker number in the whole recording may exceed self.n_speakers
if self.n_speakers < len(speakers):
n_speakers = len(speakers)
# Y: (length,), T: (frame_num, n_speakers)
Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
# the spk index exist in this chunk data
exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
chunk_spk_num = np.sum(exist_spk_idx)
if chunk_spk_num > self.n_speakers:
# the speaker number in a chunk exceed our pre-set value
return None, None, None
# the map from within recording speaker index to global speaker index
S_arr = -1 * np.ones(n_speakers).astype(np.int64)
for seg in filtered_segments:
speaker_index = speakers.index(self.data.utt2spk[seg['utt']])
try:
all_speaker_index = self.all_speakers.index(
self.data.utt2spk[seg['utt']])
except:
# we have pre-filter some spk in self.filter_spk
all_speaker_index = -1
S_arr[speaker_index] = all_speaker_index
# If T[:, n_speakers - 1] == 0.0, then S_arr[n_speakers - 1] == -1,
# so S_arr[n_speakers - 1] is not used for training,
# e.g., in the case of training 3-spk model with 2-spk data
# filter the speaker not exist in this chunk and ensure there are self.num_speakers outputs
T_exist = T[:,exist_spk_idx]
T = np.zeros((T_exist.shape[0], self.n_speakers), dtype=np.int32)
T[:,:T_exist.shape[1]] = T_exist
# subsampling for Y will be done in the model forward function
T = T[::self.subsampling]
S_arr_exist = S_arr[exist_spk_idx]
S_arr = -1 * np.ones(self.n_speakers).astype(np.int64)
S_arr[:S_arr_exist.shape[0]] = S_arr_exist
n = np.arange(self.all_n_speakers, dtype=np.int64).reshape(self.all_n_speakers, 1)
return Y, T, S_arr, n, T.shape[0]
else:
len_ratio = self.frame_shift * self.subsampling
chunks = self.chunk_indices[self.rec_list[i]]
Ys = []
chunk_len_list = []
for (rec, st, ed) in chunks:
chunk_len = ed - st
if chunk_len != self.chunk_size:
st = max(0, ed - self.chunk_size)
Y, _ = self.data.load_wav(rec, st * len_ratio, ed * len_ratio)
Ys.append(Y)
chunk_len_list.append(chunk_len)
return Ys, self.rec_list[i], chunk_len_list
def get_allnspk(self):
return self.all_n_speakers
def _get_labeled_speech(
self, rec, start, end, n_speakers=None, use_speaker_id=False
):
"""Extracts speech chunks and corresponding labels
Extracts speech chunks and corresponding diarization labels for
given recording id and start/end times
Args:
rec (str): recording id
start (int): start frame index
end (int): end frame index
n_speakers (int): number of speakers
if None, the value is given from data
Returns:
data: speech chunk
(n_samples)
T: label
(n_frmaes, n_speakers)-shaped np.int32 array.
"""
data, rate = self.data.load_wav(
rec, start * self.frame_shift, end * self.frame_shift
)
frame_num = end - start
filtered_segments = self.data.segments[rec]
# filtered_segments = self.data.segments[self.data.segments['rec'] == rec]
speakers = np.unique(
[self.data.utt2spk[seg["utt"]] for seg in filtered_segments]
).tolist()
if n_speakers is None:
n_speakers = len(speakers)
T = np.zeros((frame_num, n_speakers), dtype=np.int32)
if use_speaker_id:
all_speakers = sorted(self.data.spk2utt.keys())
S = np.zeros((frame_num, len(all_speakers)), dtype=np.int32)
for seg in filtered_segments:
speaker_index = speakers.index(self.data.utt2spk[seg["utt"]])
if use_speaker_id:
all_speaker_index = all_speakers.index(self.data.utt2spk[seg["utt"]])
start_frame = np.rint(seg["st"] * rate / self.frame_shift).astype(int)
end_frame = np.rint(seg["et"] * rate / self.frame_shift).astype(int)
rel_start = rel_end = None
if start <= start_frame and start_frame < end:
rel_start = start_frame - start
if start < end_frame and end_frame <= end:
rel_end = end_frame - start
if rel_start is not None or rel_end is not None:
T[rel_start:rel_end, speaker_index] = 1
if use_speaker_id:
S[rel_start:rel_end, all_speaker_index] = 1
if use_speaker_id:
return data, T, S
else:
return data, T
def collate_fn(self, batch):
valid_samples = [sample for sample in batch if sample[0] is not None]
wav_list, binary_label_list, spk_label_list= [], [], []
all_spk_idx_list, len_list = [], []
for sample in valid_samples:
wav_list.append(torch.from_numpy(sample[0]).float())
binary_label_list.append(torch.from_numpy(sample[1]).long())
spk_label_list.append(torch.from_numpy(sample[2]).long())
all_spk_idx_list.append(torch.from_numpy(sample[3]).long())
len_list.append(sample[4])
wav_batch = pad_sequence(wav_list, batch_first=True, padding_value=0.0)
binary_label_batch = pad_sequence(binary_label_list, batch_first=True, padding_value=1).long()
spk_label_batch = torch.stack(spk_label_list)
all_spk_idx_batch = torch.stack(all_spk_idx_list)
len_batch = torch.LongTensor(len_list)
return wav_batch, binary_label_batch.float(), spk_label_batch, all_spk_idx_batch, len_batch
def collate_fn_infer(self, batch):
assert len(batch) == 1 # each batch should contain one recording
Ys, rec, chunk_len_list = batch[0]
wav_list = [torch.from_numpy(Y).float() for Y in Ys]
return wav_list, rec, chunk_len_list
#######################
# Kaldi-style Dataset #
#######################
class KaldiData:
"""This class holds data in kaldi-style directory."""
def __init__(self, data_dir):
"""Load kaldi data directory."""
self.data_dir = data_dir
self.segments = self._load_segments_rechash(
os.path.join(self.data_dir, "segments")
)
self.utt2spk = self._load_utt2spk(os.path.join(self.data_dir, "utt2spk"))
self.wavs = self._load_wav_scp(os.path.join(self.data_dir, "wav.scp"))
self.reco2dur = self._load_reco2dur(os.path.join(self.data_dir, "reco2dur"))
self.spk2utt = self._load_spk2utt(os.path.join(self.data_dir, "spk2utt"))
def load_wav(self, recid, start=0, end=None):
"""Load wavfile given recid, start time and end time."""
data, rate = self._load_wav(self.wavs[recid], start, end)
return data, rate
def _load_segments(self, segments_file):
"""Load segments file as array."""
if not os.path.exists(segments_file):
return None
return np.loadtxt(
segments_file,
dtype=[("utt", "object"), ("rec", "object"), ("st", "f"), ("et", "f")],
ndmin=1,
)
def _load_segments_hash(self, segments_file):
"""Load segments file as dict with uttid index."""
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
ret[utt] = (rec, float(st), float(et))
return ret
def _load_segments_rechash(self, segments_file):
"""Load segments file as dict with recid index."""
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
if rec not in ret:
ret[rec] = []
ret[rec].append({"utt": utt, "st": float(st), "et": float(et)})
return ret
def _load_wav_scp(self, wav_scp_file):
"""Return dictionary { rec: wav_rxfilename }."""
if os.path.exists(wav_scp_file):
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
return {x[0]: x[1] for x in lines}
else:
wav_dir = os.path.join(self.data_dir, "wav")
return {
os.path.splitext(filename)[0]: os.path.join(wav_dir, filename)
for filename in sorted(os.listdir(wav_dir))
}
def _load_wav(self, wav_rxfilename, start=0, end=None):
"""This function reads audio file and return data in numpy.float32 array.
"lru_cache" holds recently loaded audio so that can be called
many times on the same audio file.
OPTIMIZE: controls lru_cache size for random access,
considering memory size
"""
if wav_rxfilename.endswith("|"):
# input piped command
p = subprocess.Popen(
wav_rxfilename[:-1],
shell=True,
stdout=subprocess.PIPE,
)
data, samplerate = sf.read(
io.BytesIO(p.stdout.read()),
dtype="float32",
)
# cannot seek
data = data[start:end]
elif wav_rxfilename == "-":
# stdin
data, samplerate = sf.read(sys.stdin, dtype="float32")
# cannot seek
data = data[start:end]
else:
# normal wav file
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
return data, samplerate
def _load_utt2spk(self, utt2spk_file):
"""Returns dictionary { uttid: spkid }."""
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
return {x[0]: x[1] for x in lines}
def _load_spk2utt(self, spk2utt_file):
"""Returns dictionary { spkid: list of uttids }."""
if not os.path.exists(spk2utt_file):
return None
lines = [line.strip().split() for line in open(spk2utt_file)]
return {x[0]: x[1:] for x in lines}
def _load_reco2dur(self, reco2dur_file):
"""Returns dictionary { recid: duration }."""
if not os.path.exists(reco2dur_file):
return None
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
return {x[0]: float(x[1]) for x in lines}
def _process_wav(self, wav_rxfilename, process):
"""This function returns preprocessed wav_rxfilename.
Args:
wav_rxfilename:
input
process:
command which can be connected via pipe, use stdin and stdout
Returns:
wav_rxfilename: output piped command
"""
if wav_rxfilename.endswith("|"):
# input piped command
return wav_rxfilename + process + "|"
# stdin "-" or normal file
return "cat {0} | {1} |".format(wav_rxfilename, process)
def _extract_segments(self, wavs, segments=None):
"""This function returns generator of segmented audio.
Yields (utterance id, numpy.float32 array).
TODO?: sampling rate is not converted.
"""
if segments is not None:
# segments should be sorted by rec-id
for seg in segments:
wav = wavs[seg["rec"]]
data, samplerate = self.load_wav(wav)
st_sample = np.rint(seg["st"] * samplerate).astype(int)
et_sample = np.rint(seg["et"] * samplerate).astype(int)
yield seg["utt"], data[st_sample:et_sample]
else:
# segments file not found,
# wav.scp is used as segmented audio list
for rec in wavs:
data, samplerate = self.load_wav(wavs[rec])
yield rec, data
if __name__ == "__main__":
args = {
'mode': 'train',
'data_dir': "/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/data/train_clean_5_ns3_beta2_500",
'chunk_size': 2001,
'frame_shift': 256,
'sampling_rate': 8000,
'num_speakers':3
}
torch.manual_seed(6)
dataset = DiarizationDataset(**args)
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn)
data_iter = iter(dataloader)
# wav_batch, binary_label_batch, spk_label_batch, all_spk_idx_batch, len_batch = next(data_iter)
data = next(data_iter)
for val in data:
print(val.shape)
# from torch.utils.data import DataLoader
# dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn_infer)
# data_iter = iter(dataloader)
# wav_list, binary_label_list, rec = next(data_iter)

Просмотреть файл

@ -0,0 +1,162 @@
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This library provides utilities for kaldi-style data directory.
from __future__ import print_function
import os
import sys
import numpy as np
import subprocess
import soundfile as sf
import io
from functools import lru_cache
def load_segments(segments_file):
""" load segments file as array """
if not os.path.exists(segments_file):
return None
return np.loadtxt(
segments_file,
dtype=[('utt', 'object'),
('rec', 'object'),
('st', 'f'),
('et', 'f')],
ndmin=1)
def load_segments_hash(segments_file):
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
ret[utt] = (rec, float(st), float(et))
return ret
def load_segments_rechash(segments_file):
ret = {}
if not os.path.exists(segments_file):
return None
for line in open(segments_file):
utt, rec, st, et = line.strip().split()
if rec not in ret:
ret[rec] = []
ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
return ret
def load_wav_scp(wav_scp_file):
""" return dictionary { rec: wav_rxfilename } """
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
return {x[0]: x[1] for x in lines}
@lru_cache(maxsize=1)
def load_wav(wav_rxfilename, start=0, end=None):
""" This function reads audio file and return data in numpy.float32 array.
"lru_cache" holds recently loaded audio so that can be called
many times on the same audio file.
OPTIMIZE: controls lru_cache size for random access,
considering memory size
"""
if wav_rxfilename.endswith('|'):
# input piped command
p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
stdout=subprocess.PIPE)
data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
dtype='float32')
# cannot seek
data = data[start:end]
elif wav_rxfilename == '-':
# stdin
data, samplerate = sf.read(sys.stdin, dtype='float32')
# cannot seek
data = data[start:end]
else:
# normal wav file
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
return data, samplerate
def load_utt2spk(utt2spk_file):
""" returns dictionary { uttid: spkid } """
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
return {x[0]: x[1] for x in lines}
def load_spk2utt(spk2utt_file):
""" returns dictionary { spkid: list of uttids } """
if not os.path.exists(spk2utt_file):
return None
lines = [line.strip().split() for line in open(spk2utt_file)]
return {x[0]: x[1:] for x in lines}
def load_reco2dur(reco2dur_file):
""" returns dictionary { recid: duration } """
if not os.path.exists(reco2dur_file):
return None
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
return {x[0]: float(x[1]) for x in lines}
def process_wav(wav_rxfilename, process):
""" This function returns preprocessed wav_rxfilename
Args:
wav_rxfilename: input
process: command which can be connected via pipe,
use stdin and stdout
Returns:
wav_rxfilename: output piped command
"""
if wav_rxfilename.endswith('|'):
# input piped command
return wav_rxfilename + process + "|"
else:
# stdin "-" or normal file
return "cat {} | {} |".format(wav_rxfilename, process)
def extract_segments(wavs, segments=None):
""" This function returns generator of segmented audio as
(utterance id, numpy.float32 array)
TODO?: sampling rate is not converted.
"""
if segments is not None:
# segments should be sorted by rec-id
for seg in segments:
wav = wavs[seg['rec']]
data, samplerate = load_wav(wav)
st_sample = np.rint(seg['st'] * samplerate).astype(int)
et_sample = np.rint(seg['et'] * samplerate).astype(int)
yield seg['utt'], data[st_sample:et_sample]
else:
# segments file not found,
# wav.scp is used as segmented audio list
for rec in wavs:
data, samplerate = load_wav(wavs[rec])
yield rec, data
class KaldiData:
def __init__(self, data_dir):
self.data_dir = data_dir
self.segments = load_segments_rechash(
os.path.join(self.data_dir, 'segments'))
self.utt2spk = load_utt2spk(
os.path.join(self.data_dir, 'utt2spk'))
self.wavs = load_wav_scp(
os.path.join(self.data_dir, 'wav.scp'))
self.reco2dur = load_reco2dur(
os.path.join(self.data_dir, 'reco2dur'))
self.spk2utt = load_spk2utt(
os.path.join(self.data_dir, 'spk2utt'))
def load_wav(self, recid, start=0, end=None):
data, rate = load_wav(
self.wavs[recid], start, end)
return data, rate

Просмотреть файл

@ -0,0 +1,97 @@
#!/usr/bin/env bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### Now we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.

Просмотреть файл

@ -0,0 +1,189 @@
import os
import struct
import logging
import torch
import math
import numpy as np
import random
import yaml
import torch.distributed as dist
import torch.nn.functional as F
# ------------------------------ Logger ------------------------------
# log to console or a file
def get_logger(
name,
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
date_format="%Y-%m-%d %H:%M:%S",
file=False):
"""
Get python logger instance
"""
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
# file or console
handler = logging.StreamHandler() if not file else logging.FileHandler(
name)
handler.setLevel(logging.INFO)
formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
# log to concole and file at the same time
def get_logger_2(
name,
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
date_format="%Y-%m-%d %H:%M:%S"):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
# Create handlers
c_handler = logging.StreamHandler()
f_handler = logging.FileHandler(name)
c_handler.setLevel(logging.INFO)
f_handler.setLevel(logging.INFO)
# Create formatters and add it to handlers
c_format = logging.Formatter(fmt=format_str, datefmt=date_format)
f_format = logging.Formatter(fmt=format_str, datefmt=date_format)
c_handler.setFormatter(c_format)
f_handler.setFormatter(f_format)
# Add handlers to the logger
logger.addHandler(c_handler)
logger.addHandler(f_handler)
return logger
# ------------------------------ Logger ------------------------------
# ------------------------------ Pytorch Distributed Training ------------------------------
def getoneNode():
nodelist = os.environ['SLURM_JOB_NODELIST']
nodelist = nodelist.strip().split(',')[0]
import re
text = re.split('[-\[\]]', nodelist)
if ('' in text):
text.remove('')
return text[0] + '-' + text[1] + '-' + text[2]
def dist_init(host_addr, rank, local_rank, world_size, port=23456):
host_addr_full = 'tcp://' + host_addr + ':' + str(port)
dist.init_process_group("nccl", init_method=host_addr_full,
rank=rank, world_size=world_size)
num_gpus = torch.cuda.device_count()
# torch.cuda.set_device(local_rank)
assert dist.is_initialized()
def cleanup():
dist.destroy_process_group()
def average_gradients(model, world_size):
size = float(world_size)
for param in model.parameters():
if (param.requires_grad and param.grad is not None):
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
def data_reduce(data):
dist.all_reduce(data, op=dist.ReduceOp.SUM)
return data / torch.distributed.get_world_size()
# ------------------------------ Pytorch Distributed Training ------------------------------
# ------------------------------ Hyper-parameter Dynamic Change ------------------------------
def reduce_lr(optimizer, initial_lr, final_lr, current_iter, max_iter, coeff=1.0):
current_lr = coeff * math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
for param_group in optimizer.param_groups:
param_group['lr'] = current_lr
def get_reduce_lr(initial_lr, final_lr, current_iter, max_iter):
current_lr = math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
return current_lr
def set_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# ------------------------------ Hyper-parameter Dynamic Change ------------------------------
# ---------------------- About Configuration --------------------
def parse_config_or_kwargs(config_file, **kwargs):
with open(config_file) as con_read:
yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
# passed kwargs will override yaml config
return dict(yaml_config, **kwargs)
def store_yaml(config_file, store_path, **kwargs):
with open(config_file, 'r') as f:
config_lines = f.readlines()
keys_list = list(kwargs.keys())
with open(store_path, 'w') as f:
for line in config_lines:
if ':' in line and line.split(':')[0] in keys_list:
key = line.split(':')[0]
line = '{}: {}\n'.format(key, kwargs[key])
f.write(line)
# ---------------------- About Configuration --------------------
def check_dir(dir):
if not os.path.exists(dir):
os.mkdir(dir)
def set_seed(seed=66):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# when store the model wrongly with "module" involved,
# we remove it here
def correct_key(state_dict):
keys = list(state_dict.keys())
if 'module' not in keys[0]:
return state_dict
else:
new_state_dict = {}
for key in keys:
new_key = '.'.join(key.split('.')[1:])
new_state_dict[new_key] = state_dict[key]
return new_state_dict
def validate_path(dir_name):
"""
:param dir_name: Create the directory if it doesn't exist
:return: None
"""
dir_name = os.path.dirname(dir_name) # get the path
if not os.path.exists(dir_name) and (dir_name != ''):
os.makedirs(dir_name)
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']