Add support for personalized and improved models

This commit is contained in:
Vishak Gopal 2022-05-23 23:35:44 -07:00
Родитель 4ac5283138
Коммит 82f1b17e77
6 изменённых файлов: 139 добавлений и 86 удалений

Просмотреть файл

Просмотреть файл

Двоичные данные
DNSMOS/DNSMOS/sig_bak_ovr.onnx Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -7,13 +7,17 @@ There are two ways to use DNSMOS:
1. Using the Web-API. The benefit here is that computation happens on the cloud and will always have the latest models.
2. Local evaluation using the models uploaded locally to this GitHub repo. We will try to keep this model in sync with the cloud but there are no guarantees.
To use the Web-API:
### To use the Web-API:
Please complete the following form: https://forms.office.com/r/pRhyZ0mQy3
We will send you the **AUTH_KEY** that you can insert in the **dnsmos.py** script.
Example command for P.835 evaluation of test clips: python dnsmos --testset_dir <test clips directory> --method p835
To use the local evaluation method:
### To use the local evaluation method:
Use the **dnsmos_local.py** script.
1. To compute a personalized MOS score (where interfering speaker is penalized) provide the '-p' argument
Ex: python dnsmos_local.py -t C:\temp\SampleClips -o sample.csv -p
2. To compute a regular MOS score omit the '-p' argument.
Ex: python dnsmos_local.py -t C:\temp\SampleClips -o sample.csv
## Citation:
If you have used the API for your research and development purpose, please cite the DNSMOS paper:

Просмотреть файл

@ -1,113 +1,162 @@
# Usage:
# python dnsmos_local.py -ms sig.onnx -mbo bak_ovr.onnx
# -t .
# python dnsmos_local.py -t c:\temp\DNSChallenge4_Blindset -o DNSCh4_Blind.csv -p
#
import argparse
import concurrent.futures
import glob
import os
import csv
import numpy as np
import math
import argparse
import soundfile as sf
import random
import librosa
import onnxruntime as ort
import numpy as np
import numpy.polynomial.polynomial as poly
import onnxruntime as ort
import pandas as pd
import soundfile as sf
from requests import session
from tqdm import tqdm
# Coefficients for polynomial fitting
COEFS_SIG = np.array([9.651228012789436761e-01, 6.592637550310214145e-01,
7.572372955623894730e-02])
COEFS_BAK = np.array([-3.733460011101781717e+00,2.700114234092929166e+00,
-1.721332907340922813e-01])
COEFS_OVR = np.array([8.924546794696789354e-01, 6.609981731940616223e-01,
7.600269530243179694e-02])
SAMPLING_RATE = 16000
INPUT_LENGTH = 9.01
def init_session(model_path):
sess = ort.InferenceSession(model_path)
return sess
class PickableInferenceSession: # This is a wrapper to make the current InferenceSession class pickable.
def __init__(self, model_path):
self.model_path = model_path
self.sess = init_session(self.model_path)
def run(self, *args):
return self.sess.run(*args)
def __getstate__(self):
return {'model_path': self.model_path}
def __setstate__(self, values):
self.model_path = values['model_path']
self.sess = init_session(self.model_path)
class ComputeScore:
def __init__(self, primary_model_path) -> None:
self.onnx_sess = PickableInferenceSession(primary_model_path)
def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS):
if is_personalized_MOS:
p_ovr = np.poly1d([-0.00533021, 0.005101 , 1.18058466, -0.11236046])
p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726])
p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611 , 0.96883132])
else:
p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439 ])
p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])
sig_poly = p_sig(sig)
bak_poly = p_bak(bak)
ovr_poly = p_ovr(ovr)
return sig_poly, bak_poly, ovr_poly
def __call__(self, fpath, sampling_rate, is_personalized_MOS):
aud, input_fs = sf.read(fpath)
fs = sampling_rate
if input_fs != fs:
audio = librosa.resample(aud, input_fs, fs)
else:
audio = aud
actual_audio_len = len(audio)
len_samples = int(INPUT_LENGTH*fs)
while len(audio) < len_samples:
audio = np.append(audio, audio)
num_hops = int(np.floor(len(audio)/fs) - INPUT_LENGTH)+1
hop_len_samples = fs
predicted_mos_sig_seg_raw = []
predicted_mos_bak_seg_raw = []
predicted_mos_ovr_seg_raw = []
predicted_mos_sig_seg = []
predicted_mos_bak_seg = []
predicted_mos_ovr_seg = []
for idx in range(num_hops):
audio_seg = audio[int(idx*hop_len_samples) : int((idx+INPUT_LENGTH)*hop_len_samples)]
if len(audio_seg) < len_samples:
continue
input_features = np.array(audio_seg).astype('float32')[np.newaxis,:]
oi = {'input_1': input_features}
mos_sig_raw,mos_bak_raw,mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0]
mos_sig,mos_bak,mos_ovr = self.get_polyfit_val(mos_sig_raw,mos_bak_raw,mos_ovr_raw,is_personalized_MOS)
predicted_mos_sig_seg_raw.append(mos_sig_raw)
predicted_mos_bak_seg_raw.append(mos_bak_raw)
predicted_mos_ovr_seg_raw.append(mos_ovr_raw)
predicted_mos_sig_seg.append(mos_sig)
predicted_mos_bak_seg.append(mos_bak)
predicted_mos_ovr_seg.append(mos_ovr)
clip_dict = {'filename': fpath, 'len_in_sec': actual_audio_len/fs, 'sr':fs}
clip_dict['num_hops'] = num_hops
clip_dict['OVRL_raw'] = np.mean(predicted_mos_ovr_seg_raw)
clip_dict['SIG_raw'] = np.mean(predicted_mos_sig_seg_raw)
clip_dict['BAK_raw'] = np.mean(predicted_mos_bak_seg_raw)
clip_dict['OVRL'] = np.mean(predicted_mos_ovr_seg)
clip_dict['SIG'] = np.mean(predicted_mos_sig_seg)
clip_dict['BAK'] = np.mean(predicted_mos_bak_seg)
return clip_dict
def main(args):
models = glob.glob(os.path.join(args.testset_dir, "*"))
audio_clips_list = []
def audio_logpowspec(audio, nfft=320, hop_length=160, sr=16000):
powspec = (np.abs(librosa.core.stft(audio, n_fft=nfft, hop_length=hop_length)))**2
logpowspec = np.log10(np.maximum(powspec, 10**(-12)))
return logpowspec.T
if args.personalized_MOS:
primary_model_path = os.path.join('pDNSMOS', 'sig_bak_ovr.onnx')
else:
primary_model_path = os.path.join('DNSMOS', 'sig_bak_ovr.onnx')
predicted_mos_sig = []
predicted_mos_bak = []
predicted_mos_ovr = []
audio_clips_list = glob.glob(os.path.join(args.testset_dir, "*.wav"))
session_sig = ort.InferenceSession(args.sig_model_path)
session_bak_ovr = ort.InferenceSession(args.bak_ovr_model_path)
compute_score = ComputeScore(primary_model_path)
if args.csv_path:
csv_path = args.csv_path
else:
csv_path = args.run_name+'.csv'
with open(csv_path, mode='w', newline='') as csvfile:
csvwriter = csv.writer(csvfile, delimiter=',',
quoting=csv.QUOTE_MINIMAL)
csvwriter.writerow(['filename', 'SIG', 'BAK', 'OVR'])
rows = []
clips = []
clips = glob.glob(os.path.join(args.testset_dir, "*.wav"))
is_personalized_eval = args.personalized_MOS
desired_fs = SAMPLING_RATE
for m in tqdm(models):
max_recursion_depth = 10
audio_path = os.path.join(args.testset_dir, m)
audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav"))
while len(audio_clips_list) == 0 and max_recursion_depth > 0:
audio_path = os.path.join(audio_path, "**")
audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav"))
max_recursion_depth -= 1
clips.extend(audio_clips_list)
for i in tqdm(range(len(audio_clips_list))):
fpath = audio_clips_list[i]
audio, fs = sf.read(fpath)
if len(audio)<2*fs:
print('Audio clip is too short. Skipped processing ',
os.path.basename(fpath))
continue
len_samples = int(args.input_length*fs)
while len(audio) < len_samples:
audio = np.append(audio, audio)
num_hops = int(np.floor(len(audio)/fs) - args.input_length)+1
hop_len_samples = fs
predicted_mos_sig_seg = []
predicted_mos_bak_seg = []
predicted_mos_ovr_seg = []
with concurrent.futures.ProcessPoolExecutor() as executor:
future_to_url = {executor.submit(compute_score, clip, desired_fs, is_personalized_eval): clip for clip in clips}
for future in tqdm(concurrent.futures.as_completed(future_to_url)):
clip = future_to_url[future]
try:
data = future.result()
except Exception as exc:
print('%r generated an exception: %s' % (clip, exc))
else:
rows.append(data)
for idx in range(num_hops):
audio_seg = audio[int(idx*hop_len_samples) : int((idx+args.input_length)*hop_len_samples)]
input_features = np.array(audio_logpowspec(audio=audio_seg, sr=fs)).astype('float32')[np.newaxis,:,:]
onnx_inputs_sig = {inp.name: input_features for inp in session_sig.get_inputs()}
mos_sig = poly.polyval(session_sig.run(None, onnx_inputs_sig), COEFS_SIG)
onnx_inputs_bak_ovr = {inp.name: input_features for inp in session_bak_ovr.get_inputs()}
mos_bak_ovr = session_bak_ovr.run(None, onnx_inputs_bak_ovr)
mos_bak = poly.polyval(mos_bak_ovr[0][0][1], COEFS_BAK)
mos_ovr = poly.polyval(mos_bak_ovr[0][0][2], COEFS_OVR)
predicted_mos_sig_seg.append(mos_sig)
predicted_mos_bak_seg.append(mos_bak)
predicted_mos_ovr_seg.append(mos_ovr)
predicted_mos_sig.append(np.mean(predicted_mos_sig_seg))
predicted_mos_bak.append(np.mean(predicted_mos_bak_seg))
predicted_mos_ovr.append(np.mean(predicted_mos_ovr_seg))
csvwriter.writerow([os.path.basename(fpath), np.mean(predicted_mos_sig_seg),
np.mean(predicted_mos_bak_seg), np.mean(predicted_mos_ovr_seg)])
print("The average SIG, BAK and OVR MOS for {0} is {1}, {2}, {3}".format(args.run_name,
str(round(np.mean(predicted_mos_sig), 2)),
str(round(np.mean(predicted_mos_bak), 2)),
str(round(np.mean(predicted_mos_ovr), 2))))
df = pd.DataFrame(rows)
df.to_csv(csv_path)
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-ms', "--sig_model_path", default='sig.onnx',
help='Path to ONNX or ckpt model for SIG prediction')
parser.add_argument('-mbo', "--bak_ovr_model_path", default='bak_ovr.onnx',
help='Path to ONNX or ckpt model for BAK and OVR prediction')
parser.add_argument('-t', "--testset_dir", default='.',
help='Path to the dir containing audio clips in .wav to be evaluated')
parser.add_argument('-o', "--csv_path", default=None, help='Dir to the csv that saves the results')
parser.add_argument('-l', "--input_length", type=int, default=9)
parser.add_argument('-r', "--run_name", type=str, default="dnsmos_p835_inference_sig_bak_ovr_test",
help='Change the name depending on the test set and DNS model being evaluated')
parser.add_argument('-p', "--personalized_MOS", action='store_true',
help='Flag to indicate if personalized MOS score is needed or regular')
args = parser.parse_args()

Двоичные данные
DNSMOS/pDNSMOS/sig_bak_ovr.onnx Normal file

Двоичный файл не отображается.