Adding down-sampling and stereo mixing features for AudioDecoder (#420)
* initial draft * second * third * polishing * fix the M_PI name in LINUX platform * fix bessel function issue * add a unit test case * fix the unit test name
This commit is contained in:
Родитель
ad0fd98221
Коммит
2fa0b710ea
|
@ -249,7 +249,7 @@ inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ con
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
|
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
|
||||||
T* data;
|
T* data = nullptr;
|
||||||
ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
|
ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
|
@ -157,8 +157,7 @@ class OrtPyFunction:
|
||||||
|
|
||||||
x = args[idx]
|
x = args[idx]
|
||||||
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
|
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
|
||||||
# an annoying bug is numpy by default is int32, while pytorch is int64.
|
# numpy by default is int32 in some platforms, sometimes it is int64.
|
||||||
# so cast the input here automatically.
|
|
||||||
feed[i_.name] = \
|
feed[i_.name] = \
|
||||||
ts_x.astype(
|
ts_x.astype(
|
||||||
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
||||||
|
|
|
@ -52,7 +52,7 @@ class HFTokenizerConverter(CustomOpConverter):
|
||||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||||
attrs.update(**kwargs)
|
attrs.update(**kwargs)
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
def roberta_tokenizer(self, **kwargs):
|
def roberta_tokenizer(self, **kwargs):
|
||||||
hf_roberta_tokenizer = self.tokenizer
|
hf_roberta_tokenizer = self.tokenizer
|
||||||
attrs = {'vocab': json.dumps(
|
attrs = {'vocab': json.dumps(
|
||||||
|
@ -62,4 +62,4 @@ class HFTokenizerConverter(CustomOpConverter):
|
||||||
attrs['merges'] = '\n'.join("{} {}".format(
|
attrs['merges'] = '\n'.join("{} {}".format(
|
||||||
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
|
||||||
attrs.update(**kwargs)
|
attrs.update(**kwargs)
|
||||||
return attrs
|
return attrs
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
|
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
#define DR_FLAC_IMPLEMENTATION
|
#define DR_FLAC_IMPLEMENTATION
|
||||||
#include "dr_flac.h"
|
#include "dr_flac.h"
|
||||||
#define DR_MP3_IMPLEMENTATION 1
|
#define DR_MP3_IMPLEMENTATION 1
|
||||||
|
@ -16,12 +17,17 @@
|
||||||
#include "dr_wav.h"
|
#include "dr_wav.h"
|
||||||
|
|
||||||
#include <gsl/util>
|
#include <gsl/util>
|
||||||
|
#include "narrow.h"
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
|
||||||
struct KernelAudioDecoder : public BaseKernel {
|
struct KernelAudioDecoder : public BaseKernel {
|
||||||
public:
|
public:
|
||||||
KernelAudioDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
KernelAudioDecoder(const OrtApi& api, const OrtKernelInfo& info)
|
||||||
|
: BaseKernel(api, info),
|
||||||
|
downsample_rate_(TryToGetAttributeWithDefault<int64_t>("downsampling_rate", 0)),
|
||||||
|
stereo_mixer_(TryToGetAttributeWithDefault<int64_t>("stereo_to_mono", 0)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class AudioStreamType {
|
enum class AudioStreamType {
|
||||||
|
@ -47,7 +53,8 @@ struct KernelAudioDecoder : public BaseKernel {
|
||||||
if (pos == format_mapping.end()) {
|
if (pos == format_mapping.end()) {
|
||||||
ORTX_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"[AudioDecoder]: Unknown audio stream format: ", str_format),
|
"[AudioDecoder]: Unknown audio stream format: ", str_format),
|
||||||
ORT_INVALID_ARGUMENT); }
|
ORT_INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
stream_format = pos->second;
|
stream_format = pos->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,12 +111,16 @@ struct KernelAudioDecoder : public BaseKernel {
|
||||||
|
|
||||||
int64_t total_buf_size = 0;
|
int64_t total_buf_size = 0;
|
||||||
std::list<std::vector<float>> lst_frames;
|
std::list<std::vector<float>> lst_frames;
|
||||||
|
int64_t orig_sample_rate = 0;
|
||||||
|
int64_t orig_channels = 0;
|
||||||
|
|
||||||
if (stream_format == AudioStreamType::kMP3) {
|
if (stream_format == AudioStreamType::kMP3) {
|
||||||
auto mp3_obj_ptr = std::make_unique<drmp3>();
|
auto mp3_obj_ptr = std::make_unique<drmp3>();
|
||||||
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input_dim.Size(), nullptr)) {
|
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input_dim.Size(), nullptr)) {
|
||||||
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
|
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
|
||||||
}
|
}
|
||||||
|
orig_sample_rate = mp3_obj_ptr->sampleRate;
|
||||||
|
orig_channels = mp3_obj_ptr->channels;
|
||||||
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
|
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
|
||||||
|
|
||||||
} else if (stream_format == AudioStreamType::kFLAC) {
|
} else if (stream_format == AudioStreamType::kFLAC) {
|
||||||
|
@ -118,6 +129,8 @@ struct KernelAudioDecoder : public BaseKernel {
|
||||||
if (flac_obj == nullptr) {
|
if (flac_obj == nullptr) {
|
||||||
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION);
|
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION);
|
||||||
}
|
}
|
||||||
|
orig_sample_rate = flac_obj->sampleRate;
|
||||||
|
orig_channels = flac_obj->channels;
|
||||||
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
|
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
@ -125,18 +138,54 @@ struct KernelAudioDecoder : public BaseKernel {
|
||||||
if (!drwav_init_memory(&wav_obj, p_data, input_dim.Size(), nullptr)) {
|
if (!drwav_init_memory(&wav_obj, p_data, input_dim.Size(), nullptr)) {
|
||||||
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION);
|
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION);
|
||||||
}
|
}
|
||||||
|
orig_sample_rate = wav_obj.sampleRate;
|
||||||
|
orig_channels = wav_obj.channels;
|
||||||
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
|
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> dim_out = {1, total_buf_size};
|
if (downsample_rate_ != 0 &&
|
||||||
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
|
orig_sample_rate < downsample_rate_) {
|
||||||
float* p_output = ort_.GetTensorMutableData<float>(v);
|
ORTX_CXX_API_THROW("[AudioDecoder]: only down sampling supported.", ORT_INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
// join all frames
|
||||||
|
std::vector<float> buf;
|
||||||
|
buf.resize(total_buf_size);
|
||||||
int64_t offset = 0;
|
int64_t offset = 0;
|
||||||
for (auto& _b : lst_frames) {
|
for (auto& _b : lst_frames) {
|
||||||
std::copy(_b.begin(), _b.end(), p_output + offset);
|
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
|
||||||
offset += _b.size();
|
offset += _b.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mix the stereo channels into mono channel
|
||||||
|
if (stereo_mixer_ && orig_channels > 1) {
|
||||||
|
if (buf.size() > 1) {
|
||||||
|
for (size_t i = 0; i < buf.size() / 2; ++i) {
|
||||||
|
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
|
||||||
|
}
|
||||||
|
buf.resize(buf.size() / 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (downsample_rate_ != 0 &&
|
||||||
|
downsample_rate_ != orig_sample_rate) {
|
||||||
|
// A lowpass filter on buf audio data to remove high frequency noise
|
||||||
|
ButterworthLowpass filter(1.0f * orig_sample_rate, 0.5f * downsample_rate_);
|
||||||
|
std::vector<float> filtered_buf = filter.Process(buf);
|
||||||
|
// downsample the audio data
|
||||||
|
KaiserWindowInterpolation::Process(filtered_buf, buf,
|
||||||
|
1.0f * orig_sample_rate, 1.0f * downsample_rate_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
|
||||||
|
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
|
||||||
|
float* p_output = ort_.GetTensorMutableData<float>(v);
|
||||||
|
std::copy(buf.begin(), buf.end(), p_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64_t downsample_rate_ = 0;
|
||||||
|
int64_t stereo_mixer_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpAudioDecoder : OrtW::CustomOpBase<CustomOpAudioDecoder, KernelAudioDecoder> {
|
struct CustomOpAudioDecoder : OrtW::CustomOpBase<CustomOpAudioDecoder, KernelAudioDecoder> {
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
// Copyright (c) Microsoft Corporation.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <cmath>
|
||||||
|
#include <complex>
|
||||||
|
#include "narrow.h"
|
||||||
|
|
||||||
|
// https://en.wikipedia.org/wiki/Butterworth_filter
|
||||||
|
class ButterworthLowpass {
|
||||||
|
public:
|
||||||
|
ButterworthLowpass(float sample_rate, float cutoff_frequency)
|
||||||
|
: x_prev_(0.0f), y_prev_(0.0f) {
|
||||||
|
float RC = 1.0f / (2.0f * 3.14159265359f * cutoff_frequency);
|
||||||
|
float dt = 1.0f / sample_rate;
|
||||||
|
float alpha = dt / (RC + dt);
|
||||||
|
a0_ = alpha;
|
||||||
|
a1_ = alpha;
|
||||||
|
b1_ = 1 - alpha;
|
||||||
|
}
|
||||||
|
|
||||||
|
float Process(float input) {
|
||||||
|
float output = a0_ * input + a1_ * x_prev_ - b1_ * y_prev_;
|
||||||
|
x_prev_ = input;
|
||||||
|
y_prev_ = output;
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> Process(const std::vector<float>& inputSignal) {
|
||||||
|
std::vector<float> outputSignal(inputSignal.size());
|
||||||
|
for (size_t i = 0; i < inputSignal.size(); ++i) {
|
||||||
|
outputSignal[i] = Process(inputSignal[i]);
|
||||||
|
}
|
||||||
|
return outputSignal;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
float x_prev_, y_prev_;
|
||||||
|
float a0_, a1_, b1_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// https://ccrma.stanford.edu/~jos/sasp/Kaiser_Window.html
|
||||||
|
class KaiserWindowInterpolation {
|
||||||
|
private:
|
||||||
|
// Kaiser window parameters, empirically
|
||||||
|
constexpr static double kBeta = 6.0; // Beta controls the width of the transition band
|
||||||
|
|
||||||
|
public:
|
||||||
|
static void Process(const std::vector<float>& input, std::vector<float>& output, float inputSampleRate, float outputSampleRate) {
|
||||||
|
// Downsampling factor
|
||||||
|
float factor = outputSampleRate / inputSampleRate;
|
||||||
|
const double MY_PI = 3.14159265359;
|
||||||
|
|
||||||
|
// Calculate the number of output samples
|
||||||
|
int outputSize = static_cast<int>(std::ceil(static_cast<float>(input.size()) * factor));
|
||||||
|
output.resize(outputSize);
|
||||||
|
|
||||||
|
for (int i = 0; i < outputSize; i++) {
|
||||||
|
float index = i / factor; // Fractional index for interpolation
|
||||||
|
|
||||||
|
// Calculate the integer and fractional parts of the index
|
||||||
|
int integerPart = static_cast<int>(index);
|
||||||
|
float fractionalPart = index - integerPart;
|
||||||
|
|
||||||
|
// Calculate the range of input samples for interpolation
|
||||||
|
int range = static_cast<int>(std::ceil(kBeta / (2.0f * factor)));
|
||||||
|
int startSample = std::max(0, integerPart - range);
|
||||||
|
int endSample = std::min(static_cast<int>(input.size()) - 1, integerPart + range);
|
||||||
|
|
||||||
|
// Calculate the Kaiser window weights for the input samples
|
||||||
|
std::vector<double> weights = KaiserWin(static_cast<size_t>(endSample - startSample + 1));
|
||||||
|
for (int j = startSample; j <= endSample; j++) {
|
||||||
|
double distance = std::abs(j - index);
|
||||||
|
double sincValue = (distance < 1e-6f) ? 1.0f : std::sin(MY_PI * distance) / (MY_PI * distance);
|
||||||
|
weights[j - startSample] *= sincValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the interpolation
|
||||||
|
double interpolatedValue = 0.0f;
|
||||||
|
for (int j = startSample; j <= endSample; j++) {
|
||||||
|
interpolatedValue += input[j] * weights[j - startSample];
|
||||||
|
}
|
||||||
|
|
||||||
|
output[i] = static_cast<float>(interpolatedValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// std::cyl_bessel_i is not available for every platform.
|
||||||
|
static double cyl_bessel_i0(double x) {
|
||||||
|
double sum = 0.0;
|
||||||
|
double term = 1.0;
|
||||||
|
double x_squared = x * x / 4.0;
|
||||||
|
int n = 0;
|
||||||
|
double tolerance = 1e-8;
|
||||||
|
|
||||||
|
while (term > tolerance * sum) {
|
||||||
|
sum += term;
|
||||||
|
n += 1;
|
||||||
|
term *= x_squared / (n * n);
|
||||||
|
}
|
||||||
|
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kaiser Window function
|
||||||
|
static std::vector<double> KaiserWin(size_t window_length) {
|
||||||
|
std::vector<double> window(window_length);
|
||||||
|
static const double i0_beta = cyl_bessel_i0(kBeta);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < window_length; i++) {
|
||||||
|
double x = 2.0 * i / (window_length - 1.0) - 1.0;
|
||||||
|
double bessel_value = cyl_bessel_i0(kBeta * std::sqrt(1 - x * x));
|
||||||
|
window[i] = bessel_value / i0_beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
return window;
|
||||||
|
}
|
||||||
|
};
|
Двоичный файл не отображается.
|
@ -7,7 +7,7 @@ from onnx import checker, helper, onnx_pb as onnx_proto
|
||||||
from onnxruntime_extensions import PyOrtFunction, util
|
from onnxruntime_extensions import PyOrtFunction, util
|
||||||
|
|
||||||
|
|
||||||
class TestBpeTokenizer(unittest.TestCase):
|
class TestAudioCodec(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls) -> None:
|
def setUpClass(cls) -> None:
|
||||||
cls.test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
cls.test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||||
|
@ -47,6 +47,13 @@ class TestBpeTokenizer(unittest.TestCase):
|
||||||
np.asarray([np.max(pcm_tensor), np.average(pcm_tensor), np.min(pcm_tensor)]),
|
np.asarray([np.max(pcm_tensor), np.average(pcm_tensor), np.min(pcm_tensor)]),
|
||||||
np.asarray([np.max(self.raw_data), np.average(self.raw_data), np.min(self.raw_data)]), atol=1e-01)
|
np.asarray([np.max(self.raw_data), np.average(self.raw_data), np.min(self.raw_data)]), atol=1e-01)
|
||||||
|
|
||||||
|
def test_decoder_resampling(self):
|
||||||
|
test_file = util.get_test_data_file('data', 'jfk.flac')
|
||||||
|
blob = bytearray(util.read_file(test_file, mode='rb'))
|
||||||
|
decoder = PyOrtFunction.from_customop('AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
|
||||||
|
pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
|
||||||
|
self.assertEqual(pcm_tensor.shape, (1, 176000))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -8,7 +8,6 @@ import re
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from transformers import WhisperProcessor
|
from transformers import WhisperProcessor
|
||||||
|
|
||||||
|
@ -142,7 +141,8 @@ def _torch_export(*arg, **kwargs):
|
||||||
|
|
||||||
def preprocessing(audio_data):
|
def preprocessing(audio_data):
|
||||||
if USE_AUDIO_DECODER:
|
if USE_AUDIO_DECODER:
|
||||||
decoder = PyOrtFunction.from_customop("AudioDecoder", cpu_only=True)
|
decoder = PyOrtFunction.from_customop(
|
||||||
|
"AudioDecoder", cpu_only=True, downsampling_rate=SAMPLE_RATE, stereo_to_mono=1)
|
||||||
audio_pcm = torch.from_numpy(decoder(audio_data))
|
audio_pcm = torch.from_numpy(decoder(audio_data))
|
||||||
else:
|
else:
|
||||||
audio_pcm = torch.from_numpy(audio_data)
|
audio_pcm = torch.from_numpy(audio_data)
|
||||||
|
@ -172,7 +172,7 @@ def preprocessing(audio_data):
|
||||||
return pre_f(audio_data)
|
return pre_f(audio_data)
|
||||||
else:
|
else:
|
||||||
pre_full = onnx.compose.merge_models(
|
pre_full = onnx.compose.merge_models(
|
||||||
decoder.onnx_model,
|
decoder.onnx_model,
|
||||||
pre_model,
|
pre_model,
|
||||||
io_map=[("floatPCM", "audio_pcm")])
|
io_map=[("floatPCM", "audio_pcm")])
|
||||||
pre_f = PyOrtFunction.from_model(pre_full, cpu_only=True)
|
pre_f = PyOrtFunction.from_model(pre_full, cpu_only=True)
|
||||||
|
@ -198,11 +198,14 @@ def merge_models(core: str, output_model: str, audio_data):
|
||||||
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
||||||
bpe_decoder_node
|
bpe_decoder_node
|
||||||
])
|
])
|
||||||
onnx.save_model(m_all, output_model,
|
try:
|
||||||
save_as_external_data=True,
|
onnx.save_model(m_all, output_model)
|
||||||
all_tensors_to_one_file=True,
|
except ValueError:
|
||||||
location=f"{os.path.basename(output_model)}.data",
|
onnx.save_model(m_all, output_model,
|
||||||
convert_attribute=True)
|
save_as_external_data=True,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location=f"{os.path.basename(output_model)}.data",
|
||||||
|
convert_attribute=True)
|
||||||
print(f"The final merged model was saved as: {output_model}")
|
print(f"The final merged model was saved as: {output_model}")
|
||||||
|
|
||||||
print("Verify the final model...")
|
print("Verify the final model...")
|
||||||
|
|
Загрузка…
Ссылка в новой задаче