Add support for personalized and improved models
This commit is contained in:
Родитель
4ac5283138
Коммит
82f1b17e77
Двоичный файл не отображается.
|
@ -7,13 +7,17 @@ There are two ways to use DNSMOS:
|
|||
1. Using the Web-API. The benefit here is that computation happens on the cloud and will always have the latest models.
|
||||
2. Local evaluation using the models uploaded locally to this GitHub repo. We will try to keep this model in sync with the cloud but there are no guarantees.
|
||||
|
||||
To use the Web-API:
|
||||
### To use the Web-API:
|
||||
Please complete the following form: https://forms.office.com/r/pRhyZ0mQy3
|
||||
We will send you the **AUTH_KEY** that you can insert in the **dnsmos.py** script.
|
||||
Example command for P.835 evaluation of test clips: python dnsmos --testset_dir <test clips directory> --method p835
|
||||
|
||||
To use the local evaluation method:
|
||||
### To use the local evaluation method:
|
||||
Use the **dnsmos_local.py** script.
|
||||
1. To compute a personalized MOS score (where interfering speaker is penalized) provide the '-p' argument
|
||||
Ex: python dnsmos_local.py -t C:\temp\SampleClips -o sample.csv -p
|
||||
2. To compute a regular MOS score omit the '-p' argument.
|
||||
Ex: python dnsmos_local.py -t C:\temp\SampleClips -o sample.csv
|
||||
|
||||
## Citation:
|
||||
If you have used the API for your research and development purpose, please cite the DNSMOS paper:
|
||||
|
|
|
@ -1,113 +1,162 @@
|
|||
# Usage:
|
||||
# python dnsmos_local.py -ms sig.onnx -mbo bak_ovr.onnx
|
||||
# -t .
|
||||
|
||||
# python dnsmos_local.py -t c:\temp\DNSChallenge4_Blindset -o DNSCh4_Blind.csv -p
|
||||
#
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import glob
|
||||
import os
|
||||
import csv
|
||||
import numpy as np
|
||||
import math
|
||||
import argparse
|
||||
import soundfile as sf
|
||||
import random
|
||||
|
||||
import librosa
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
import numpy.polynomial.polynomial as poly
|
||||
import onnxruntime as ort
|
||||
import pandas as pd
|
||||
import soundfile as sf
|
||||
from requests import session
|
||||
from tqdm import tqdm
|
||||
|
||||
# Coefficients for polynomial fitting
|
||||
COEFS_SIG = np.array([9.651228012789436761e-01, 6.592637550310214145e-01,
|
||||
7.572372955623894730e-02])
|
||||
COEFS_BAK = np.array([-3.733460011101781717e+00,2.700114234092929166e+00,
|
||||
-1.721332907340922813e-01])
|
||||
COEFS_OVR = np.array([8.924546794696789354e-01, 6.609981731940616223e-01,
|
||||
7.600269530243179694e-02])
|
||||
SAMPLING_RATE = 16000
|
||||
INPUT_LENGTH = 9.01
|
||||
|
||||
def init_session(model_path):
|
||||
sess = ort.InferenceSession(model_path)
|
||||
return sess
|
||||
|
||||
class PickableInferenceSession: # This is a wrapper to make the current InferenceSession class pickable.
|
||||
def __init__(self, model_path):
|
||||
self.model_path = model_path
|
||||
self.sess = init_session(self.model_path)
|
||||
|
||||
def run(self, *args):
|
||||
return self.sess.run(*args)
|
||||
|
||||
def __getstate__(self):
|
||||
return {'model_path': self.model_path}
|
||||
|
||||
def __setstate__(self, values):
|
||||
self.model_path = values['model_path']
|
||||
self.sess = init_session(self.model_path)
|
||||
|
||||
class ComputeScore:
|
||||
def __init__(self, primary_model_path) -> None:
|
||||
self.onnx_sess = PickableInferenceSession(primary_model_path)
|
||||
|
||||
def get_polyfit_val(self, sig, bak, ovr, is_personalized_MOS):
|
||||
if is_personalized_MOS:
|
||||
p_ovr = np.poly1d([-0.00533021, 0.005101 , 1.18058466, -0.11236046])
|
||||
p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726])
|
||||
p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611 , 0.96883132])
|
||||
else:
|
||||
p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535])
|
||||
p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439 ])
|
||||
p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546])
|
||||
|
||||
sig_poly = p_sig(sig)
|
||||
bak_poly = p_bak(bak)
|
||||
ovr_poly = p_ovr(ovr)
|
||||
|
||||
return sig_poly, bak_poly, ovr_poly
|
||||
|
||||
def __call__(self, fpath, sampling_rate, is_personalized_MOS):
|
||||
aud, input_fs = sf.read(fpath)
|
||||
fs = sampling_rate
|
||||
if input_fs != fs:
|
||||
audio = librosa.resample(aud, input_fs, fs)
|
||||
else:
|
||||
audio = aud
|
||||
actual_audio_len = len(audio)
|
||||
len_samples = int(INPUT_LENGTH*fs)
|
||||
while len(audio) < len_samples:
|
||||
audio = np.append(audio, audio)
|
||||
|
||||
num_hops = int(np.floor(len(audio)/fs) - INPUT_LENGTH)+1
|
||||
hop_len_samples = fs
|
||||
predicted_mos_sig_seg_raw = []
|
||||
predicted_mos_bak_seg_raw = []
|
||||
predicted_mos_ovr_seg_raw = []
|
||||
predicted_mos_sig_seg = []
|
||||
predicted_mos_bak_seg = []
|
||||
predicted_mos_ovr_seg = []
|
||||
|
||||
for idx in range(num_hops):
|
||||
audio_seg = audio[int(idx*hop_len_samples) : int((idx+INPUT_LENGTH)*hop_len_samples)]
|
||||
if len(audio_seg) < len_samples:
|
||||
continue
|
||||
|
||||
input_features = np.array(audio_seg).astype('float32')[np.newaxis,:]
|
||||
oi = {'input_1': input_features}
|
||||
mos_sig_raw,mos_bak_raw,mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0]
|
||||
mos_sig,mos_bak,mos_ovr = self.get_polyfit_val(mos_sig_raw,mos_bak_raw,mos_ovr_raw,is_personalized_MOS)
|
||||
predicted_mos_sig_seg_raw.append(mos_sig_raw)
|
||||
predicted_mos_bak_seg_raw.append(mos_bak_raw)
|
||||
predicted_mos_ovr_seg_raw.append(mos_ovr_raw)
|
||||
predicted_mos_sig_seg.append(mos_sig)
|
||||
predicted_mos_bak_seg.append(mos_bak)
|
||||
predicted_mos_ovr_seg.append(mos_ovr)
|
||||
|
||||
clip_dict = {'filename': fpath, 'len_in_sec': actual_audio_len/fs, 'sr':fs}
|
||||
clip_dict['num_hops'] = num_hops
|
||||
clip_dict['OVRL_raw'] = np.mean(predicted_mos_ovr_seg_raw)
|
||||
clip_dict['SIG_raw'] = np.mean(predicted_mos_sig_seg_raw)
|
||||
clip_dict['BAK_raw'] = np.mean(predicted_mos_bak_seg_raw)
|
||||
clip_dict['OVRL'] = np.mean(predicted_mos_ovr_seg)
|
||||
clip_dict['SIG'] = np.mean(predicted_mos_sig_seg)
|
||||
clip_dict['BAK'] = np.mean(predicted_mos_bak_seg)
|
||||
return clip_dict
|
||||
|
||||
def main(args):
|
||||
models = glob.glob(os.path.join(args.testset_dir, "*"))
|
||||
audio_clips_list = []
|
||||
|
||||
def audio_logpowspec(audio, nfft=320, hop_length=160, sr=16000):
|
||||
powspec = (np.abs(librosa.core.stft(audio, n_fft=nfft, hop_length=hop_length)))**2
|
||||
logpowspec = np.log10(np.maximum(powspec, 10**(-12)))
|
||||
return logpowspec.T
|
||||
if args.personalized_MOS:
|
||||
primary_model_path = os.path.join('pDNSMOS', 'sig_bak_ovr.onnx')
|
||||
else:
|
||||
primary_model_path = os.path.join('DNSMOS', 'sig_bak_ovr.onnx')
|
||||
|
||||
predicted_mos_sig = []
|
||||
predicted_mos_bak = []
|
||||
predicted_mos_ovr = []
|
||||
audio_clips_list = glob.glob(os.path.join(args.testset_dir, "*.wav"))
|
||||
|
||||
session_sig = ort.InferenceSession(args.sig_model_path)
|
||||
session_bak_ovr = ort.InferenceSession(args.bak_ovr_model_path)
|
||||
compute_score = ComputeScore(primary_model_path)
|
||||
|
||||
if args.csv_path:
|
||||
csv_path = args.csv_path
|
||||
else:
|
||||
csv_path = args.run_name+'.csv'
|
||||
|
||||
with open(csv_path, mode='w', newline='') as csvfile:
|
||||
csvwriter = csv.writer(csvfile, delimiter=',',
|
||||
quoting=csv.QUOTE_MINIMAL)
|
||||
csvwriter.writerow(['filename', 'SIG', 'BAK', 'OVR'])
|
||||
rows = []
|
||||
clips = []
|
||||
clips = glob.glob(os.path.join(args.testset_dir, "*.wav"))
|
||||
is_personalized_eval = args.personalized_MOS
|
||||
desired_fs = SAMPLING_RATE
|
||||
for m in tqdm(models):
|
||||
max_recursion_depth = 10
|
||||
audio_path = os.path.join(args.testset_dir, m)
|
||||
audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav"))
|
||||
while len(audio_clips_list) == 0 and max_recursion_depth > 0:
|
||||
audio_path = os.path.join(audio_path, "**")
|
||||
audio_clips_list = glob.glob(os.path.join(audio_path, "*.wav"))
|
||||
max_recursion_depth -= 1
|
||||
clips.extend(audio_clips_list)
|
||||
|
||||
for i in tqdm(range(len(audio_clips_list))):
|
||||
fpath = audio_clips_list[i]
|
||||
audio, fs = sf.read(fpath)
|
||||
if len(audio)<2*fs:
|
||||
print('Audio clip is too short. Skipped processing ',
|
||||
os.path.basename(fpath))
|
||||
continue
|
||||
len_samples = int(args.input_length*fs)
|
||||
while len(audio) < len_samples:
|
||||
audio = np.append(audio, audio)
|
||||
|
||||
num_hops = int(np.floor(len(audio)/fs) - args.input_length)+1
|
||||
hop_len_samples = fs
|
||||
predicted_mos_sig_seg = []
|
||||
predicted_mos_bak_seg = []
|
||||
predicted_mos_ovr_seg = []
|
||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||
future_to_url = {executor.submit(compute_score, clip, desired_fs, is_personalized_eval): clip for clip in clips}
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_url)):
|
||||
clip = future_to_url[future]
|
||||
try:
|
||||
data = future.result()
|
||||
except Exception as exc:
|
||||
print('%r generated an exception: %s' % (clip, exc))
|
||||
else:
|
||||
rows.append(data)
|
||||
|
||||
for idx in range(num_hops):
|
||||
audio_seg = audio[int(idx*hop_len_samples) : int((idx+args.input_length)*hop_len_samples)]
|
||||
input_features = np.array(audio_logpowspec(audio=audio_seg, sr=fs)).astype('float32')[np.newaxis,:,:]
|
||||
|
||||
onnx_inputs_sig = {inp.name: input_features for inp in session_sig.get_inputs()}
|
||||
mos_sig = poly.polyval(session_sig.run(None, onnx_inputs_sig), COEFS_SIG)
|
||||
|
||||
onnx_inputs_bak_ovr = {inp.name: input_features for inp in session_bak_ovr.get_inputs()}
|
||||
mos_bak_ovr = session_bak_ovr.run(None, onnx_inputs_bak_ovr)
|
||||
|
||||
mos_bak = poly.polyval(mos_bak_ovr[0][0][1], COEFS_BAK)
|
||||
mos_ovr = poly.polyval(mos_bak_ovr[0][0][2], COEFS_OVR)
|
||||
|
||||
|
||||
predicted_mos_sig_seg.append(mos_sig)
|
||||
predicted_mos_bak_seg.append(mos_bak)
|
||||
predicted_mos_ovr_seg.append(mos_ovr)
|
||||
|
||||
predicted_mos_sig.append(np.mean(predicted_mos_sig_seg))
|
||||
predicted_mos_bak.append(np.mean(predicted_mos_bak_seg))
|
||||
predicted_mos_ovr.append(np.mean(predicted_mos_ovr_seg))
|
||||
|
||||
csvwriter.writerow([os.path.basename(fpath), np.mean(predicted_mos_sig_seg),
|
||||
np.mean(predicted_mos_bak_seg), np.mean(predicted_mos_ovr_seg)])
|
||||
|
||||
print("The average SIG, BAK and OVR MOS for {0} is {1}, {2}, {3}".format(args.run_name,
|
||||
str(round(np.mean(predicted_mos_sig), 2)),
|
||||
str(round(np.mean(predicted_mos_bak), 2)),
|
||||
str(round(np.mean(predicted_mos_ovr), 2))))
|
||||
df = pd.DataFrame(rows)
|
||||
df.to_csv(csv_path)
|
||||
|
||||
if __name__=="__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-ms', "--sig_model_path", default='sig.onnx',
|
||||
help='Path to ONNX or ckpt model for SIG prediction')
|
||||
parser.add_argument('-mbo', "--bak_ovr_model_path", default='bak_ovr.onnx',
|
||||
help='Path to ONNX or ckpt model for BAK and OVR prediction')
|
||||
parser.add_argument('-t', "--testset_dir", default='.',
|
||||
help='Path to the dir containing audio clips in .wav to be evaluated')
|
||||
parser.add_argument('-o', "--csv_path", default=None, help='Dir to the csv that saves the results')
|
||||
parser.add_argument('-l', "--input_length", type=int, default=9)
|
||||
parser.add_argument('-r', "--run_name", type=str, default="dnsmos_p835_inference_sig_bak_ovr_test",
|
||||
help='Change the name depending on the test set and DNS model being evaluated')
|
||||
parser.add_argument('-p', "--personalized_MOS", action='store_true',
|
||||
help='Flag to indicate if personalized MOS score is needed or regular')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
Двоичный файл не отображается.
Загрузка…
Ссылка в новой задаче