зеркало из https://github.com/microsoft/UniSpeech.git
add diarization demo for unispeech_sat pre-training model
This commit is contained in:
Родитель
2816e682dc
Коммит
822afbabc6
|
@ -0,0 +1,34 @@
|
|||
## Pre-training Representations for Speaker Diarization
|
||||
|
||||
### Downstream Model
|
||||
|
||||
[EEND-vector-clustering](https://arxiv.org/abs/2105.09040)
|
||||
|
||||
### Pre-trained models
|
||||
|
||||
- It should be noted that the diarization system is trained on 8k audio data.
|
||||
|
||||
| Model | 2 spk DER | 3 spk DER | 4 spk DER | 5 spk DER | 6 spk DER | ALL spk DER |
|
||||
| ------------------------------------------------------------ | --------- | --------- | --------- | --------- | --------- | ----------- |
|
||||
| EEND-vector-clustering | 7.96 | 11.93 | 16.38 | 21.21 | 23.1 | 12.49 |
|
||||
| [**UniSpeech-SAT large**](https://drive.google.com/file/d/16OwIyOk2uYm0aWtSPaS0S12xE8RxF7k_/view?usp=sharing) | 5.93 | 10.66 | 12.90 | 16.48 | 23.25 | 10.92 |
|
||||
|
||||
### How to use?
|
||||
|
||||
#### Environment Setup
|
||||
|
||||
1. `pip install -r requirements.txt`
|
||||
2. Install fairseq code
|
||||
- For UniSpeech-SAT large, we should install the [Unispeech-SAT](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT) fairseq code.
|
||||
|
||||
#### Example
|
||||
|
||||
1. First, you should download the pre-trained model in the above table to `checkpoint_path`.
|
||||
2. Then, run the following codes:
|
||||
- The wav file is the multi-talker simulated speech from Librispeech corpus.
|
||||
3. The output will be written in `out.rttm` by default.
|
||||
|
||||
```bash
|
||||
python diarization.py --wav_path tmp/mix_0000496.wav --model_init $checkpoint_path
|
||||
```
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
|
||||
# All rights reserved
|
||||
|
||||
# inference options
|
||||
est_nspk: 1
|
||||
sil_spk_th: 0.05
|
||||
ahc_dis_th: 1.0
|
||||
clink_dis: 1.0e+4
|
||||
model:
|
||||
n_speakers: 3
|
||||
all_n_speakers: 0
|
||||
feat_dim: 1024
|
||||
n_units: 256
|
||||
n_heads: 8
|
||||
n_layers: 6
|
||||
dropout_rate: 0.1
|
||||
spk_emb_dim: 256
|
||||
sr: 8000
|
||||
frame_shift: 320
|
||||
frame_size: 200
|
||||
context_size: 0
|
||||
subsampling: 1
|
||||
feat_type: "config/unispeech_sat.th"
|
||||
feature_selection: "hidden_states"
|
||||
interpolate_mode: "linear"
|
||||
dataset:
|
||||
chunk_size: 750
|
||||
frame_shift: 320
|
||||
sampling_rate: 8000
|
||||
subsampling: 1
|
||||
num_speakers: 3
|
Двоичный файл не отображается.
|
@ -0,0 +1,321 @@
|
|||
import sys
|
||||
import h5py
|
||||
import soundfile as sf
|
||||
import fire
|
||||
import math
|
||||
import yamlargparse
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
from utils.utils import parse_config_or_kwargs
|
||||
from utils.dataset import DiarizationDataset
|
||||
from models.models import TransformerDiarization
|
||||
from scipy.signal import medfilt
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from scipy.spatial import distance
|
||||
from utils.kaldi_data import KaldiData
|
||||
|
||||
|
||||
def get_cl_sil(args, acti, cls_num):
|
||||
n_chunks = len(acti)
|
||||
mean_acti = np.array([np.mean(acti[i], axis=0)
|
||||
for i in range(n_chunks)]).flatten()
|
||||
n = args.num_speakers
|
||||
sil_spk_th = args.sil_spk_th
|
||||
|
||||
cl_lst = []
|
||||
sil_lst = []
|
||||
for chunk_idx in range(n_chunks):
|
||||
if cls_num is not None:
|
||||
if args.num_speakers > cls_num:
|
||||
mean_acti_bi = np.array([mean_acti[n * chunk_idx + s_loc_idx]
|
||||
for s_loc_idx in range(n)])
|
||||
min_idx = np.argmin(mean_acti_bi)
|
||||
mean_acti[n * chunk_idx + min_idx] = 0.0
|
||||
|
||||
for s_loc_idx in range(n):
|
||||
a = n * chunk_idx + (s_loc_idx + 0) % n
|
||||
b = n * chunk_idx + (s_loc_idx + 1) % n
|
||||
if mean_acti[a] > sil_spk_th and mean_acti[b] > sil_spk_th:
|
||||
cl_lst.append((a, b))
|
||||
else:
|
||||
if mean_acti[a] <= sil_spk_th:
|
||||
sil_lst.append(a)
|
||||
|
||||
return cl_lst, sil_lst
|
||||
|
||||
|
||||
def clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst):
|
||||
org_svec_len = len(svec)
|
||||
svec = np.delete(svec, sil_lst, 0)
|
||||
|
||||
# update cl_lst idx
|
||||
_tbl = [i - sum(sil < i for sil in sil_lst) for i in range(org_svec_len)]
|
||||
cl_lst = [(_tbl[_cl[0]], _tbl[_cl[1]]) for _cl in cl_lst]
|
||||
|
||||
distMat = distance.cdist(svec, svec, metric='euclidean')
|
||||
for cl in cl_lst:
|
||||
distMat[cl[0], cl[1]] = args.clink_dis
|
||||
distMat[cl[1], cl[0]] = args.clink_dis
|
||||
|
||||
clusterer = AgglomerativeClustering(
|
||||
n_clusters=cls_num,
|
||||
affinity='precomputed',
|
||||
linkage='average',
|
||||
distance_threshold=ahc_dis_th)
|
||||
clusterer.fit(distMat)
|
||||
|
||||
if cls_num is not None:
|
||||
print("oracle n_clusters is known")
|
||||
else:
|
||||
print("oracle n_clusters is unknown")
|
||||
print("estimated n_clusters by constraind AHC: {}"
|
||||
.format(len(np.unique(clusterer.labels_))))
|
||||
cls_num = len(np.unique(clusterer.labels_))
|
||||
|
||||
sil_lab = cls_num
|
||||
insert_sil_lab = [sil_lab for i in range(len(sil_lst))]
|
||||
insert_sil_lab_idx = [sil_lst[i] - i for i in range(len(sil_lst))]
|
||||
print("insert_sil_lab : {}".format(insert_sil_lab))
|
||||
print("insert_sil_lab_idx : {}".format(insert_sil_lab_idx))
|
||||
clslab = np.insert(clusterer.labels_,
|
||||
insert_sil_lab_idx,
|
||||
insert_sil_lab).reshape(-1, args.num_speakers)
|
||||
print("clslab : {}".format(clslab))
|
||||
|
||||
return clslab, cls_num
|
||||
|
||||
|
||||
def merge_act_max(act, i, j):
|
||||
for k in range(len(act)):
|
||||
act[k, i] = max(act[k, i], act[k, j])
|
||||
act[k, j] = 0.0
|
||||
return act
|
||||
|
||||
|
||||
def merge_acti_clslab(args, acti, clslab, cls_num):
|
||||
sil_lab = cls_num
|
||||
for i in range(len(clslab)):
|
||||
_lab = clslab[i].reshape(-1, 1)
|
||||
distM = distance.cdist(_lab, _lab, metric='euclidean').astype(np.int64)
|
||||
for j in range(len(distM)):
|
||||
distM[j][:j] = -1
|
||||
idx_lst = np.where(np.count_nonzero(distM == 0, axis=1) > 1)
|
||||
merge_done = []
|
||||
for j in idx_lst[0]:
|
||||
for k in (np.where(distM[j] == 0))[0]:
|
||||
if j != k and clslab[i, j] != sil_lab and k not in merge_done:
|
||||
print("merge : (i, j, k) == ({}, {}, {})".format(i, j, k))
|
||||
acti[i] = merge_act_max(acti[i], j, k)
|
||||
clslab[i, k] = sil_lab
|
||||
merge_done.append(j)
|
||||
|
||||
return acti, clslab
|
||||
|
||||
|
||||
def stitching(args, acti, clslab, cls_num):
|
||||
n_chunks = len(acti)
|
||||
s_loc = args.num_speakers
|
||||
sil_lab = cls_num
|
||||
s_tot = max(cls_num, s_loc-1)
|
||||
|
||||
# Extend the max value of s_loc_idx to s_tot+1
|
||||
add_acti = []
|
||||
for chunk_idx in range(n_chunks):
|
||||
zeros = np.zeros((len(acti[chunk_idx]), s_tot+1))
|
||||
if s_tot+1 > s_loc:
|
||||
zeros[:, :-(s_tot+1-s_loc)] = acti[chunk_idx]
|
||||
else:
|
||||
zeros = acti[chunk_idx]
|
||||
add_acti.append(zeros)
|
||||
acti = np.array(add_acti)
|
||||
|
||||
out_chunks = []
|
||||
for chunk_idx in range(n_chunks):
|
||||
# Make sloci2lab_dct.
|
||||
# key: s_loc_idx
|
||||
# value: estimated label by clustering or sil_lab
|
||||
cls_set = set()
|
||||
for s_loc_idx in range(s_tot+1):
|
||||
cls_set.add(s_loc_idx)
|
||||
|
||||
sloci2lab_dct = {}
|
||||
for s_loc_idx in range(s_tot+1):
|
||||
if s_loc_idx < s_loc:
|
||||
sloci2lab_dct[s_loc_idx] = clslab[chunk_idx][s_loc_idx]
|
||||
if clslab[chunk_idx][s_loc_idx] in cls_set:
|
||||
cls_set.remove(clslab[chunk_idx][s_loc_idx])
|
||||
else:
|
||||
if clslab[chunk_idx][s_loc_idx] != sil_lab:
|
||||
raise ValueError
|
||||
else:
|
||||
sloci2lab_dct[s_loc_idx] = list(cls_set)[s_loc_idx-s_loc]
|
||||
|
||||
# Sort by label value
|
||||
sloci2lab_lst = sorted(sloci2lab_dct.items(), key=lambda x: x[1])
|
||||
|
||||
# Select sil_lab_idx
|
||||
sil_lab_idx = None
|
||||
for idx_lab in sloci2lab_lst:
|
||||
if idx_lab[1] == sil_lab:
|
||||
sil_lab_idx = idx_lab[0]
|
||||
break
|
||||
if sil_lab_idx is None:
|
||||
raise ValueError
|
||||
|
||||
# Get swap_idx
|
||||
# [idx of label(0), idx of label(1), ..., idx of label(s_tot)]
|
||||
swap_idx = [sil_lab_idx for j in range(s_tot+1)]
|
||||
for lab in range(s_tot+1):
|
||||
for idx_lab in sloci2lab_lst:
|
||||
if lab == idx_lab[1]:
|
||||
swap_idx[lab] = idx_lab[0]
|
||||
|
||||
print("swap_idx {}".format(swap_idx))
|
||||
swap_acti = acti[chunk_idx][:, swap_idx]
|
||||
swap_acti = np.delete(swap_acti, sil_lab, 1)
|
||||
out_chunks.append(swap_acti)
|
||||
|
||||
return out_chunks
|
||||
|
||||
|
||||
def prediction(num_speakers, net, wav_list, chunk_len_list):
|
||||
acti_lst = []
|
||||
svec_lst = []
|
||||
len_list = []
|
||||
|
||||
with torch.no_grad():
|
||||
for wav, chunk_len in zip(wav_list, chunk_len_list):
|
||||
wav = wav.to('cuda')
|
||||
outputs = net.batch_estimate(torch.unsqueeze(wav, 0))
|
||||
ys = outputs[0]
|
||||
|
||||
for i in range(num_speakers):
|
||||
spkivecs = outputs[i+1]
|
||||
svec_lst.append(spkivecs[0].cpu().detach().numpy())
|
||||
|
||||
acti = ys[0][-chunk_len:].cpu().detach().numpy()
|
||||
acti_lst.append(acti)
|
||||
len_list.append(chunk_len)
|
||||
|
||||
acti_arr = np.concatenate(acti_lst, axis=0) # totol_len x num_speakers
|
||||
svec_arr = np.stack(svec_lst) # (chunk_num x num_speakers) x emb_dim
|
||||
len_arr = np.array(len_list) # chunk_num
|
||||
|
||||
return acti_arr, svec_arr, len_arr
|
||||
|
||||
def cluster(args, conf, acti_arr, svec_arr, len_arr):
|
||||
|
||||
acti_list = []
|
||||
n_chunks = len_arr.shape[0]
|
||||
start = 0
|
||||
for i in range(n_chunks):
|
||||
chunk_len = len_arr[i]
|
||||
acti_list.append(acti_arr[start: start+chunk_len])
|
||||
start += chunk_len
|
||||
acti = np.array(acti_list)
|
||||
svec = svec_arr
|
||||
|
||||
# initialize clustering setting
|
||||
cls_num = None
|
||||
ahc_dis_th = args.ahc_dis_th
|
||||
# Get cannot-link index list and silence index list
|
||||
cl_lst, sil_lst = get_cl_sil(args, acti, cls_num)
|
||||
|
||||
n_samples = n_chunks * args.num_speakers - len(sil_lst)
|
||||
min_n_samples = 2
|
||||
if cls_num is not None:
|
||||
min_n_samples = cls_num
|
||||
|
||||
if n_samples >= min_n_samples:
|
||||
# clustering (if cls_num is None, update cls_num)
|
||||
clslab, cls_num =\
|
||||
clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst)
|
||||
# merge
|
||||
acti, clslab = merge_acti_clslab(args, acti, clslab, cls_num)
|
||||
# stitching
|
||||
out_chunks = stitching(args, acti, clslab, cls_num)
|
||||
else:
|
||||
out_chunks = acti
|
||||
|
||||
outdata = np.vstack(out_chunks)
|
||||
# Saving the resuts
|
||||
return outdata
|
||||
|
||||
def make_rttm(args, conf, cluster_data):
|
||||
args.frame_shift = conf['model']['frame_shift']
|
||||
args.subsampling = conf['model']['subsampling']
|
||||
args.sampling_rate = conf['dataset']['sampling_rate']
|
||||
|
||||
with open(args.out_rttm_file, 'w') as wf:
|
||||
a = np.where(cluster_data > args.threshold, 1, 0)
|
||||
if args.median > 1:
|
||||
a = medfilt(a, (args.median, 1))
|
||||
for spkid, frames in enumerate(a.T):
|
||||
frames = np.pad(frames, (1, 1), 'constant')
|
||||
changes, = np.where(np.diff(frames, axis=0) != 0)
|
||||
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
|
||||
for s, e in zip(changes[::2], changes[1::2]):
|
||||
print(fmt.format(
|
||||
args.session,
|
||||
s * args.frame_shift * args.subsampling / args.sampling_rate,
|
||||
(e - s) * args.frame_shift * args.subsampling / args.sampling_rate,
|
||||
args.session + "_" + str(spkid)), file=wf)
|
||||
|
||||
def main(args):
|
||||
conf = parse_config_or_kwargs(args.config_path)
|
||||
num_speakers = conf['dataset']['num_speakers']
|
||||
args.num_speakers = num_speakers
|
||||
|
||||
# Prepare model
|
||||
model_parameter_dict = torch.load(args.model_init)['model']
|
||||
model_all_n_speakers = model_parameter_dict["embed.weight"].shape[0]
|
||||
conf['model']['all_n_speakers'] = model_all_n_speakers
|
||||
net = TransformerDiarization(**conf['model'])
|
||||
net.load_state_dict(model_parameter_dict, strict=False)
|
||||
net.eval()
|
||||
net = net.to("cuda")
|
||||
|
||||
audio, sr = sf.read(args.wav_path, dtype="float32")
|
||||
audio_len = audio.shape[0]
|
||||
chunk_size, frame_shift, subsampling = conf['dataset']['chunk_size'], conf['model']['frame_shift'], conf['model']['subsampling']
|
||||
scale_ratio = int(frame_shift * subsampling)
|
||||
chunk_audio_size = chunk_size * scale_ratio
|
||||
wav_list, chunk_len_list = [], []
|
||||
for i in range(0, math.ceil(1.0 * audio_len / chunk_audio_size)):
|
||||
start, end = i*chunk_audio_size, (i+1)*chunk_audio_size
|
||||
if end > audio_len:
|
||||
chunk_len_list.append(int((audio_len-start) / scale_ratio))
|
||||
end = audio_len
|
||||
start = max(0, audio_len - chunk_audio_size)
|
||||
else:
|
||||
chunk_len_list.append(chunk_size)
|
||||
wav_list.append(audio[start:end])
|
||||
wav_list = [torch.from_numpy(wav).float() for wav in wav_list]
|
||||
|
||||
acti_arr, svec_arr, len_arr = prediction(num_speakers, net, wav_list, chunk_len_list)
|
||||
cluster_data = cluster(args, conf, acti_arr, svec_arr, len_arr)
|
||||
make_rttm(args, conf, cluster_data)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = yamlargparse.ArgumentParser(description='decoding')
|
||||
parser.add_argument('--wav_path',
|
||||
help='the input wav path',
|
||||
default="tmp/mix_0000496.wav")
|
||||
parser.add_argument('--config_path',
|
||||
help='config file path',
|
||||
default="config/infer_est_nspk1.yaml")
|
||||
parser.add_argument('--model_init',
|
||||
help='model initialize path',
|
||||
default="")
|
||||
parser.add_argument('--sil_spk_th', default=0.05, type=float)
|
||||
parser.add_argument('--ahc_dis_th', default=1.0, type=float)
|
||||
parser.add_argument('--clink_dis', default=1.0e+4, type=float)
|
||||
parser.add_argument('--session', default='Anonymous', help='the name of the output speaker')
|
||||
parser.add_argument('--out_rttm_file', default='out.rttm', help='the output rttm file')
|
||||
parser.add_argument('--threshold', default=0.4, type=float)
|
||||
parser.add_argument('--median', default=25, type=int)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,391 @@
|
|||
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
|
||||
# All rights reserved
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torchaudio.transforms as trans
|
||||
from collections import OrderedDict
|
||||
from itertools import permutations
|
||||
from models.transformer import TransformerEncoder
|
||||
from .utils import UpstreamExpert
|
||||
|
||||
|
||||
class GradMultiply(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.new(x)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad * ctx.scale, None
|
||||
|
||||
|
||||
|
||||
"""
|
||||
P: number of permutation
|
||||
T: number of frames
|
||||
C: number of speakers (classes)
|
||||
B: mini-batch size
|
||||
"""
|
||||
|
||||
|
||||
def batch_pit_loss_parallel(outputs, labels, ilens=None):
|
||||
""" calculate the batch pit loss parallelly
|
||||
Args:
|
||||
outputs (torch.Tensor): B x T x C
|
||||
labels (torch.Tensor): B x T x C
|
||||
ilens (torch.Tensor): B
|
||||
Returns:
|
||||
perm (torch.Tensor): permutation for outputs (Batch, num_spk)
|
||||
loss
|
||||
"""
|
||||
|
||||
if ilens is None:
|
||||
mask, scale = 1.0, outputs.shape[1]
|
||||
else:
|
||||
scale = torch.unsqueeze(torch.LongTensor(ilens), 1).to(outputs.device)
|
||||
mask = outputs.new_zeros(outputs.size()[:-1])
|
||||
for i, chunk_len in enumerate(ilens):
|
||||
mask[i, :chunk_len] += 1.0
|
||||
mask /= scale
|
||||
|
||||
def loss_func(output, label):
|
||||
# return torch.mean(F.binary_cross_entropy_with_logits(output, label, reduction='none'), dim=tuple(range(1, output.dim())))
|
||||
return torch.sum(F.binary_cross_entropy_with_logits(output, label, reduction='none') * mask, dim=-1)
|
||||
|
||||
def pair_loss(outputs, labels, permutation):
|
||||
return sum([loss_func(outputs[:,:,s], labels[:,:,t]) for s, t in enumerate(permutation)]) / len(permutation)
|
||||
|
||||
device = outputs.device
|
||||
num_spk = outputs.shape[-1]
|
||||
all_permutations = list(permutations(range(num_spk)))
|
||||
losses = torch.stack([pair_loss(outputs, labels, p) for p in all_permutations], dim=1)
|
||||
loss, perm = torch.min(losses, dim=1)
|
||||
perm = torch.index_select(torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm)
|
||||
return torch.mean(loss), perm
|
||||
|
||||
|
||||
def fix_state_dict(state_dict):
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
# remove 'module.' of DataParallel
|
||||
k = k[7:]
|
||||
if k.startswith('net.'):
|
||||
# remove 'net.' of PadertorchModel
|
||||
k = k[4:]
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
|
||||
|
||||
class TransformerDiarization(nn.Module):
|
||||
def __init__(self,
|
||||
n_speakers,
|
||||
all_n_speakers,
|
||||
feat_dim,
|
||||
n_units,
|
||||
n_heads,
|
||||
n_layers,
|
||||
dropout_rate,
|
||||
spk_emb_dim,
|
||||
sr=8000,
|
||||
frame_shift=256,
|
||||
frame_size=1024,
|
||||
context_size=0,
|
||||
subsampling=1,
|
||||
feat_type='fbank',
|
||||
feature_selection='default',
|
||||
interpolate_mode='linear',
|
||||
update_extract=False,
|
||||
feature_grad_mult=1.0
|
||||
):
|
||||
super(TransformerDiarization, self).__init__()
|
||||
self.context_size = context_size
|
||||
self.subsampling = subsampling
|
||||
self.feat_type = feat_type
|
||||
self.feature_selection = feature_selection
|
||||
self.sr = sr
|
||||
self.frame_shift = frame_shift
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.update_extract = update_extract
|
||||
self.feature_grad_mult = feature_grad_mult
|
||||
|
||||
if feat_type == 'fbank':
|
||||
self.feature_extract = trans.MelSpectrogram(sample_rate=sr,
|
||||
n_fft=frame_size,
|
||||
win_length=frame_size,
|
||||
hop_length=frame_shift,
|
||||
f_min=0.0,
|
||||
f_max=sr // 2,
|
||||
pad=0,
|
||||
n_mels=feat_dim)
|
||||
else:
|
||||
self.feature_extract = UpstreamExpert(feat_type)
|
||||
# self.feature_extract = torch.hub.load('s3prl/s3prl', 'hubert_local', ckpt=feat_type)
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
|
||||
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
|
||||
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
||||
self.feat_num = self.get_feat_num()
|
||||
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
||||
# for param in self.feature_extract.parameters():
|
||||
# param.requires_grad = False
|
||||
self.resample = trans.Resample(orig_freq=sr, new_freq=16000)
|
||||
|
||||
if feat_type != 'fbank' and feat_type != 'mfcc':
|
||||
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer', 'spk_proj', 'layer_norm_for_extract']
|
||||
for name, param in self.feature_extract.named_parameters():
|
||||
for freeze_val in freeze_list:
|
||||
if freeze_val in name:
|
||||
param.requires_grad = False
|
||||
break
|
||||
if not self.update_extract:
|
||||
for param in self.feature_extract.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
||||
|
||||
feat_dim = feat_dim * (self.context_size*2 + 1)
|
||||
self.enc = TransformerEncoder(
|
||||
feat_dim, n_layers, n_units, h=n_heads, dropout_rate=dropout_rate)
|
||||
self.linear = nn.Linear(n_units, n_speakers)
|
||||
|
||||
for i in range(n_speakers):
|
||||
setattr(self, '{}{:d}'.format("linear", i), nn.Linear(n_units, spk_emb_dim))
|
||||
|
||||
self.n_speakers = n_speakers
|
||||
self.embed = nn.Embedding(all_n_speakers, spk_emb_dim)
|
||||
self.alpha = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
|
||||
self.beta = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
|
||||
|
||||
def get_feat_num(self):
|
||||
self.feature_extract.eval()
|
||||
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
||||
with torch.no_grad():
|
||||
features = self.feature_extract(wav)
|
||||
select_feature = features[self.feature_selection]
|
||||
if isinstance(select_feature, (list, tuple)):
|
||||
return len(select_feature)
|
||||
else:
|
||||
return 1
|
||||
|
||||
def fix_except_embedding(self, requires_grad=False):
|
||||
for name, param in self.named_parameters():
|
||||
if 'embed' not in name:
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
def modfy_emb(self, weight):
|
||||
self.embed = nn.Embedding.from_pretrained(weight)
|
||||
|
||||
def splice(self, data, context_size):
|
||||
# data: B x feat_dim x time_len
|
||||
data = torch.unsqueeze(data, -1)
|
||||
kernel_size = context_size*2 + 1
|
||||
splice_data = F.unfold(data, kernel_size=(kernel_size, 1), padding=(context_size, 0))
|
||||
return splice_data
|
||||
|
||||
def get_feat(self, xs):
|
||||
wav_len = xs.shape[-1]
|
||||
chunk_size = int(wav_len / self.frame_shift)
|
||||
chunk_size = int(chunk_size / self.subsampling)
|
||||
|
||||
self.feature_extract.eval()
|
||||
if self.update_extract:
|
||||
xs = self.resample(xs)
|
||||
feature = self.feature_extract([sample for sample in xs])
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if self.feat_type == 'fbank':
|
||||
feature = self.feature_extract(xs) + 1e-6 # B x feat_dim x time_len
|
||||
feature = feature.log()
|
||||
else:
|
||||
xs = self.resample(xs)
|
||||
feature = self.feature_extract([sample for sample in xs])
|
||||
|
||||
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
||||
feature = feature[self.feature_selection]
|
||||
if isinstance(feature, (list, tuple)):
|
||||
feature = torch.stack(feature, dim=0)
|
||||
else:
|
||||
feature = feature.unsqueeze(0)
|
||||
|
||||
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
feature = (norm_weights * feature).sum(dim=0)
|
||||
feature = torch.transpose(feature, 1, 2) + 1e-6
|
||||
|
||||
feature = self.instance_norm(feature)
|
||||
feature = self.splice(feature, self.context_size)
|
||||
feature = feature[:, :, ::self.subsampling]
|
||||
feature = F.interpolate(feature, chunk_size, mode=self.interpolate_mode)
|
||||
feature = torch.transpose(feature, 1, 2)
|
||||
|
||||
if self.feature_grad_mult != 1.0:
|
||||
feature = GradMultiply.apply(feature, self.feature_grad_mult)
|
||||
|
||||
return feature
|
||||
|
||||
def forward(self, inputs):
|
||||
if isinstance(inputs, list):
|
||||
xs = inputs[0]
|
||||
else:
|
||||
xs = inputs
|
||||
feature = self.get_feat(xs)
|
||||
|
||||
pad_shape = feature.shape
|
||||
emb = self.enc(feature)
|
||||
ys = self.linear(emb)
|
||||
ys = ys.reshape(pad_shape[0], pad_shape[1], -1)
|
||||
|
||||
spksvecs = []
|
||||
for i in range(self.n_speakers):
|
||||
spkivecs = getattr(self, '{}{:d}'.format("linear", i))(emb)
|
||||
spkivecs = spkivecs.reshape(pad_shape[0], pad_shape[1], -1)
|
||||
spksvecs.append(spkivecs)
|
||||
|
||||
return ys, spksvecs
|
||||
|
||||
def get_loss(self, inputs, ys, spksvecs, cal_spk_loss=True):
|
||||
ts = inputs[1]
|
||||
ss = inputs[2]
|
||||
ns = inputs[3]
|
||||
ilens = inputs[4]
|
||||
ilens = [ilen.item() for ilen in ilens]
|
||||
|
||||
pit_loss, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
|
||||
if cal_spk_loss:
|
||||
spk_loss = self.spk_loss_parallel(spksvecs, ys, ts, ss, sigmas, ns, ilens)
|
||||
else:
|
||||
spk_loss = torch.tensor(0.0).to(pit_loss.device)
|
||||
|
||||
alpha = torch.clamp(self.alpha, min=sys.float_info.epsilon)
|
||||
|
||||
return {'spk_loss':spk_loss,
|
||||
'pit_loss': pit_loss}
|
||||
|
||||
|
||||
def batch_estimate(self, xs):
|
||||
out = self(xs)
|
||||
ys = out[0]
|
||||
spksvecs = out[1]
|
||||
spksvecs = list(zip(*spksvecs))
|
||||
outputs = [
|
||||
self.estimate(spksvec, y)
|
||||
for (spksvec, y) in zip(spksvecs, ys)]
|
||||
outputs = list(zip(*outputs))
|
||||
|
||||
return outputs
|
||||
|
||||
def batch_estimate_with_perm(self, xs, ts, ilens=None):
|
||||
out = self(xs)
|
||||
ys = out[0]
|
||||
if ts[0].shape[1] > ys[0].shape[1]:
|
||||
# e.g. the case of training 3-spk model with 4-spk data
|
||||
add_dim = ts[0].shape[1] - ys[0].shape[1]
|
||||
y_device = ys[0].device
|
||||
zeros = [torch.zeros(ts[0].shape).to(y_device)
|
||||
for i in range(len(ts))]
|
||||
_ys = []
|
||||
for zero, y in zip(zeros, ys):
|
||||
_zero = zero
|
||||
_zero[:, :-add_dim] = y
|
||||
_ys.append(_zero)
|
||||
_, sigmas = batch_pit_loss_parallel(_ys, ts, ilens)
|
||||
else:
|
||||
_, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
|
||||
spksvecs = out[1]
|
||||
spksvecs = list(zip(*spksvecs))
|
||||
outputs = [self.estimate(spksvec, y)
|
||||
for (spksvec, y) in zip(spksvecs, ys)]
|
||||
outputs = list(zip(*outputs))
|
||||
zs = outputs[0]
|
||||
|
||||
if ts[0].shape[1] > ys[0].shape[1]:
|
||||
# e.g. the case of training 3-spk model with 4-spk data
|
||||
add_dim = ts[0].shape[1] - ys[0].shape[1]
|
||||
z_device = zs[0].device
|
||||
zeros = [torch.zeros(ts[0].shape).to(z_device)
|
||||
for i in range(len(ts))]
|
||||
_zs = []
|
||||
for zero, z in zip(zeros, zs):
|
||||
_zero = zero
|
||||
_zero[:, :-add_dim] = z
|
||||
_zs.append(_zero)
|
||||
zs = _zs
|
||||
outputs[0] = zs
|
||||
outputs.append(sigmas)
|
||||
|
||||
# outputs: [zs, nmz_wavg_spk0vecs, nmz_wavg_spk1vecs, ..., sigmas]
|
||||
return outputs
|
||||
|
||||
def estimate(self, spksvec, y):
|
||||
outputs = []
|
||||
z = torch.sigmoid(y.transpose(1, 0))
|
||||
|
||||
outputs.append(z.transpose(1, 0))
|
||||
for spkid, spkvec in enumerate(spksvec):
|
||||
norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1)
|
||||
# Normalize speaker vectors before weighted average
|
||||
spkvec = torch.mul(
|
||||
spkvec.transpose(1, 0), norm_spkvec_inv
|
||||
).transpose(1, 0)
|
||||
wavg_spkvec = torch.mul(
|
||||
spkvec.transpose(1, 0), z[spkid]
|
||||
).transpose(1, 0)
|
||||
sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0)
|
||||
nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec)
|
||||
outputs.append(nmz_wavg_spkvec)
|
||||
|
||||
# outputs: [z, nmz_wavg_spk0vec, nmz_wavg_spk1vec, ...]
|
||||
return outputs
|
||||
|
||||
def spk_loss_parallel(self, spksvecs, ys, ts, ss, sigmas, ns, ilens):
|
||||
'''
|
||||
spksvecs (List[torch.Tensor, ...]): [B x T x emb_dim, ...]
|
||||
ys (torch.Tensor): B x T x 3
|
||||
ts (torch.Tensor): B x T x 3
|
||||
ss (torch.Tensor): B x 3
|
||||
sigmas (torch.Tensor): B x 3
|
||||
ns (torch.Tensor): B x total_spk_num x 1
|
||||
ilens (List): B
|
||||
'''
|
||||
chunk_spk_num = len(spksvecs) # 3
|
||||
|
||||
len_mask = ys.new_zeros((ys.size()[:-1])) # B x T
|
||||
for i, len_val in enumerate(ilens):
|
||||
len_mask[i,:len_val] += 1.0
|
||||
ts = ts * len_mask.unsqueeze(-1)
|
||||
len_mask = len_mask.repeat((chunk_spk_num, 1)) # B*3 x T
|
||||
|
||||
spk_vecs = torch.cat(spksvecs, dim=0) # B*3 x T x emb_dim
|
||||
# Normalize speaker vectors before weighted average
|
||||
spk_vecs = F.normalize(spk_vecs, dim=-1)
|
||||
|
||||
ys = torch.permute(torch.sigmoid(ys), dims=(2, 0, 1)) # 3 x B x T
|
||||
ys = ys.reshape(-1, ys.shape[-1]).unsqueeze(-1) # B*3 x T x 1
|
||||
|
||||
weight_spk_vec = ys * spk_vecs # B*3 x T x emb_dim
|
||||
weight_spk_vec *= len_mask.unsqueeze(-1)
|
||||
sum_spk_vec = torch.sum(weight_spk_vec, dim=1) # B*3 x emb_dim
|
||||
norm_spk_vec = F.normalize(sum_spk_vec, dim=1)
|
||||
|
||||
embeds = F.normalize(self.embed(ns[0]).squeeze(), dim=1) # total_spk_num x emb_dim
|
||||
dist = torch.cdist(norm_spk_vec, embeds) # B*3 x total_spk_num
|
||||
logits = -1.0 * torch.add(torch.clamp(self.alpha, min=sys.float_info.epsilon) * torch.pow(dist, 2), self.beta)
|
||||
label = torch.gather(ss, 1, sigmas).transpose(0, 1).reshape(-1, 1).squeeze() # B*3
|
||||
label[label==-1] = 0
|
||||
valid_spk_mask = torch.gather(torch.sum(ts, dim=1), 1, sigmas).transpose(0, 1) # 3 x B
|
||||
valid_spk_mask = (torch.flatten(valid_spk_mask) > 0).float() # B*3
|
||||
|
||||
valid_spk_loss_num = torch.sum(valid_spk_mask).item()
|
||||
if valid_spk_loss_num > 0:
|
||||
loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_loss_num
|
||||
# uncomment the line below, the loss result is same as batch_spk_loss
|
||||
# loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_mask.shape[0]
|
||||
return torch.sum(loss)
|
||||
else:
|
||||
return torch.tensor(0.0).to(ys.device)
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
|
||||
# All rights reserved
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class NoamScheduler(_LRScheduler):
|
||||
""" learning rate scheduler used in the transformer
|
||||
See https://arxiv.org/pdf/1706.03762.pdf
|
||||
lrate = d_model**(-0.5) * \
|
||||
min(step_num**(-0.5), step_num*warmup_steps**(-1.5))
|
||||
Scaling factor is implemented as in
|
||||
http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, optimizer, d_model, warmup_steps, tot_step, scale,
|
||||
last_epoch=-1
|
||||
):
|
||||
self.d_model = d_model
|
||||
self.warmup_steps = warmup_steps
|
||||
self.tot_step = tot_step
|
||||
self.scale = scale
|
||||
super(NoamScheduler, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
self.last_epoch = max(1, self.last_epoch)
|
||||
step_num = self.last_epoch
|
||||
val = self.scale * self.d_model ** (-0.5) * \
|
||||
min(step_num ** (-0.5), step_num * self.warmup_steps ** (-1.5))
|
||||
|
||||
return [base_lr / base_lr * val for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class MultiHeadSelfAttention(nn.Module):
|
||||
""" Multi head "self" attention layer
|
||||
"""
|
||||
|
||||
def __init__(self, n_units, h=8, dropout_rate=0.1):
|
||||
super(MultiHeadSelfAttention, self).__init__()
|
||||
self.linearQ = nn.Linear(n_units, n_units)
|
||||
self.linearK = nn.Linear(n_units, n_units)
|
||||
self.linearV = nn.Linear(n_units, n_units)
|
||||
self.linearO = nn.Linear(n_units, n_units)
|
||||
self.d_k = n_units // h
|
||||
self.h = h
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
# attention for plot
|
||||
self.att = None
|
||||
|
||||
def forward(self, x, batch_size):
|
||||
# x: (BT, F)
|
||||
q = self.linearQ(x).reshape(batch_size, -1, self.h, self.d_k)
|
||||
k = self.linearK(x).reshape(batch_size, -1, self.h, self.d_k)
|
||||
v = self.linearV(x).reshape(batch_size, -1, self.h, self.d_k)
|
||||
|
||||
scores = torch.matmul(
|
||||
q.transpose(1, 2), k.permute(0, 2, 3, 1)) / np.sqrt(self.d_k)
|
||||
# scores: (B, h, T, T) = (B, h, T, d_k) x (B, h, d_k, T)
|
||||
self.att = F.softmax(scores, dim=3)
|
||||
p_att = self.dropout(self.att)
|
||||
x = torch.matmul(p_att, v.transpose(1, 2))
|
||||
x = x.transpose(1, 2).reshape(-1, self.h * self.d_k)
|
||||
|
||||
return self.linearO(x)
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
""" Positionwise feed-forward layer
|
||||
"""
|
||||
|
||||
def __init__(self, n_units, d_units, dropout_rate):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.linear1 = nn.Linear(n_units, d_units)
|
||||
self.linear2 = nn.Linear(d_units, n_units)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.dropout(F.relu(self.linear1(x))))
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
""" Positional encoding function
|
||||
"""
|
||||
|
||||
def __init__(self, n_units, dropout_rate, max_len):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
positions = np.arange(0, max_len, dtype='f')[:, None]
|
||||
dens = np.exp(
|
||||
np.arange(0, n_units, 2, dtype='f') * -(np.log(10000.) / n_units))
|
||||
self.enc = np.zeros((max_len, n_units), dtype='f')
|
||||
self.enc[:, ::2] = np.sin(positions * dens)
|
||||
self.enc[:, 1::2] = np.cos(positions * dens)
|
||||
self.scale = np.sqrt(n_units)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * self.scale + self.xp.array(self.enc[:, :x.shape[1]])
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, idim, n_layers, n_units,
|
||||
e_units=2048, h=8, dropout_rate=0.1):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
self.linear_in = nn.Linear(idim, n_units)
|
||||
# self.lnorm_in = nn.LayerNorm(n_units)
|
||||
self.pos_enc = PositionalEncoding(n_units, dropout_rate, 5000)
|
||||
self.n_layers = n_layers
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
for i in range(n_layers):
|
||||
setattr(self, '{}{:d}'.format("lnorm1_", i),
|
||||
nn.LayerNorm(n_units))
|
||||
setattr(self, '{}{:d}'.format("self_att_", i),
|
||||
MultiHeadSelfAttention(n_units, h, dropout_rate))
|
||||
setattr(self, '{}{:d}'.format("lnorm2_", i),
|
||||
nn.LayerNorm(n_units))
|
||||
setattr(self, '{}{:d}'.format("ff_", i),
|
||||
PositionwiseFeedForward(n_units, e_units, dropout_rate))
|
||||
self.lnorm_out = nn.LayerNorm(n_units)
|
||||
|
||||
def forward(self, x):
|
||||
# x: (B, T, F) ... batch, time, (mel)freq
|
||||
BT_size = x.shape[0] * x.shape[1]
|
||||
# e: (BT, F)
|
||||
e = self.linear_in(x.reshape(BT_size, -1))
|
||||
# Encoder stack
|
||||
for i in range(self.n_layers):
|
||||
# layer normalization
|
||||
e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
|
||||
# self-attention
|
||||
s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0])
|
||||
# residual
|
||||
e = e + self.dropout(s)
|
||||
# layer normalization
|
||||
e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
|
||||
# positionwise feed-forward
|
||||
s = getattr(self, '{}{:d}'.format("ff_", i))(e)
|
||||
# residual
|
||||
e = e + self.dropout(s)
|
||||
# final layer normalization
|
||||
# output: (BT, F)
|
||||
return self.lnorm_out(e)
|
|
@ -0,0 +1,78 @@
|
|||
import torch
|
||||
import fairseq
|
||||
from packaging import version
|
||||
import torch.nn.functional as F
|
||||
from fairseq import tasks
|
||||
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from omegaconf import OmegaConf
|
||||
from s3prl.upstream.interfaces import UpstreamBase
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
def load_model(filepath):
|
||||
state = torch.load(filepath, map_location=lambda storage, loc: storage)
|
||||
# state = load_checkpoint_to_cpu(filepath)
|
||||
state["cfg"] = OmegaConf.create(state["cfg"])
|
||||
|
||||
if "args" in state and state["args"] is not None:
|
||||
cfg = convert_namespace_to_omegaconf(state["args"])
|
||||
elif "cfg" in state and state["cfg"] is not None:
|
||||
cfg = state["cfg"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
||||
)
|
||||
|
||||
task = tasks.setup_task(cfg.task)
|
||||
if "task_state" in state:
|
||||
task.load_state_dict(state["task_state"])
|
||||
|
||||
model = task.build_model(cfg.model)
|
||||
|
||||
return model, cfg, task
|
||||
|
||||
|
||||
###################
|
||||
# UPSTREAM EXPERT #
|
||||
###################
|
||||
class UpstreamExpert(UpstreamBase):
|
||||
def __init__(self, ckpt, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert version.parse(fairseq.__version__) > version.parse(
|
||||
"0.10.2"
|
||||
), "Please install the fairseq master branch."
|
||||
|
||||
model, cfg, task = load_model(ckpt)
|
||||
self.model = model
|
||||
self.task = task
|
||||
|
||||
if len(self.hooks) == 0:
|
||||
module_name = "self.model.encoder.layers"
|
||||
for module_id in range(len(eval(module_name))):
|
||||
self.add_hook(
|
||||
f"{module_name}[{module_id}]",
|
||||
lambda input, output: input[0].transpose(0, 1),
|
||||
)
|
||||
self.add_hook("self.model.encoder", lambda input, output: output[0])
|
||||
|
||||
def forward(self, wavs):
|
||||
if self.task.cfg.normalize:
|
||||
wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
|
||||
|
||||
device = wavs[0].device
|
||||
wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
|
||||
wav_padding_mask = ~torch.lt(
|
||||
torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
|
||||
wav_lengths.unsqueeze(1),
|
||||
)
|
||||
padded_wav = pad_sequence(wavs, batch_first=True)
|
||||
|
||||
features, feat_padding_mask = self.model.extract_features(
|
||||
padded_wav,
|
||||
padding_mask=wav_padding_mask,
|
||||
mask=None,
|
||||
)
|
||||
return {
|
||||
"default": features,
|
||||
}
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
soundfile
|
||||
fire
|
||||
sentencepiece
|
||||
tqdm
|
||||
pyyaml
|
||||
h5py
|
||||
yamlargparse
|
||||
sklearn
|
||||
matplotlib
|
||||
torchaudio
|
||||
s3rpl
|
|
@ -0,0 +1 @@
|
|||
/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/wav/dev_clean_2_ns3_beta2_500/100/mix_0000496.wav
|
|
@ -0,0 +1,484 @@
|
|||
# -*- coding: utf-8 -*- #
|
||||
"""*********************************************************************************************"""
|
||||
# FileName [ dataset.py ]
|
||||
# Synopsis [ the speaker diarization dataset ]
|
||||
# Source [ Refactored from https://github.com/hitachi-speech/EEND ]
|
||||
# Author [ Jiatong Shi ]
|
||||
# Copyright [ Copyright(c), Johns Hopkins University ]
|
||||
"""*********************************************************************************************"""
|
||||
|
||||
|
||||
###############
|
||||
# IMPORTATION #
|
||||
###############
|
||||
import io
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# -------------#
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# -------------#
|
||||
from torch.utils.data.dataset import Dataset
|
||||
# -------------#
|
||||
|
||||
|
||||
def _count_frames(data_len, size, step):
|
||||
# no padding at edges, last remaining samples are ignored
|
||||
return int((data_len - size + step) / step)
|
||||
|
||||
|
||||
def _gen_frame_indices(data_length, size=2000, step=2000):
|
||||
i = -1
|
||||
for i in range(_count_frames(data_length, size, step)):
|
||||
yield i * step, i * step + size
|
||||
|
||||
if i * step + size < data_length:
|
||||
if data_length - (i + 1) * step > 0:
|
||||
if i == -1:
|
||||
yield (i + 1) * step, data_length
|
||||
else:
|
||||
yield data_length - size, data_length
|
||||
|
||||
|
||||
def _gen_chunk_indices(data_len, chunk_size):
|
||||
step = chunk_size
|
||||
start = 0
|
||||
while start < data_len:
|
||||
end = min(data_len, start + chunk_size)
|
||||
yield start, end
|
||||
start += step
|
||||
|
||||
|
||||
#######################
|
||||
# Diarization Dataset #
|
||||
#######################
|
||||
class DiarizationDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
mode,
|
||||
data_dir,
|
||||
chunk_size=2000,
|
||||
frame_shift=256,
|
||||
sampling_rate=16000,
|
||||
subsampling=1,
|
||||
use_last_samples=True,
|
||||
num_speakers=3,
|
||||
filter_spk=False
|
||||
):
|
||||
super(DiarizationDataset, self).__init__()
|
||||
|
||||
self.mode = mode
|
||||
self.data_dir = data_dir
|
||||
self.chunk_size = chunk_size
|
||||
self.frame_shift = frame_shift
|
||||
self.subsampling = subsampling
|
||||
self.n_speakers = num_speakers
|
||||
self.chunk_indices = [] if mode != "test" else {}
|
||||
|
||||
self.data = KaldiData(self.data_dir)
|
||||
self.all_speakers = sorted(self.data.spk2utt.keys())
|
||||
self.all_n_speakers = len(self.all_speakers)
|
||||
|
||||
# make chunk indices: filepath, start_frame, end_frame
|
||||
for rec in self.data.wavs:
|
||||
data_len = int(self.data.reco2dur[rec] * sampling_rate / frame_shift)
|
||||
data_len = int(data_len / self.subsampling)
|
||||
if mode == "test":
|
||||
self.chunk_indices[rec] = []
|
||||
if mode != "test":
|
||||
for st, ed in _gen_frame_indices(data_len, chunk_size, chunk_size):
|
||||
self.chunk_indices.append(
|
||||
(rec, st * self.subsampling, ed * self.subsampling)
|
||||
)
|
||||
else:
|
||||
for st, ed in _gen_chunk_indices(data_len, chunk_size):
|
||||
self.chunk_indices[rec].append(
|
||||
(rec, st, ed)
|
||||
)
|
||||
|
||||
if mode != "test":
|
||||
if filter_spk:
|
||||
self.filter_spk()
|
||||
print(len(self.chunk_indices), " chunks")
|
||||
else:
|
||||
self.rec_list = list(self.chunk_indices.keys())
|
||||
print(len(self.rec_list), " recordings")
|
||||
|
||||
def __len__(self):
|
||||
return (
|
||||
len(self.rec_list)
|
||||
if type(self.chunk_indices) == dict
|
||||
else len(self.chunk_indices)
|
||||
)
|
||||
|
||||
def filter_spk(self):
|
||||
# filter the spk in spk2utt but will not be used in training
|
||||
# i.e. the chunks contains more spk than self.n_speakers
|
||||
occur_spk_set = set()
|
||||
|
||||
new_chunk_indices = [] # filter the chunk that more than self.num_speakers
|
||||
for idx in range(self.__len__()):
|
||||
rec, st, ed = self.chunk_indices[idx]
|
||||
|
||||
filtered_segments = self.data.segments[rec]
|
||||
# all the speakers in this recording not the chunk
|
||||
speakers = np.unique(
|
||||
[self.data.utt2spk[seg['utt']] for seg in filtered_segments]
|
||||
).tolist()
|
||||
n_speakers = self.n_speakers
|
||||
# we assume that in each chunk the speaker number is less or equal than self.n_speakers
|
||||
# but the speaker number in the whole recording may exceed self.n_speakers
|
||||
if self.n_speakers < len(speakers):
|
||||
n_speakers = len(speakers)
|
||||
|
||||
# Y: (length,), T: (frame_num, n_speakers)
|
||||
Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
|
||||
# the spk index exist in this chunk data
|
||||
exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
|
||||
chunk_spk_num = np.sum(exist_spk_idx)
|
||||
if chunk_spk_num <= self.n_speakers:
|
||||
spk_arr = np.array(speakers)
|
||||
valid_spk_arr = spk_arr[exist_spk_idx[:spk_arr.shape[0]]]
|
||||
for spk in valid_spk_arr:
|
||||
occur_spk_set.add(spk)
|
||||
|
||||
new_chunk_indices.append((rec, st, ed))
|
||||
self.chunk_indices = new_chunk_indices
|
||||
self.all_speakers = sorted(list(occur_spk_set))
|
||||
self.all_n_speakers = len(self.all_speakers)
|
||||
|
||||
def __getitem__(self, i):
|
||||
if self.mode != "test":
|
||||
rec, st, ed = self.chunk_indices[i]
|
||||
|
||||
filtered_segments = self.data.segments[rec]
|
||||
# all the speakers in this recording not the chunk
|
||||
speakers = np.unique(
|
||||
[self.data.utt2spk[seg['utt']] for seg in filtered_segments]
|
||||
).tolist()
|
||||
n_speakers = self.n_speakers
|
||||
# we assume that in each chunk the speaker number is less or equal than self.n_speakers
|
||||
# but the speaker number in the whole recording may exceed self.n_speakers
|
||||
if self.n_speakers < len(speakers):
|
||||
n_speakers = len(speakers)
|
||||
|
||||
# Y: (length,), T: (frame_num, n_speakers)
|
||||
Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
|
||||
# the spk index exist in this chunk data
|
||||
exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
|
||||
chunk_spk_num = np.sum(exist_spk_idx)
|
||||
if chunk_spk_num > self.n_speakers:
|
||||
# the speaker number in a chunk exceed our pre-set value
|
||||
return None, None, None
|
||||
|
||||
# the map from within recording speaker index to global speaker index
|
||||
S_arr = -1 * np.ones(n_speakers).astype(np.int64)
|
||||
for seg in filtered_segments:
|
||||
speaker_index = speakers.index(self.data.utt2spk[seg['utt']])
|
||||
try:
|
||||
all_speaker_index = self.all_speakers.index(
|
||||
self.data.utt2spk[seg['utt']])
|
||||
except:
|
||||
# we have pre-filter some spk in self.filter_spk
|
||||
all_speaker_index = -1
|
||||
S_arr[speaker_index] = all_speaker_index
|
||||
# If T[:, n_speakers - 1] == 0.0, then S_arr[n_speakers - 1] == -1,
|
||||
# so S_arr[n_speakers - 1] is not used for training,
|
||||
# e.g., in the case of training 3-spk model with 2-spk data
|
||||
|
||||
# filter the speaker not exist in this chunk and ensure there are self.num_speakers outputs
|
||||
T_exist = T[:,exist_spk_idx]
|
||||
T = np.zeros((T_exist.shape[0], self.n_speakers), dtype=np.int32)
|
||||
T[:,:T_exist.shape[1]] = T_exist
|
||||
# subsampling for Y will be done in the model forward function
|
||||
T = T[::self.subsampling]
|
||||
|
||||
S_arr_exist = S_arr[exist_spk_idx]
|
||||
S_arr = -1 * np.ones(self.n_speakers).astype(np.int64)
|
||||
S_arr[:S_arr_exist.shape[0]] = S_arr_exist
|
||||
|
||||
n = np.arange(self.all_n_speakers, dtype=np.int64).reshape(self.all_n_speakers, 1)
|
||||
return Y, T, S_arr, n, T.shape[0]
|
||||
else:
|
||||
len_ratio = self.frame_shift * self.subsampling
|
||||
chunks = self.chunk_indices[self.rec_list[i]]
|
||||
Ys = []
|
||||
chunk_len_list = []
|
||||
for (rec, st, ed) in chunks:
|
||||
chunk_len = ed - st
|
||||
if chunk_len != self.chunk_size:
|
||||
st = max(0, ed - self.chunk_size)
|
||||
Y, _ = self.data.load_wav(rec, st * len_ratio, ed * len_ratio)
|
||||
Ys.append(Y)
|
||||
chunk_len_list.append(chunk_len)
|
||||
return Ys, self.rec_list[i], chunk_len_list
|
||||
|
||||
def get_allnspk(self):
|
||||
return self.all_n_speakers
|
||||
|
||||
def _get_labeled_speech(
|
||||
self, rec, start, end, n_speakers=None, use_speaker_id=False
|
||||
):
|
||||
"""Extracts speech chunks and corresponding labels
|
||||
|
||||
Extracts speech chunks and corresponding diarization labels for
|
||||
given recording id and start/end times
|
||||
|
||||
Args:
|
||||
rec (str): recording id
|
||||
start (int): start frame index
|
||||
end (int): end frame index
|
||||
n_speakers (int): number of speakers
|
||||
if None, the value is given from data
|
||||
Returns:
|
||||
data: speech chunk
|
||||
(n_samples)
|
||||
T: label
|
||||
(n_frmaes, n_speakers)-shaped np.int32 array.
|
||||
"""
|
||||
data, rate = self.data.load_wav(
|
||||
rec, start * self.frame_shift, end * self.frame_shift
|
||||
)
|
||||
frame_num = end - start
|
||||
filtered_segments = self.data.segments[rec]
|
||||
# filtered_segments = self.data.segments[self.data.segments['rec'] == rec]
|
||||
speakers = np.unique(
|
||||
[self.data.utt2spk[seg["utt"]] for seg in filtered_segments]
|
||||
).tolist()
|
||||
if n_speakers is None:
|
||||
n_speakers = len(speakers)
|
||||
T = np.zeros((frame_num, n_speakers), dtype=np.int32)
|
||||
|
||||
if use_speaker_id:
|
||||
all_speakers = sorted(self.data.spk2utt.keys())
|
||||
S = np.zeros((frame_num, len(all_speakers)), dtype=np.int32)
|
||||
|
||||
for seg in filtered_segments:
|
||||
speaker_index = speakers.index(self.data.utt2spk[seg["utt"]])
|
||||
if use_speaker_id:
|
||||
all_speaker_index = all_speakers.index(self.data.utt2spk[seg["utt"]])
|
||||
start_frame = np.rint(seg["st"] * rate / self.frame_shift).astype(int)
|
||||
end_frame = np.rint(seg["et"] * rate / self.frame_shift).astype(int)
|
||||
rel_start = rel_end = None
|
||||
if start <= start_frame and start_frame < end:
|
||||
rel_start = start_frame - start
|
||||
if start < end_frame and end_frame <= end:
|
||||
rel_end = end_frame - start
|
||||
if rel_start is not None or rel_end is not None:
|
||||
T[rel_start:rel_end, speaker_index] = 1
|
||||
if use_speaker_id:
|
||||
S[rel_start:rel_end, all_speaker_index] = 1
|
||||
|
||||
if use_speaker_id:
|
||||
return data, T, S
|
||||
else:
|
||||
return data, T
|
||||
|
||||
def collate_fn(self, batch):
|
||||
valid_samples = [sample for sample in batch if sample[0] is not None]
|
||||
|
||||
wav_list, binary_label_list, spk_label_list= [], [], []
|
||||
all_spk_idx_list, len_list = [], []
|
||||
for sample in valid_samples:
|
||||
wav_list.append(torch.from_numpy(sample[0]).float())
|
||||
binary_label_list.append(torch.from_numpy(sample[1]).long())
|
||||
spk_label_list.append(torch.from_numpy(sample[2]).long())
|
||||
all_spk_idx_list.append(torch.from_numpy(sample[3]).long())
|
||||
len_list.append(sample[4])
|
||||
wav_batch = pad_sequence(wav_list, batch_first=True, padding_value=0.0)
|
||||
binary_label_batch = pad_sequence(binary_label_list, batch_first=True, padding_value=1).long()
|
||||
spk_label_batch = torch.stack(spk_label_list)
|
||||
all_spk_idx_batch = torch.stack(all_spk_idx_list)
|
||||
len_batch = torch.LongTensor(len_list)
|
||||
|
||||
return wav_batch, binary_label_batch.float(), spk_label_batch, all_spk_idx_batch, len_batch
|
||||
|
||||
def collate_fn_infer(self, batch):
|
||||
assert len(batch) == 1 # each batch should contain one recording
|
||||
Ys, rec, chunk_len_list = batch[0]
|
||||
wav_list = [torch.from_numpy(Y).float() for Y in Ys]
|
||||
|
||||
return wav_list, rec, chunk_len_list
|
||||
|
||||
|
||||
#######################
|
||||
# Kaldi-style Dataset #
|
||||
#######################
|
||||
class KaldiData:
|
||||
"""This class holds data in kaldi-style directory."""
|
||||
|
||||
def __init__(self, data_dir):
|
||||
"""Load kaldi data directory."""
|
||||
self.data_dir = data_dir
|
||||
self.segments = self._load_segments_rechash(
|
||||
os.path.join(self.data_dir, "segments")
|
||||
)
|
||||
self.utt2spk = self._load_utt2spk(os.path.join(self.data_dir, "utt2spk"))
|
||||
self.wavs = self._load_wav_scp(os.path.join(self.data_dir, "wav.scp"))
|
||||
self.reco2dur = self._load_reco2dur(os.path.join(self.data_dir, "reco2dur"))
|
||||
self.spk2utt = self._load_spk2utt(os.path.join(self.data_dir, "spk2utt"))
|
||||
|
||||
def load_wav(self, recid, start=0, end=None):
|
||||
"""Load wavfile given recid, start time and end time."""
|
||||
data, rate = self._load_wav(self.wavs[recid], start, end)
|
||||
return data, rate
|
||||
|
||||
def _load_segments(self, segments_file):
|
||||
"""Load segments file as array."""
|
||||
if not os.path.exists(segments_file):
|
||||
return None
|
||||
return np.loadtxt(
|
||||
segments_file,
|
||||
dtype=[("utt", "object"), ("rec", "object"), ("st", "f"), ("et", "f")],
|
||||
ndmin=1,
|
||||
)
|
||||
|
||||
def _load_segments_hash(self, segments_file):
|
||||
"""Load segments file as dict with uttid index."""
|
||||
ret = {}
|
||||
if not os.path.exists(segments_file):
|
||||
return None
|
||||
for line in open(segments_file):
|
||||
utt, rec, st, et = line.strip().split()
|
||||
ret[utt] = (rec, float(st), float(et))
|
||||
return ret
|
||||
|
||||
def _load_segments_rechash(self, segments_file):
|
||||
"""Load segments file as dict with recid index."""
|
||||
ret = {}
|
||||
if not os.path.exists(segments_file):
|
||||
return None
|
||||
for line in open(segments_file):
|
||||
utt, rec, st, et = line.strip().split()
|
||||
if rec not in ret:
|
||||
ret[rec] = []
|
||||
ret[rec].append({"utt": utt, "st": float(st), "et": float(et)})
|
||||
return ret
|
||||
|
||||
def _load_wav_scp(self, wav_scp_file):
|
||||
"""Return dictionary { rec: wav_rxfilename }."""
|
||||
if os.path.exists(wav_scp_file):
|
||||
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
|
||||
return {x[0]: x[1] for x in lines}
|
||||
else:
|
||||
wav_dir = os.path.join(self.data_dir, "wav")
|
||||
return {
|
||||
os.path.splitext(filename)[0]: os.path.join(wav_dir, filename)
|
||||
for filename in sorted(os.listdir(wav_dir))
|
||||
}
|
||||
|
||||
def _load_wav(self, wav_rxfilename, start=0, end=None):
|
||||
"""This function reads audio file and return data in numpy.float32 array.
|
||||
"lru_cache" holds recently loaded audio so that can be called
|
||||
many times on the same audio file.
|
||||
OPTIMIZE: controls lru_cache size for random access,
|
||||
considering memory size
|
||||
"""
|
||||
if wav_rxfilename.endswith("|"):
|
||||
# input piped command
|
||||
p = subprocess.Popen(
|
||||
wav_rxfilename[:-1],
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
data, samplerate = sf.read(
|
||||
io.BytesIO(p.stdout.read()),
|
||||
dtype="float32",
|
||||
)
|
||||
# cannot seek
|
||||
data = data[start:end]
|
||||
elif wav_rxfilename == "-":
|
||||
# stdin
|
||||
data, samplerate = sf.read(sys.stdin, dtype="float32")
|
||||
# cannot seek
|
||||
data = data[start:end]
|
||||
else:
|
||||
# normal wav file
|
||||
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
|
||||
return data, samplerate
|
||||
|
||||
def _load_utt2spk(self, utt2spk_file):
|
||||
"""Returns dictionary { uttid: spkid }."""
|
||||
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
|
||||
return {x[0]: x[1] for x in lines}
|
||||
|
||||
def _load_spk2utt(self, spk2utt_file):
|
||||
"""Returns dictionary { spkid: list of uttids }."""
|
||||
if not os.path.exists(spk2utt_file):
|
||||
return None
|
||||
lines = [line.strip().split() for line in open(spk2utt_file)]
|
||||
return {x[0]: x[1:] for x in lines}
|
||||
|
||||
def _load_reco2dur(self, reco2dur_file):
|
||||
"""Returns dictionary { recid: duration }."""
|
||||
if not os.path.exists(reco2dur_file):
|
||||
return None
|
||||
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
|
||||
return {x[0]: float(x[1]) for x in lines}
|
||||
|
||||
def _process_wav(self, wav_rxfilename, process):
|
||||
"""This function returns preprocessed wav_rxfilename.
|
||||
Args:
|
||||
wav_rxfilename:
|
||||
input
|
||||
process:
|
||||
command which can be connected via pipe, use stdin and stdout
|
||||
Returns:
|
||||
wav_rxfilename: output piped command
|
||||
"""
|
||||
if wav_rxfilename.endswith("|"):
|
||||
# input piped command
|
||||
return wav_rxfilename + process + "|"
|
||||
# stdin "-" or normal file
|
||||
return "cat {0} | {1} |".format(wav_rxfilename, process)
|
||||
|
||||
def _extract_segments(self, wavs, segments=None):
|
||||
"""This function returns generator of segmented audio.
|
||||
Yields (utterance id, numpy.float32 array).
|
||||
TODO?: sampling rate is not converted.
|
||||
"""
|
||||
if segments is not None:
|
||||
# segments should be sorted by rec-id
|
||||
for seg in segments:
|
||||
wav = wavs[seg["rec"]]
|
||||
data, samplerate = self.load_wav(wav)
|
||||
st_sample = np.rint(seg["st"] * samplerate).astype(int)
|
||||
et_sample = np.rint(seg["et"] * samplerate).astype(int)
|
||||
yield seg["utt"], data[st_sample:et_sample]
|
||||
else:
|
||||
# segments file not found,
|
||||
# wav.scp is used as segmented audio list
|
||||
for rec in wavs:
|
||||
data, samplerate = self.load_wav(wavs[rec])
|
||||
yield rec, data
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = {
|
||||
'mode': 'train',
|
||||
'data_dir': "/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/data/train_clean_5_ns3_beta2_500",
|
||||
'chunk_size': 2001,
|
||||
'frame_shift': 256,
|
||||
'sampling_rate': 8000,
|
||||
'num_speakers':3
|
||||
}
|
||||
|
||||
torch.manual_seed(6)
|
||||
dataset = DiarizationDataset(**args)
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn)
|
||||
data_iter = iter(dataloader)
|
||||
# wav_batch, binary_label_batch, spk_label_batch, all_spk_idx_batch, len_batch = next(data_iter)
|
||||
data = next(data_iter)
|
||||
for val in data:
|
||||
print(val.shape)
|
||||
|
||||
# from torch.utils.data import DataLoader
|
||||
# dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn_infer)
|
||||
# data_iter = iter(dataloader)
|
||||
# wav_list, binary_label_list, rec = next(data_iter)
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# This library provides utilities for kaldi-style data directory.
|
||||
|
||||
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import soundfile as sf
|
||||
import io
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
def load_segments(segments_file):
|
||||
""" load segments file as array """
|
||||
if not os.path.exists(segments_file):
|
||||
return None
|
||||
return np.loadtxt(
|
||||
segments_file,
|
||||
dtype=[('utt', 'object'),
|
||||
('rec', 'object'),
|
||||
('st', 'f'),
|
||||
('et', 'f')],
|
||||
ndmin=1)
|
||||
|
||||
|
||||
def load_segments_hash(segments_file):
|
||||
ret = {}
|
||||
if not os.path.exists(segments_file):
|
||||
return None
|
||||
for line in open(segments_file):
|
||||
utt, rec, st, et = line.strip().split()
|
||||
ret[utt] = (rec, float(st), float(et))
|
||||
return ret
|
||||
|
||||
|
||||
def load_segments_rechash(segments_file):
|
||||
ret = {}
|
||||
if not os.path.exists(segments_file):
|
||||
return None
|
||||
for line in open(segments_file):
|
||||
utt, rec, st, et = line.strip().split()
|
||||
if rec not in ret:
|
||||
ret[rec] = []
|
||||
ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
|
||||
return ret
|
||||
|
||||
|
||||
def load_wav_scp(wav_scp_file):
|
||||
""" return dictionary { rec: wav_rxfilename } """
|
||||
lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
|
||||
return {x[0]: x[1] for x in lines}
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def load_wav(wav_rxfilename, start=0, end=None):
|
||||
""" This function reads audio file and return data in numpy.float32 array.
|
||||
"lru_cache" holds recently loaded audio so that can be called
|
||||
many times on the same audio file.
|
||||
OPTIMIZE: controls lru_cache size for random access,
|
||||
considering memory size
|
||||
"""
|
||||
if wav_rxfilename.endswith('|'):
|
||||
# input piped command
|
||||
p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
|
||||
stdout=subprocess.PIPE)
|
||||
data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
|
||||
dtype='float32')
|
||||
# cannot seek
|
||||
data = data[start:end]
|
||||
elif wav_rxfilename == '-':
|
||||
# stdin
|
||||
data, samplerate = sf.read(sys.stdin, dtype='float32')
|
||||
# cannot seek
|
||||
data = data[start:end]
|
||||
else:
|
||||
# normal wav file
|
||||
data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
|
||||
return data, samplerate
|
||||
|
||||
|
||||
def load_utt2spk(utt2spk_file):
|
||||
""" returns dictionary { uttid: spkid } """
|
||||
lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
|
||||
return {x[0]: x[1] for x in lines}
|
||||
|
||||
|
||||
def load_spk2utt(spk2utt_file):
|
||||
""" returns dictionary { spkid: list of uttids } """
|
||||
if not os.path.exists(spk2utt_file):
|
||||
return None
|
||||
lines = [line.strip().split() for line in open(spk2utt_file)]
|
||||
return {x[0]: x[1:] for x in lines}
|
||||
|
||||
|
||||
def load_reco2dur(reco2dur_file):
|
||||
""" returns dictionary { recid: duration } """
|
||||
if not os.path.exists(reco2dur_file):
|
||||
return None
|
||||
lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
|
||||
return {x[0]: float(x[1]) for x in lines}
|
||||
|
||||
|
||||
def process_wav(wav_rxfilename, process):
|
||||
""" This function returns preprocessed wav_rxfilename
|
||||
Args:
|
||||
wav_rxfilename: input
|
||||
process: command which can be connected via pipe,
|
||||
use stdin and stdout
|
||||
Returns:
|
||||
wav_rxfilename: output piped command
|
||||
"""
|
||||
if wav_rxfilename.endswith('|'):
|
||||
# input piped command
|
||||
return wav_rxfilename + process + "|"
|
||||
else:
|
||||
# stdin "-" or normal file
|
||||
return "cat {} | {} |".format(wav_rxfilename, process)
|
||||
|
||||
|
||||
def extract_segments(wavs, segments=None):
|
||||
""" This function returns generator of segmented audio as
|
||||
(utterance id, numpy.float32 array)
|
||||
TODO?: sampling rate is not converted.
|
||||
"""
|
||||
if segments is not None:
|
||||
# segments should be sorted by rec-id
|
||||
for seg in segments:
|
||||
wav = wavs[seg['rec']]
|
||||
data, samplerate = load_wav(wav)
|
||||
st_sample = np.rint(seg['st'] * samplerate).astype(int)
|
||||
et_sample = np.rint(seg['et'] * samplerate).astype(int)
|
||||
yield seg['utt'], data[st_sample:et_sample]
|
||||
else:
|
||||
# segments file not found,
|
||||
# wav.scp is used as segmented audio list
|
||||
for rec in wavs:
|
||||
data, samplerate = load_wav(wavs[rec])
|
||||
yield rec, data
|
||||
|
||||
|
||||
class KaldiData:
|
||||
def __init__(self, data_dir):
|
||||
self.data_dir = data_dir
|
||||
self.segments = load_segments_rechash(
|
||||
os.path.join(self.data_dir, 'segments'))
|
||||
self.utt2spk = load_utt2spk(
|
||||
os.path.join(self.data_dir, 'utt2spk'))
|
||||
self.wavs = load_wav_scp(
|
||||
os.path.join(self.data_dir, 'wav.scp'))
|
||||
self.reco2dur = load_reco2dur(
|
||||
os.path.join(self.data_dir, 'reco2dur'))
|
||||
self.spk2utt = load_spk2utt(
|
||||
os.path.join(self.data_dir, 'spk2utt'))
|
||||
|
||||
def load_wav(self, recid, start=0, end=None):
|
||||
data, rate = load_wav(
|
||||
self.wavs[recid], start, end)
|
||||
return data, rate
|
|
@ -0,0 +1,97 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
|
||||
# Arnab Ghoshal, Karel Vesely
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Parse command-line options.
|
||||
# To be sourced by another script (as in ". parse_options.sh").
|
||||
# Option format is: --option-name arg
|
||||
# and shell variable "option_name" gets set to value "arg."
|
||||
# The exception is --help, which takes no arguments, but prints the
|
||||
# $help_message variable (if defined).
|
||||
|
||||
|
||||
###
|
||||
### The --config file options have lower priority to command line
|
||||
### options, so we need to import them first...
|
||||
###
|
||||
|
||||
# Now import all the configs specified by command-line, in left-to-right order
|
||||
for ((argpos=1; argpos<$#; argpos++)); do
|
||||
if [ "${!argpos}" == "--config" ]; then
|
||||
argpos_plus1=$((argpos+1))
|
||||
config=${!argpos_plus1}
|
||||
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
|
||||
. $config # source the config file.
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
###
|
||||
### Now we process the command line options
|
||||
###
|
||||
while true; do
|
||||
[ -z "${1:-}" ] && break; # break if there are no arguments
|
||||
case "$1" in
|
||||
# If the enclosing script is called with --help option, print the help
|
||||
# message and exit. Scripts should put help messages in $help_message
|
||||
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
|
||||
else printf "$help_message\n" 1>&2 ; fi;
|
||||
exit 0 ;;
|
||||
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
|
||||
exit 1 ;;
|
||||
# If the first command-line argument begins with "--" (e.g. --foo-bar),
|
||||
# then work out the variable name as $name, which will equal "foo_bar".
|
||||
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
|
||||
# Next we test whether the variable in question is undefned-- if so it's
|
||||
# an invalid option and we die. Note: $0 evaluates to the name of the
|
||||
# enclosing script.
|
||||
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
|
||||
# is undefined. We then have to wrap this test inside "eval" because
|
||||
# foo_bar is itself inside a variable ($name).
|
||||
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
||||
|
||||
oldval="`eval echo \\$$name`";
|
||||
# Work out whether we seem to be expecting a Boolean argument.
|
||||
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
|
||||
was_bool=true;
|
||||
else
|
||||
was_bool=false;
|
||||
fi
|
||||
|
||||
# Set the variable to the right value-- the escaped quotes make it work if
|
||||
# the option had spaces, like --cmd "queue.pl -sync y"
|
||||
eval $name=\"$2\";
|
||||
|
||||
# Check that Boolean-valued arguments are really Boolean.
|
||||
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
||||
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
||||
exit 1;
|
||||
fi
|
||||
shift 2;
|
||||
;;
|
||||
*) break;
|
||||
esac
|
||||
done
|
||||
|
||||
|
||||
# Check for an empty argument to the --cmd option, which can easily occur as a
|
||||
# result of scripting errors.
|
||||
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
|
||||
|
||||
|
||||
true; # so this script returns exit code 0.
|
|
@ -0,0 +1,189 @@
|
|||
import os
|
||||
import struct
|
||||
import logging
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import random
|
||||
import yaml
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ------------------------------ Logger ------------------------------
|
||||
# log to console or a file
|
||||
def get_logger(
|
||||
name,
|
||||
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
|
||||
date_format="%Y-%m-%d %H:%M:%S",
|
||||
file=False):
|
||||
"""
|
||||
Get python logger instance
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
# file or console
|
||||
handler = logging.StreamHandler() if not file else logging.FileHandler(
|
||||
name)
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
|
||||
# log to concole and file at the same time
|
||||
def get_logger_2(
|
||||
name,
|
||||
format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
|
||||
date_format="%Y-%m-%d %H:%M:%S"):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Create handlers
|
||||
c_handler = logging.StreamHandler()
|
||||
f_handler = logging.FileHandler(name)
|
||||
c_handler.setLevel(logging.INFO)
|
||||
f_handler.setLevel(logging.INFO)
|
||||
|
||||
# Create formatters and add it to handlers
|
||||
c_format = logging.Formatter(fmt=format_str, datefmt=date_format)
|
||||
f_format = logging.Formatter(fmt=format_str, datefmt=date_format)
|
||||
c_handler.setFormatter(c_format)
|
||||
f_handler.setFormatter(f_format)
|
||||
|
||||
# Add handlers to the logger
|
||||
logger.addHandler(c_handler)
|
||||
logger.addHandler(f_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
# ------------------------------ Logger ------------------------------
|
||||
|
||||
# ------------------------------ Pytorch Distributed Training ------------------------------
|
||||
def getoneNode():
|
||||
nodelist = os.environ['SLURM_JOB_NODELIST']
|
||||
nodelist = nodelist.strip().split(',')[0]
|
||||
import re
|
||||
text = re.split('[-\[\]]', nodelist)
|
||||
if ('' in text):
|
||||
text.remove('')
|
||||
return text[0] + '-' + text[1] + '-' + text[2]
|
||||
|
||||
|
||||
def dist_init(host_addr, rank, local_rank, world_size, port=23456):
|
||||
host_addr_full = 'tcp://' + host_addr + ':' + str(port)
|
||||
dist.init_process_group("nccl", init_method=host_addr_full,
|
||||
rank=rank, world_size=world_size)
|
||||
num_gpus = torch.cuda.device_count()
|
||||
# torch.cuda.set_device(local_rank)
|
||||
assert dist.is_initialized()
|
||||
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def average_gradients(model, world_size):
|
||||
size = float(world_size)
|
||||
for param in model.parameters():
|
||||
if (param.requires_grad and param.grad is not None):
|
||||
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
|
||||
param.grad.data /= size
|
||||
|
||||
|
||||
def data_reduce(data):
|
||||
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
||||
return data / torch.distributed.get_world_size()
|
||||
|
||||
|
||||
# ------------------------------ Pytorch Distributed Training ------------------------------
|
||||
|
||||
|
||||
# ------------------------------ Hyper-parameter Dynamic Change ------------------------------
|
||||
def reduce_lr(optimizer, initial_lr, final_lr, current_iter, max_iter, coeff=1.0):
|
||||
current_lr = coeff * math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = current_lr
|
||||
|
||||
|
||||
def get_reduce_lr(initial_lr, final_lr, current_iter, max_iter):
|
||||
current_lr = math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
|
||||
return current_lr
|
||||
|
||||
|
||||
def set_lr(optimizer, lr):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# ------------------------------ Hyper-parameter Dynamic Change ------------------------------
|
||||
|
||||
# ---------------------- About Configuration --------------------
|
||||
def parse_config_or_kwargs(config_file, **kwargs):
|
||||
with open(config_file) as con_read:
|
||||
yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
|
||||
# passed kwargs will override yaml config
|
||||
return dict(yaml_config, **kwargs)
|
||||
|
||||
|
||||
def store_yaml(config_file, store_path, **kwargs):
|
||||
with open(config_file, 'r') as f:
|
||||
config_lines = f.readlines()
|
||||
|
||||
keys_list = list(kwargs.keys())
|
||||
with open(store_path, 'w') as f:
|
||||
for line in config_lines:
|
||||
if ':' in line and line.split(':')[0] in keys_list:
|
||||
key = line.split(':')[0]
|
||||
line = '{}: {}\n'.format(key, kwargs[key])
|
||||
f.write(line)
|
||||
|
||||
|
||||
# ---------------------- About Configuration --------------------
|
||||
|
||||
|
||||
def check_dir(dir):
|
||||
if not os.path.exists(dir):
|
||||
os.mkdir(dir)
|
||||
|
||||
|
||||
def set_seed(seed=66):
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
# when store the model wrongly with "module" involved,
|
||||
# we remove it here
|
||||
def correct_key(state_dict):
|
||||
keys = list(state_dict.keys())
|
||||
if 'module' not in keys[0]:
|
||||
return state_dict
|
||||
else:
|
||||
new_state_dict = {}
|
||||
for key in keys:
|
||||
new_key = '.'.join(key.split('.')[1:])
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def validate_path(dir_name):
|
||||
"""
|
||||
:param dir_name: Create the directory if it doesn't exist
|
||||
:return: None
|
||||
"""
|
||||
dir_name = os.path.dirname(dir_name) # get the path
|
||||
if not os.path.exists(dir_name) and (dir_name != ''):
|
||||
os.makedirs(dir_name)
|
||||
|
||||
|
||||
def get_lr(optimizer):
|
||||
for param_group in optimizer.param_groups:
|
||||
return param_group['lr']
|
Загрузка…
Ссылка в новой задаче