Adding down-sampling and stereo mixing features for AudioDecoder (#420)

* initial draft

* second

* third

* polishing

* fix the M_PI name in LINUX platform

* fix bessel function issue

* add a unit test case

* fix the unit test name
This commit is contained in:
Wenbing Li 2023-05-04 13:30:10 -07:00 коммит произвёл GitHub
Родитель ad0fd98221
Коммит 2fa0b710ea
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 199 добавлений и 20 удалений

Просмотреть файл

@ -249,7 +249,7 @@ inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ con
template <typename T>
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
T* data;
T* data = nullptr;
ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
return data;
}

Просмотреть файл

@ -157,8 +157,7 @@ class OrtPyFunction:
x = args[idx]
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
# an annoying bug is numpy by default is int32, while pytorch is int64.
# so cast the input here automatically.
# numpy by default is int32 in some platforms, sometimes it is int64.
feed[i_.name] = \
ts_x.astype(
np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x

Просмотреть файл

@ -7,6 +7,7 @@
#include <list>
#include <map>
#include <memory>
#define DR_FLAC_IMPLEMENTATION
#include "dr_flac.h"
#define DR_MP3_IMPLEMENTATION 1
@ -16,12 +17,17 @@
#include "dr_wav.h"
#include <gsl/util>
#include "narrow.h"
#include "string_utils.h"
#include "string_tensor.h"
#include "sampling.h"
struct KernelAudioDecoder : public BaseKernel {
public:
KernelAudioDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
KernelAudioDecoder(const OrtApi& api, const OrtKernelInfo& info)
: BaseKernel(api, info),
downsample_rate_(TryToGetAttributeWithDefault<int64_t>("downsampling_rate", 0)),
stereo_mixer_(TryToGetAttributeWithDefault<int64_t>("stereo_to_mono", 0)) {
}
enum class AudioStreamType {
@ -47,7 +53,8 @@ struct KernelAudioDecoder : public BaseKernel {
if (pos == format_mapping.end()) {
ORTX_CXX_API_THROW(MakeString(
"[AudioDecoder]: Unknown audio stream format: ", str_format),
ORT_INVALID_ARGUMENT); }
ORT_INVALID_ARGUMENT);
}
stream_format = pos->second;
}
@ -104,12 +111,16 @@ struct KernelAudioDecoder : public BaseKernel {
int64_t total_buf_size = 0;
std::list<std::vector<float>> lst_frames;
int64_t orig_sample_rate = 0;
int64_t orig_channels = 0;
if (stream_format == AudioStreamType::kMP3) {
auto mp3_obj_ptr = std::make_unique<drmp3>();
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input_dim.Size(), nullptr)) {
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
}
orig_sample_rate = mp3_obj_ptr->sampleRate;
orig_channels = mp3_obj_ptr->channels;
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
} else if (stream_format == AudioStreamType::kFLAC) {
@ -118,6 +129,8 @@ struct KernelAudioDecoder : public BaseKernel {
if (flac_obj == nullptr) {
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION);
}
orig_sample_rate = flac_obj->sampleRate;
orig_channels = flac_obj->channels;
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
} else {
@ -125,18 +138,54 @@ struct KernelAudioDecoder : public BaseKernel {
if (!drwav_init_memory(&wav_obj, p_data, input_dim.Size(), nullptr)) {
ORTX_CXX_API_THROW("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION);
}
orig_sample_rate = wav_obj.sampleRate;
orig_channels = wav_obj.channels;
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
}
std::vector<int64_t> dim_out = {1, total_buf_size};
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
float* p_output = ort_.GetTensorMutableData<float>(v);
if (downsample_rate_ != 0 &&
orig_sample_rate < downsample_rate_) {
ORTX_CXX_API_THROW("[AudioDecoder]: only down sampling supported.", ORT_INVALID_ARGUMENT);
}
// join all frames
std::vector<float> buf;
buf.resize(total_buf_size);
int64_t offset = 0;
for (auto& _b : lst_frames) {
std::copy(_b.begin(), _b.end(), p_output + offset);
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
offset += _b.size();
}
// mix the stereo channels into mono channel
if (stereo_mixer_ && orig_channels > 1) {
if (buf.size() > 1) {
for (size_t i = 0; i < buf.size() / 2; ++i) {
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
}
buf.resize(buf.size() / 2);
}
}
if (downsample_rate_ != 0 &&
downsample_rate_ != orig_sample_rate) {
// A lowpass filter on buf audio data to remove high frequency noise
ButterworthLowpass filter(1.0f * orig_sample_rate, 0.5f * downsample_rate_);
std::vector<float> filtered_buf = filter.Process(buf);
// downsample the audio data
KaiserWindowInterpolation::Process(filtered_buf, buf,
1.0f * orig_sample_rate, 1.0f * downsample_rate_);
}
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
float* p_output = ort_.GetTensorMutableData<float>(v);
std::copy(buf.begin(), buf.end(), p_output);
}
private:
int64_t downsample_rate_ = 0;
int64_t stereo_mixer_ = 0;
};
struct CustomOpAudioDecoder : OrtW::CustomOpBase<CustomOpAudioDecoder, KernelAudioDecoder> {

121
operators/audio/sampling.h Normal file
Просмотреть файл

@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <vector>
#include <cmath>
#include <complex>
#include "narrow.h"
// https://en.wikipedia.org/wiki/Butterworth_filter
class ButterworthLowpass {
public:
ButterworthLowpass(float sample_rate, float cutoff_frequency)
: x_prev_(0.0f), y_prev_(0.0f) {
float RC = 1.0f / (2.0f * 3.14159265359f * cutoff_frequency);
float dt = 1.0f / sample_rate;
float alpha = dt / (RC + dt);
a0_ = alpha;
a1_ = alpha;
b1_ = 1 - alpha;
}
float Process(float input) {
float output = a0_ * input + a1_ * x_prev_ - b1_ * y_prev_;
x_prev_ = input;
y_prev_ = output;
return output;
}
std::vector<float> Process(const std::vector<float>& inputSignal) {
std::vector<float> outputSignal(inputSignal.size());
for (size_t i = 0; i < inputSignal.size(); ++i) {
outputSignal[i] = Process(inputSignal[i]);
}
return outputSignal;
}
private:
float x_prev_, y_prev_;
float a0_, a1_, b1_;
};
// https://ccrma.stanford.edu/~jos/sasp/Kaiser_Window.html
class KaiserWindowInterpolation {
private:
// Kaiser window parameters, empirically
constexpr static double kBeta = 6.0; // Beta controls the width of the transition band
public:
static void Process(const std::vector<float>& input, std::vector<float>& output, float inputSampleRate, float outputSampleRate) {
// Downsampling factor
float factor = outputSampleRate / inputSampleRate;
const double MY_PI = 3.14159265359;
// Calculate the number of output samples
int outputSize = static_cast<int>(std::ceil(static_cast<float>(input.size()) * factor));
output.resize(outputSize);
for (int i = 0; i < outputSize; i++) {
float index = i / factor; // Fractional index for interpolation
// Calculate the integer and fractional parts of the index
int integerPart = static_cast<int>(index);
float fractionalPart = index - integerPart;
// Calculate the range of input samples for interpolation
int range = static_cast<int>(std::ceil(kBeta / (2.0f * factor)));
int startSample = std::max(0, integerPart - range);
int endSample = std::min(static_cast<int>(input.size()) - 1, integerPart + range);
// Calculate the Kaiser window weights for the input samples
std::vector<double> weights = KaiserWin(static_cast<size_t>(endSample - startSample + 1));
for (int j = startSample; j <= endSample; j++) {
double distance = std::abs(j - index);
double sincValue = (distance < 1e-6f) ? 1.0f : std::sin(MY_PI * distance) / (MY_PI * distance);
weights[j - startSample] *= sincValue;
}
// Perform the interpolation
double interpolatedValue = 0.0f;
for (int j = startSample; j <= endSample; j++) {
interpolatedValue += input[j] * weights[j - startSample];
}
output[i] = static_cast<float>(interpolatedValue);
}
}
private:
// std::cyl_bessel_i is not available for every platform.
static double cyl_bessel_i0(double x) {
double sum = 0.0;
double term = 1.0;
double x_squared = x * x / 4.0;
int n = 0;
double tolerance = 1e-8;
while (term > tolerance * sum) {
sum += term;
n += 1;
term *= x_squared / (n * n);
}
return sum;
}
// Kaiser Window function
static std::vector<double> KaiserWin(size_t window_length) {
std::vector<double> window(window_length);
static const double i0_beta = cyl_bessel_i0(kBeta);
for (size_t i = 0; i < window_length; i++) {
double x = 2.0 * i / (window_length - 1.0) - 1.0;
double bessel_value = cyl_bessel_i0(kBeta * std::sqrt(1 - x * x));
window[i] = bessel_value / i0_beta;
}
return window;
}
};

Двоичные данные
test/data/jfk.flac Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -7,7 +7,7 @@ from onnx import checker, helper, onnx_pb as onnx_proto
from onnxruntime_extensions import PyOrtFunction, util
class TestBpeTokenizer(unittest.TestCase):
class TestAudioCodec(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.test_mp3_file = util.get_test_data_file('data', '1272-141231-0002.mp3')
@ -47,6 +47,13 @@ class TestBpeTokenizer(unittest.TestCase):
np.asarray([np.max(pcm_tensor), np.average(pcm_tensor), np.min(pcm_tensor)]),
np.asarray([np.max(self.raw_data), np.average(self.raw_data), np.min(self.raw_data)]), atol=1e-01)
def test_decoder_resampling(self):
test_file = util.get_test_data_file('data', 'jfk.flac')
blob = bytearray(util.read_file(test_file, mode='rb'))
decoder = PyOrtFunction.from_customop('AudioDecoder', cpu_only=True, downsampling_rate=16000, stereo_to_mono=1)
pcm_tensor = decoder(np.expand_dims(np.asarray(blob), axis=(0,)))
self.assertEqual(pcm_tensor.shape, (1, 176000))
if __name__ == "__main__":
unittest.main()

Просмотреть файл

@ -8,7 +8,6 @@ import re
import torch
import numpy as np
from pathlib import Path
from onnx import numpy_helper
from transformers import WhisperProcessor
@ -142,7 +141,8 @@ def _torch_export(*arg, **kwargs):
def preprocessing(audio_data):
if USE_AUDIO_DECODER:
decoder = PyOrtFunction.from_customop("AudioDecoder", cpu_only=True)
decoder = PyOrtFunction.from_customop(
"AudioDecoder", cpu_only=True, downsampling_rate=SAMPLE_RATE, stereo_to_mono=1)
audio_pcm = torch.from_numpy(decoder(audio_data))
else:
audio_pcm = torch.from_numpy(audio_data)
@ -198,6 +198,9 @@ def merge_models(core: str, output_model: str, audio_data):
make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
bpe_decoder_node
])
try:
onnx.save_model(m_all, output_model)
except ValueError:
onnx.save_model(m_all, output_model,
save_as_external_data=True,
all_tensors_to_one_file=True,