Adding down-sampling and stereo mixing features for AudioDecoder (#420)
* initial draft * second * third * polishing * fix the M_PI name in LINUX platform * fix bessel function issue * add a unit test case * fix the unit test name
This commit is contained in:
Родитель
ad0fd98221
Коммит
2fa0b710ea
|
@ -249,7 +249,7 @@ inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ con
|
|||
|
||||
template <typename T>
|
||||
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
|
||||
T* data;
|
||||
T* data = nullptr;
|
||||
ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
|
||||
return data;
|
||||
}
|
||||
|
|
|
@ -157,8 +157,7 @@ class OrtPyFunction:
|
|||
|
||||
x = args[idx]
|
||||
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
|
||||
# an annoying bug is numpy by default is int32, while pytorch is int64.
|
||||
# so cast the input here automatically.
|
||||
# numpy by default is int32 in some platforms, sometimes it is int64.
|
||||
feed[i_.name] = \
|
||||
ts_x.astype(
|
||||
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#define DR_FLAC_IMPLEMENTATION
|
||||
#include "dr_flac.h"
|
||||
#define DR_MP3_IMPLEMENTATION 1
|
||||
|
@ -16,12 +17,17 @@
|
|||
#include "dr_wav.h"
|
||||
|
||||
#include <gsl/util>
|
||||
#include "narrow.h"
|
||||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
#include "sampling.h"
|
||||
|
||||
struct KernelAudioDecoder : public BaseKernel {
|
||||
public:
|
||||
KernelAudioDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
KernelAudioDecoder(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info),
|
||||
downsample_rate_(TryToGetAttributeWithDefault<int64_t>("downsampling_rate", 0)),
|
||||
stereo_mixer_(TryToGetAttributeWithDefault<int64_t>("stereo_to_mono", 0)) {
|
||||
}
|
||||
|
||||
enum class AudioStreamType {
|
||||
|
@ -47,7 +53,8 @@ struct KernelAudioDecoder : public BaseKernel {
|
|||
if (pos == format_mapping.end()) {
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"[AudioDecoder]: Unknown audio stream format: ", str_format),
|
||||
ORT_INVALID_ARGUMENT); }
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
stream_format = pos->second;
|
||||
}
|
||||
|
||||
|
@ -104,12 +111,16 @@ struct KernelAudioDecoder : public BaseKernel {
|
|||
|
||||
int64_t total_buf_size = 0;
|
||||
std::list<std::vector<float>> lst_frames;
|
||||
int64_t orig_sample_rate = 0;
|
||||
int64_t orig_channels = 0;
|
||||
|
||||
if (stream_format == AudioStreamType::kMP3) {
|
||||
auto mp3_obj_ptr = std::make_unique<drmp3>();
|
||||
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input_dim.Size(), nullptr)) {
|
||||
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
orig_sample_rate = mp3_obj_ptr->sampleRate;
|
||||
orig_channels = mp3_obj_ptr->channels;
|
||||
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
|
||||
|
||||
} else if (stream_format == AudioStreamType::kFLAC) {
|
||||
|
@ -118,6 +129,8 @@ struct KernelAudioDecoder : public BaseKernel {
|
|||
if (flac_obj == nullptr) {
|
||||
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
orig_sample_rate = flac_obj->sampleRate;
|
||||
orig_channels = flac_obj->channels;
|
||||
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
|
||||
|
||||
} else {
|
||||
|
@ -125,18 +138,54 @@ struct KernelAudioDecoder : public BaseKernel {
|
|||
if (!drwav_init_memory(&wav_obj, p_data, input_dim.Size(), nullptr)) {
|
||||
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
orig_sample_rate = wav_obj.sampleRate;
|
||||
orig_channels = wav_obj.channels;
|
||||
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
|
||||
}
|
||||
|
||||
std::vector<int64_t> dim_out = {1, total_buf_size};
|
||||
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
|
||||
float* p_output = ort_.GetTensorMutableData<float>(v);
|
||||
if (downsample_rate_ != 0 &&
|
||||
orig_sample_rate < downsample_rate_) {
|
||||
ORTX_CXX_API_THROW("[AudioDecoder]: only down sampling supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
// join all frames
|
||||
std::vector<float> buf;
|
||||
buf.resize(total_buf_size);
|
||||
int64_t offset = 0;
|
||||
for (auto& _b : lst_frames) {
|
||||
std::copy(_b.begin(), _b.end(), p_output + offset);
|
||||
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
|
||||
offset += _b.size();
|
||||
}
|
||||
|
||||
// mix the stereo channels into mono channel
|
||||
if (stereo_mixer_ && orig_channels > 1) {
|
||||
if (buf.size() > 1) {
|
||||
for (size_t i = 0; i < buf.size() / 2; ++i) {
|
||||
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
|
||||
}
|
||||
buf.resize(buf.size() / 2);
|
||||
}
|
||||
}
|
||||
|
||||
if (downsample_rate_ != 0 &&
|
||||
downsample_rate_ != orig_sample_rate) {
|
||||
// A lowpass filter on buf audio data to remove high frequency noise
|
||||
ButterworthLowpass filter(1.0f * orig_sample_rate, 0.5f * downsample_rate_);
|
||||
std::vector<float> filtered_buf = filter.Process(buf);
|
||||
// downsample the audio data
|
||||
KaiserWindowInterpolation::Process(filtered_buf, buf,
|
||||
1.0f * orig_sample_rate, 1.0f * downsample_rate_);
|
||||
}
|
||||
|
||||
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
|
||||
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
|
||||
float* p_output = ort_.GetTensorMutableData<float>(v);
|
||||
std::copy(buf.begin(), buf.end(), p_output);
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t downsample_rate_ = 0;
|
||||
int64_t stereo_mixer_ = 0;
|
||||
};
|
||||
|
||||
struct CustomOpAudioDecoder : OrtW::CustomOpBase<CustomOpAudioDecoder, KernelAudioDecoder> {
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include "narrow.h"
|
||||
|
||||
// https://en.wikipedia.org/wiki/Butterworth_filter
|
||||
class ButterworthLowpass {
|
||||
public:
|
||||
ButterworthLowpass(float sample_rate, float cutoff_frequency)
|
||||
: x_prev_(0.0f), y_prev_(0.0f) {
|
||||
float RC = 1.0f / (2.0f * 3.14159265359f * cutoff_frequency);
|
||||
float dt = 1.0f / sample_rate;
|
||||
float alpha = dt / (RC + dt);
|
||||
a0_ = alpha;
|
||||
a1_ = alpha;
|
||||
b1_ = 1 - alpha;
|
||||
}
|
||||
|
||||
float Process(float input) {
|
||||
float output = a0_ * input + a1_ * x_prev_ - b1_ * y_prev_;
|
||||
x_prev_ = input;
|
||||
y_prev_ = output;
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<float> Process(const std::vector<float>& inputSignal) {
|
||||
std::vector<float> outputSignal(inputSignal.size());
|
||||
for (size_t i = 0; i < inputSignal.size(); ++i) {
|
||||
outputSignal[i] = Process(inputSignal[i]);
|
||||
}
|
||||
return outputSignal;
|
||||
}
|
||||
|
||||
private:
|
||||
float x_prev_, y_prev_;
|
||||
float a0_, a1_, b1_;
|
||||
};
|
||||
|
||||
// https://ccrma.stanford.edu/~jos/sasp/Kaiser_Window.html
|
||||
class KaiserWindowInterpolation {
|
||||
private:
|
||||
// Kaiser window parameters, empirically
|
||||
constexpr static double kBeta = 6.0; // Beta controls the width of the transition band
|
||||
|
||||
public:
|
||||
static void Process(const std::vector<float>& input, std::vector<float>& output, float inputSampleRate, float outputSampleRate) {
|
||||
// Downsampling factor
|
||||
float factor = outputSampleRate / inputSampleRate;
|
||||
const double MY_PI = 3.14159265359;
|
||||
|
||||
// Calculate the number of output samples
|
||||
int outputSize = static_cast<int>(std::ceil(static_cast<float>(input.size()) * factor));
|
||||
output.resize(outputSize);
|
||||
|
||||
for (int i = 0; i < outputSize; i++) {
|
||||
float index = i / factor; // Fractional index for interpolation
|
||||
|
||||
// Calculate the integer and fractional parts of the index
|
||||
int integerPart = static_cast<int>(index);
|
||||
float fractionalPart = index - integerPart;
|
||||
|
||||
// Calculate the range of input samples for interpolation
|
||||
int range = static_cast<int>(std::ceil(kBeta / (2.0f * factor)));
|
||||
int startSample = std::max(0, integerPart - range);
|
||||
int endSample = std::min(static_cast<int>(input.size()) - 1, integerPart + range);
|
||||
|
||||
// Calculate the Kaiser window weights for the input samples
|
||||
std::vector<double> weights = KaiserWin(static_cast<size_t>(endSample - startSample + 1));
|
||||
for (int j = startSample; j <= endSample; j++) {
|
||||
double distance = std::abs(j - index);
|
||||
double sincValue = (distance < 1e-6f) ? 1.0f : std::sin(MY_PI * distance) / (MY_PI * distance);
|
||||
weights[j - startSample] *= sincValue;
|
||||
}
|
||||
|
||||
// Perform the interpolation
|
||||
double interpolatedValue = 0.0f;
|
||||
for (int j = startSample; j <= endSample; j++) {
|
||||
interpolatedValue += input[j] * weights[j - startSample];
|
||||
}
|
||||
|
||||
output[i] = static_cast<float>(interpolatedValue);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// std::cyl_bessel_i is not available for every platform.
|
||||
static double cyl_bessel_i0(double x) {
|
||||
double sum = 0.0;
|
||||
double term = 1.0;
|
||||
double x_squared = x * x / 4.0;
|
||||
int n = 0;
|
||||
double tolerance = 1e-8;
|
||||
|
||||
while (term > tolerance * sum) {
|
||||
sum += term;
|
||||
n += 1;
|
||||
term *= x_squared / (n * n);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
// Kaiser Window function
|
||||
static std::vector<double> KaiserWin(size_t window_length) {
|
||||
std::vector<double> window(window_length);
|
||||
static const double i0_beta = cyl_bessel_i0(kBeta);
|
||||
|
||||
for (size_t i = 0; i < window_length; i++) {
|
||||
double x = 2.0 * i / (window_length - 1.0) - 1.0;
|
||||
double bessel_value = cyl_bessel_i0(kBeta * std::sqrt(1 - x * x));
|
||||
window[i] = bessel_value / i0_beta;
|
||||
}
|
||||
|
||||
return window;
|
||||
}
|
||||
};
|
Двоичный файл не отображается.
|
@ -7,7 +7,7 @@ from onnx import checker, helper, onnx_pb as onnx_proto
|
|||
from onnxruntime_extensions import PyOrtFunction, util
|
||||
|
||||
|
||||
class TestBpeTokenizer(unittest.TestCase):
|
||||
class TestAudioCodec(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
|
||||
|
@ -47,6 +47,13 @@ class TestBpeTokenizer(unittest.TestCase):
|
|||
np.asarray([np.max(pcm_tensor), np.average(pcm_tensor), np.min(pcm_tensor)]),
|
||||
np.asarray([np.max(self.raw_data), np.average(self.raw_data), np.min(self.raw_data)]), atol=1e-01)
|
||||
|
||||
def test_decoder_resampling(self):
|
||||
test_file = util.get_test_data_file('data', 'jfk.flac')
|
||||
blob = bytearray(util.read_file(test_file, mode='rb'))
|
||||
decoder = PyOrtFunction.from_customop('AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
|
||||
pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
|
||||
self.assertEqual(pcm_tensor.shape, (1, 176000))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -8,7 +8,6 @@ import re
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from onnx import numpy_helper
|
||||
from transformers import WhisperProcessor
|
||||
|
||||
|
@ -142,7 +141,8 @@ def _torch_export(*arg, **kwargs):
|
|||
|
||||
def preprocessing(audio_data):
|
||||
if USE_AUDIO_DECODER:
|
||||
decoder = PyOrtFunction.from_customop("AudioDecoder", cpu_only=True)
|
||||
decoder = PyOrtFunction.from_customop(
|
||||
"AudioDecoder", cpu_only=True, downsampling_rate=SAMPLE_RATE, stereo_to_mono=1)
|
||||
audio_pcm = torch.from_numpy(decoder(audio_data))
|
||||
else:
|
||||
audio_pcm = torch.from_numpy(audio_data)
|
||||
|
@ -198,11 +198,14 @@ def merge_models(core: str, output_model: str, audio_data):
|
|||
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
||||
bpe_decoder_node
|
||||
])
|
||||
onnx.save_model(m_all, output_model,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=f"{os.path.basename(output_model)}.data",
|
||||
convert_attribute=True)
|
||||
try:
|
||||
onnx.save_model(m_all, output_model)
|
||||
except ValueError:
|
||||
onnx.save_model(m_all, output_model,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location=f"{os.path.basename(output_model)}.data",
|
||||
convert_attribute=True)
|
||||
print(f"The final merged model was saved as: {output_model}")
|
||||
|
||||
print("Verify the final model...")
|
||||
|
|
Загрузка…
Ссылка в новой задаче