182 строки
6.4 KiB
C++
182 строки
6.4 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT License.
|
|
|
|
#include <map>
|
|
#include <memory>
|
|
#include <gsl/util>
|
|
|
|
#include "audio_decoder.h"
|
|
|
|
#define DR_FLAC_IMPLEMENTATION
|
|
#include "dr_flac.h"
|
|
#define DR_MP3_IMPLEMENTATION 1
|
|
#define DR_MP3_FLOAT_OUTPUT 1
|
|
#include "dr_mp3.h"
|
|
#define DR_WAV_IMPLEMENTATION
|
|
#include "dr_wav.h"
|
|
|
|
#include "narrow.h"
|
|
#include "string_utils.h"
|
|
#include "string_tensor.h"
|
|
#include "sampling.h"
|
|
|
|
OrtStatusPtr AudioDecoder::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
|
auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_);
|
|
if (!status) {
|
|
status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_);
|
|
}
|
|
|
|
return status;
|
|
}
|
|
|
|
AudioDecoder::AudioStreamType AudioDecoder::ReadStreamFormat(const uint8_t* p_data, const std::string& str_format,
|
|
OrtxStatus& status) const {
|
|
const std::map<std::string, AudioStreamType> format_mapping = {{"default", AudioStreamType::kDefault},
|
|
{"wav", AudioStreamType::kWAV},
|
|
{"mp3", AudioStreamType::kMP3},
|
|
{"flac", AudioStreamType::kFLAC}};
|
|
|
|
AudioStreamType stream_format = AudioStreamType::kDefault;
|
|
if (str_format.length() > 0) {
|
|
auto pos = format_mapping.find(str_format);
|
|
if (pos == format_mapping.end()) {
|
|
status = {kOrtxErrorInvalidArgument,
|
|
MakeString("[AudioDecoder]: Unknown audio stream format: ", str_format).c_str()};
|
|
return stream_format;
|
|
}
|
|
stream_format = pos->second;
|
|
}
|
|
|
|
if (stream_format == AudioStreamType::kDefault) {
|
|
auto p_stream = reinterpret_cast<char const*>(p_data);
|
|
std::string_view marker(p_stream, 4);
|
|
if (marker == "fLaC") {
|
|
stream_format = AudioStreamType::kFLAC;
|
|
} else if (marker == "RIFF") {
|
|
stream_format = AudioStreamType::kWAV;
|
|
} else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) {
|
|
// http://www.mp3-tech.org/programmer/frame_header.html
|
|
// only detect the 8 + 3 bits sync word
|
|
stream_format = AudioStreamType::kMP3;
|
|
} else {
|
|
status = {kOrtxErrorInvalidArgument, "[AudioDecoder]: Cannot detect audio stream format"};
|
|
}
|
|
}
|
|
|
|
return stream_format;
|
|
}
|
|
|
|
template <typename TY_AUDIO, typename FX_DECODER>
|
|
static size_t DrReadFrames(std::list<std::vector<float>>& frames, FX_DECODER fx, TY_AUDIO& obj) {
|
|
const size_t default_chunk_size = 1024 * 256;
|
|
int64_t total_buf_size = 0;
|
|
|
|
for (;;) {
|
|
std::vector<float> buf;
|
|
buf.resize(default_chunk_size * obj.channels);
|
|
auto n_frames = fx(&obj, default_chunk_size, buf.data());
|
|
if (n_frames <= 0) {
|
|
break;
|
|
}
|
|
auto data_size = n_frames * obj.channels;
|
|
total_buf_size += data_size;
|
|
buf.resize(data_size);
|
|
frames.emplace_back(std::move(buf));
|
|
}
|
|
|
|
return total_buf_size;
|
|
}
|
|
|
|
OrtxStatus AudioDecoder::Compute(const ortc::Tensor<uint8_t>& input, const std::optional<std::string> format,
|
|
ortc::Tensor<float>& output0) const {
|
|
const uint8_t* p_data = input.Data();
|
|
auto input_dim = input.Shape();
|
|
OrtxStatus status;
|
|
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
|
|
return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Expect input dimension [n] or [1,n]."};
|
|
}
|
|
|
|
std::string str_format;
|
|
if (format) {
|
|
str_format = *format;
|
|
}
|
|
auto stream_format = ReadStreamFormat(p_data, str_format, status);
|
|
if (status) {
|
|
return status;
|
|
}
|
|
|
|
int64_t total_buf_size = 0;
|
|
std::list<std::vector<float>> lst_frames;
|
|
int64_t orig_sample_rate = 0;
|
|
int64_t orig_channels = 0;
|
|
|
|
if (stream_format == AudioStreamType::kMP3) {
|
|
auto mp3_obj_ptr = std::make_unique<drmp3>();
|
|
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) {
|
|
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on MP3 stream."};
|
|
return status;
|
|
}
|
|
orig_sample_rate = mp3_obj_ptr->sampleRate;
|
|
orig_channels = mp3_obj_ptr->channels;
|
|
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
|
|
|
|
} else if (stream_format == AudioStreamType::kFLAC) {
|
|
drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr);
|
|
auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); });
|
|
if (flac_obj == nullptr) {
|
|
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on FLAC stream."};
|
|
return status;
|
|
}
|
|
orig_sample_rate = flac_obj->sampleRate;
|
|
orig_channels = flac_obj->channels;
|
|
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
|
|
|
|
} else {
|
|
drwav wav_obj;
|
|
if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) {
|
|
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on WAV stream."};
|
|
return status;
|
|
}
|
|
orig_sample_rate = wav_obj.sampleRate;
|
|
orig_channels = wav_obj.channels;
|
|
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
|
|
}
|
|
|
|
if (downsample_rate_ != 0 && orig_sample_rate < downsample_rate_) {
|
|
status = {kOrtxErrorCorruptData, "[AudioDecoder]: only down-sampling supported."};
|
|
return status;
|
|
}
|
|
|
|
// join all frames
|
|
std::vector<float> buf;
|
|
buf.resize(total_buf_size);
|
|
int64_t offset = 0;
|
|
for (auto& _b : lst_frames) {
|
|
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
|
|
offset += _b.size();
|
|
}
|
|
|
|
// mix the stereo channels into mono channel
|
|
if (stereo_mixer_ && orig_channels > 1) {
|
|
if (buf.size() > 1) {
|
|
for (size_t i = 0; i < buf.size() / 2; ++i) {
|
|
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
|
|
}
|
|
buf.resize(buf.size() / 2);
|
|
}
|
|
}
|
|
|
|
if (downsample_rate_ != 0 && downsample_rate_ != orig_sample_rate) {
|
|
// A lowpass filter on buf audio data to remove high frequency noise
|
|
ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate);
|
|
std::vector<float> filtered_buf = filter.Process(buf);
|
|
// downsample the audio data
|
|
KaiserWindowInterpolation::Process(filtered_buf, buf, 1.0f * orig_sample_rate, 1.0f * downsample_rate_);
|
|
}
|
|
|
|
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
|
|
float* p_output = output0.Allocate(dim_out);
|
|
std::copy(buf.begin(), buf.end(), p_output);
|
|
return status;
|
|
}
|