Feature extraction C API for whipser model (#755)

* Feature extraction C API for whipser model

* Update the docs

* Update the docs2

* refine the code

* fix some issues

* fix the Linux build

* fix more data consistency issue

* More code refinements
This commit is contained in:
Wenbing Li 2024-07-11 11:20:36 -07:00 коммит произвёл GitHub
Родитель 95d65e4ec0
Коммит 8153bc1a3a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
28 изменённых файлов: 1505 добавлений и 515 удалений

Просмотреть файл

@ -171,6 +171,10 @@ if (MSVC)
endif()
message(STATUS "_STATIC_MSVC_RUNTIME_LIBRARY: ${_STATIC_MSVC_RUNTIME_LIBRARY}")
# DLL initialization errors due to old conda msvcp140.dll dll are a result of the new MSVC compiler
# See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856
# Remove this definition once the conda msvcp140.dll dll is updated.
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
endif()
if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
@ -442,7 +446,9 @@ endif()
if(OCOS_ENABLE_BERT_TOKENIZER)
# Bert
set(_HAS_TOKENIZER ON)
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*")
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*"
"operators/tokenizer/bert_tokenizer.*"
"operators/tokenizer/bert_tokenizer_decoder.*")
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
endif()
@ -820,7 +826,9 @@ if(OCOS_ENABLE_AZURE)
endif()
target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:noexcep_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_link_libraries(ortcustomops PUBLIC ocos_operators)
if(_BUILD_SHARED_LIBRARY)
@ -840,7 +848,8 @@ if(_BUILD_SHARED_LIBRARY)
standardize_output_folder(extensions_shared)
if(LINUX OR ANDROID)
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS
" -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
# strip if not a debug build
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s")

19
docs/c_api.md Normal file
Просмотреть файл

@ -0,0 +1,19 @@
# ONNXRuntime Extensions C ABI
ONNXRuntime Extensions provides a C-style ABI for pre-processing. It offers support for tokenization, image processing, speech feature extraction, and more. You can compile the ONNXRuntime Extensions as either a static library or a dynamic library to access these APIs.
The C ABI header files are named `ortx_*.h` and can be found in the include folder. There are three types of data processing APIs available:
- [`ortx_tokenizer.h`](../include/ortx_tokenizer.h): Provides tokenization for LLM models.
- [`ortx_processor.h`](../include/ortx_processor.h): Offers image processing APIs for multimodels.
- [`ortx_extraction.h`](../include/ortx_extractor.h): Provides speech feature extraction for audio data processing to assist the Whisper model.
## ABI QuickStart
Most APIs accept raw data inputs such as audio, image compressed binary formats, or UTF-8 encoded text for tokenization.
**Tokenization:** You can create a tokenizer object using `OrtxCreateTokenizer` and then use the object to tokenize a text or decode the token ID into the text. A C-style code snippet is available [here](../test/pp_api_test/c_only_test.c).
**Image processing:** `OrtxCreateProcessor` can create an image processor object from a pre-defined workflow in JSON format to process image files into a tensor-like data type. An example code snippet can be found [here](../test/pp_api_test/test_processor.cc#L75).
**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extractor.cc#L16).

Просмотреть файл

75
include/ortx_extractor.h Normal file
Просмотреть файл

@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// C ABI header file for the onnxruntime-extensions tokenization module
#pragma once
#include "ortx_utils.h"
typedef OrtxObject OrtxFeatureExtractor;
typedef OrtxObject OrtxRawAudios;
typedef OrtxObject OrtxTensorResult;
#ifdef __cplusplus
extern "C" {
#endif
/**
* @brief Creates a feature extractor object.
*
* This function creates a feature extractor object based on the provided feature definition.
*
* @param[out] extractor Pointer to a pointer to the created feature extractor object.
* @param[in] fe_def The feature definition used to create the feature extractor.
*
* @return An error code indicating the result of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* fe_def);
/**
* Loads a collection of audio files into memory.
*
* This function loads a collection of audio files specified by the `audio_paths` array
* into memory and returns a pointer to the loaded audio data in the `audios` parameter.
*
* @param audios A pointer to a pointer that will be updated with the loaded audio data.
* The caller is responsible for freeing the memory allocated for the audio data.
* @param audio_paths An array of strings representing the paths to the audio files to be loaded.
* @param num_audios The number of audio files to be loaded.
*
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* const* audio_paths, size_t num_audios);
/**
* @brief Creates an array of raw audio objects.
*
* This function creates an array of raw audio objects based on the provided data and sizes.
*
* @param audios Pointer to the variable that will hold the created raw audio objects.
* @param data Array of pointers to the audio data.
* @param sizes Array of pointers to the sizes of the audio data.
* @param num_audios Number of audio objects to create.
*
* @return extError_t Error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t* sizes[], size_t num_audios);
/**
* @brief Calculates the log mel spectrogram for a given audio using the specified feature extractor.
*
* This function takes an instance of the OrtxFeatureExtractor struct, an instance of the OrtxRawAudios struct,
* and a pointer to a OrtxTensorResult pointer. It calculates the log mel spectrogram for the given audio using
* the specified feature extractor and stores the result in the provided log_mel pointer.
*
* @param extractor The feature extractor to use for calculating the log mel spectrogram.
* @param audio The raw audio data to process.
* @param log_mel A pointer to a OrtxTensorResult pointer where the result will be stored.
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** log_mel);
#ifdef __cplusplus
}
#endif

Просмотреть файл

@ -10,7 +10,6 @@
// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
typedef OrtxObject OrtxProcessor;
typedef OrtxObject OrtxRawImages;
typedef OrtxObject OrtxImageProcessorResult;
#ifdef __cplusplus
extern "C" {
@ -40,8 +39,22 @@ extError_t ORTX_API_CALL OrtxCreateProcessor(OrtxProcessor** processor, const ch
extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** image_paths, size_t num_images,
size_t* num_images_loaded);
/**
* @brief Preprocesses the given raw images using the specified processor.
* @brief Creates raw images from the provided data.
*
* This function creates raw images from the provided data. The raw images are stored in the `images` parameter.
*
* @param images Pointer to a pointer to the `OrtxRawImages` structure that will hold the created raw images.
* @param data Array of pointers to the data for each image.
* @param sizes Array of pointers to the sizes of each image.
* @param num_images Number of images to create.
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawImages(OrtxRawImages** images, const void* data[], const int64_t* sizes[], size_t num_images);
/**
* @brief Pre-processes the given raw images using the specified processor.
*
* This function applies preprocessing operations on the raw images using the provided processor.
* The result of the preprocessing is stored in the `OrtxImageProcessorResult` object.
@ -52,24 +65,7 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima
* @return An `extError_t` value indicating the success or failure of the preprocessing operation.
*/
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
OrtxImageProcessorResult** result);
/**
* @brief Retrieves the image processor result at the specified index.
*
* @param result Pointer to the OrtxImageProcessorResult structure to store the result.
* @param index The index of the result to retrieve.
* @return extError_t The error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor);
/** \brief Clear the outputs of the processor
*
* \param processor The processor object
* \param result The result object to clear
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result);
OrtxTensorResult** result);
#ifdef __cplusplus
}

Просмотреть файл

@ -17,19 +17,22 @@ typedef enum {
kOrtxKindDetokenizerCache = 0x778B,
kOrtxKindProcessor = 0x778C,
kOrtxKindRawImages = 0x778D,
kOrtxKindImageProcessorResult = 0x778E,
kOrtxKindTensorResult = 0x778E,
kOrtxKindProcessorResult = 0x778F,
kOrtxKindTensor = 0x7790,
kOrtxKindFeatureExtractor = 0x7791,
kOrtxKindRawAudios = 0x7792,
kOrtxKindEnd = 0x7999
} extObjectKind_t;
// all object managed by the library should be 'derived' from this struct
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
typedef struct {
int ext_kind_;
extObjectKind_t ext_kind_;
} OrtxObject;
typedef OrtxObject OrtxTensor;
typedef OrtxObject OrtxTensorResult;
// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
@ -77,6 +80,18 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
*/
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);
/**
* @brief Retrieves the tensor at the specified index from the given tensor result.
*
* This function allows you to access a specific tensor from a tensor result object.
*
* @param result The tensor result object from which to retrieve the tensor.
* @param index The index of the tensor to retrieve.
* @param tensor A pointer to a variable that will hold the retrieved tensor.
* @return An error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor);
/** \brief Get the data from the tensor
*
* \param tensor The tensor object

Просмотреть файл

@ -17,7 +17,7 @@ from onnx import numpy_helper
from ._ortapi2 import make_onnx_model
from ._cuops import SingleOpGraph
from ._hf_cvt import HFTokenizerConverter
from .util import remove_unused_initializers
from .util import remove_unused_initializers, mel_filterbank
class _WhisperHParams:
@ -30,53 +30,15 @@ class _WhisperHParams:
N_FRAMES = N_SAMPLES // HOP_LENGTH
def _mel_filterbank(
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
"""
Compute a Mel-filterbank. The filters are stored in the rows, the columns,
and it is Slaney normalized mel-scale filterbank.
"""
fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype)
# the centers of the frequency bins for the DFT
freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
mel = np.linspace(min_mel, max_mel, n_mels + 2)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
log_t = mel >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel))
mel_bins = freqs
mel_spacing = np.diff(mel_bins)
ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1)
for i in range(n_mels):
left = -ramps[i] / mel_spacing[i]
right = ramps[i + 2] / mel_spacing[i + 1]
# intersect them with each other and zero
fbank[i] = np.maximum(0, np.minimum(left, right))
energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis]
return fbank
class CustomOpStftNorm(torch.autograd.Function):
@staticmethod
def symbolic(g, self, n_fft, hop_length, window):
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
t_n_fft = g.op('Constant', value_t=torch.tensor(
n_fft, dtype=torch.int64))
t_hop_length = g.op('Constant', value_t=torch.tensor(
hop_length, dtype=torch.int64))
t_frame_size = g.op(
'Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
@staticmethod
@ -97,7 +59,7 @@ class WhisperPrePipeline(torch.nn.Module):
self.n_fft = n_fft
self.window = torch.hann_window(n_fft)
self.mel_filters = torch.from_numpy(
_mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
def forward(self, audio_pcm: torch.Tensor):
stft_norm = CustomOpStftNorm.apply(audio_pcm,
@ -112,7 +74,8 @@ class WhisperPrePipeline(torch.nn.Module):
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1],
self.n_samples // self.hop_length - spec_shape[2],
self.n_samples // self.hop_length -
spec_shape[2],
dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
@ -165,15 +128,20 @@ def _to_onnx_stft(onnx_model, n_fft):
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0',
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
name='slice_1'),
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
make_node('Constant', inputs=[], outputs=[
'const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=[
'const1_output_0'], name='const1', value_int=1),
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
name='gather_4', axis=3),
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
name='gather_5', axis=3),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=[
'mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=[
'mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[
stft_norm_node.output[0]], name='add0'),
]
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
new_stft_nodes.extend(replaced_nodes)
@ -253,9 +221,11 @@ class WhisperDataProcGraph:
del g.node[:]
g.node.extend(nodes)
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
inputs = [onnx.helper.make_tensor_value_info(
"sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
del g.input[:]
g.input.extend(inputs)
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(
onnx.TensorProto.STRING, ['N', 'text']))
return make_onnx_model(g, opset_version=self.opset_version)

Просмотреть файл

@ -3,17 +3,15 @@
#include "ocos.h"
#ifdef ENABLE_DR_LIBS
#include "audio_decoder.hpp"
#include "audio_decoder.h"
#endif // ENABLE_DR_LIBS
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []() -> CustomOpArray& {
static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef ENABLE_DR_LIBS
,
CustomCpuStructV2("AudioDecoder", AudioDecoder)
CustomCpuStructV2("AudioDecoder", AudioDecoder),
#endif
);
[]() { return nullptr; });
return op_loader.GetCustomOps();
};

Просмотреть файл

@ -0,0 +1,181 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <map>
#include <memory>
#include <gsl/util>
#include "audio_decoder.h"
#define DR_FLAC_IMPLEMENTATION
#include "dr_flac.h"
#define DR_MP3_IMPLEMENTATION 1
#define DR_MP3_FLOAT_OUTPUT 1
#include "dr_mp3.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include "narrow.h"
#include "string_utils.h"
#include "string_tensor.h"
#include "sampling.h"
OrtStatusPtr AudioDecoder::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_);
if (!status) {
status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_);
}
return status;
}
AudioDecoder::AudioStreamType AudioDecoder::ReadStreamFormat(const uint8_t* p_data, const std::string& str_format,
OrtxStatus& status) const {
const std::map<std::string, AudioStreamType> format_mapping = {{"default", AudioStreamType::kDefault},
{"wav", AudioStreamType::kWAV},
{"mp3", AudioStreamType::kMP3},
{"flac", AudioStreamType::kFLAC}};
AudioStreamType stream_format = AudioStreamType::kDefault;
if (str_format.length() > 0) {
auto pos = format_mapping.find(str_format);
if (pos == format_mapping.end()) {
status = {kOrtxErrorInvalidArgument,
MakeString("[AudioDecoder]: Unknown audio stream format: ", str_format).c_str()};
return stream_format;
}
stream_format = pos->second;
}
if (stream_format == AudioStreamType::kDefault) {
auto p_stream = reinterpret_cast<char const*>(p_data);
std::string_view marker(p_stream, 4);
if (marker == "fLaC") {
stream_format = AudioStreamType::kFLAC;
} else if (marker == "RIFF") {
stream_format = AudioStreamType::kWAV;
} else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) {
// http://www.mp3-tech.org/programmer/frame_header.html
// only detect the 8 + 3 bits sync word
stream_format = AudioStreamType::kMP3;
} else {
status = {kOrtxErrorInvalidArgument, "[AudioDecoder]: Cannot detect audio stream format"};
}
}
return stream_format;
}
template <typename TY_AUDIO, typename FX_DECODER>
static size_t DrReadFrames(std::list<std::vector<float>>& frames, FX_DECODER fx, TY_AUDIO& obj) {
const size_t default_chunk_size = 1024 * 256;
int64_t total_buf_size = 0;
for (;;) {
std::vector<float> buf;
buf.resize(default_chunk_size * obj.channels);
auto n_frames = fx(&obj, default_chunk_size, buf.data());
if (n_frames <= 0) {
break;
}
auto data_size = n_frames * obj.channels;
total_buf_size += data_size;
buf.resize(data_size);
frames.emplace_back(std::move(buf));
}
return total_buf_size;
}
OrtxStatus AudioDecoder::Compute(const ortc::Tensor<uint8_t>& input, const std::optional<std::string> format,
ortc::Tensor<float>& output0) const {
const uint8_t* p_data = input.Data();
auto input_dim = input.Shape();
OrtxStatus status;
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Expect input dimension [n] or [1,n]."};
}
std::string str_format;
if (format) {
str_format = *format;
}
auto stream_format = ReadStreamFormat(p_data, str_format, status);
if (status) {
return status;
}
int64_t total_buf_size = 0;
std::list<std::vector<float>> lst_frames;
int64_t orig_sample_rate = 0;
int64_t orig_channels = 0;
if (stream_format == AudioStreamType::kMP3) {
auto mp3_obj_ptr = std::make_unique<drmp3>();
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) {
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on MP3 stream."};
return status;
}
orig_sample_rate = mp3_obj_ptr->sampleRate;
orig_channels = mp3_obj_ptr->channels;
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
} else if (stream_format == AudioStreamType::kFLAC) {
drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr);
auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); });
if (flac_obj == nullptr) {
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on FLAC stream."};
return status;
}
orig_sample_rate = flac_obj->sampleRate;
orig_channels = flac_obj->channels;
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
} else {
drwav wav_obj;
if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) {
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on WAV stream."};
return status;
}
orig_sample_rate = wav_obj.sampleRate;
orig_channels = wav_obj.channels;
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
}
if (downsample_rate_ != 0 && orig_sample_rate < downsample_rate_) {
status = {kOrtxErrorCorruptData, "[AudioDecoder]: only down-sampling supported."};
return status;
}
// join all frames
std::vector<float> buf;
buf.resize(total_buf_size);
int64_t offset = 0;
for (auto& _b : lst_frames) {
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
offset += _b.size();
}
// mix the stereo channels into mono channel
if (stereo_mixer_ && orig_channels > 1) {
if (buf.size() > 1) {
for (size_t i = 0; i < buf.size() / 2; ++i) {
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
}
buf.resize(buf.size() / 2);
}
}
if (downsample_rate_ != 0 && downsample_rate_ != orig_sample_rate) {
// A lowpass filter on buf audio data to remove high frequency noise
ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate);
std::vector<float> filtered_buf = filter.Process(buf);
// downsample the audio data
KaiserWindowInterpolation::Process(filtered_buf, buf, 1.0f * orig_sample_rate, 1.0f * downsample_rate_);
}
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
float* p_output = output0.Allocate(dim_out);
std::copy(buf.begin(), buf.end(), p_output);
return status;
}

Просмотреть файл

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include <list>
#include <optional>
struct AudioDecoder {
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info);
template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
// in API mode, the default value is 1
downsample_rate_ = 16000;
stereo_mixer_ = 1;
for (const auto& [key, value] : attrs) {
if (key == "target_sample_rate") {
downsample_rate_ = std::get<std::int64_t>(value);
} else if (key == "stereo_to_mono") {
stereo_mixer_ = std::get<std::int64_t>(value);
} else {
return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Invalid argument"};
}
}
return {};
}
enum class AudioStreamType { kDefault = 0, kWAV, kMP3, kFLAC };
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtxStatus& status) const;
OrtxStatus Compute(const ortc::Tensor<uint8_t>& input, const std::optional<std::string> format,
ortc::Tensor<float>& output0) const;
OrtxStatus ComputeNoOpt(const ortc::Tensor<uint8_t>& input, ortc::Tensor<float>& output0) {
return Compute(input, std::nullopt, output0);
}
private:
int64_t downsample_rate_{};
int64_t stereo_mixer_{};
};

Просмотреть файл

@ -1,206 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include <list>
#include <map>
#include <memory>
#include <optional>
#define DR_FLAC_IMPLEMENTATION
#include "dr_flac.h"
#define DR_MP3_IMPLEMENTATION 1
#define DR_MP3_FLOAT_OUTPUT 1
#include "dr_mp3.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <gsl/util>
#include "narrow.h"
#include "string_utils.h"
#include "string_tensor.h"
#include "sampling.h"
struct AudioDecoder{
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_);
if (!status) {
status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_);
}
return status;
}
enum class AudioStreamType {
kDefault = 0,
kWAV,
kMP3,
kFLAC
};
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtStatusPtr& status) const {
static const std::map<std::string, AudioStreamType> format_mapping = {
{"default", AudioStreamType::kDefault},
{"wav", AudioStreamType::kWAV},
{"mp3", AudioStreamType::kMP3},
{"flac", AudioStreamType::kFLAC}};
AudioStreamType stream_format = AudioStreamType::kDefault;
if (str_format.length() > 0) {
auto pos = format_mapping.find(str_format);
if (pos == format_mapping.end()) {
status = OrtW::CreateStatus(MakeString(
"[AudioDecoder]: Unknown audio stream format: ", str_format)
.c_str(),
ORT_INVALID_ARGUMENT);
return stream_format;
}
stream_format = pos->second;
}
if (stream_format == AudioStreamType::kDefault) {
auto p_stream = reinterpret_cast<char const*>(p_data);
std::string_view marker(p_stream, 4);
if (marker == "fLaC") {
stream_format = AudioStreamType::kFLAC;
} else if (marker == "RIFF") {
stream_format = AudioStreamType::kWAV;
} else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) {
// http://www.mp3-tech.org/programmer/frame_header.html
// only detect the 8 + 3 bits sync word
stream_format = AudioStreamType::kMP3;
} else {
status = OrtW::CreateStatus("[AudioDecoder]: Cannot detect audio stream format", ORT_INVALID_ARGUMENT);
}
}
return stream_format;
}
template <typename TY_AUDIO, typename FX_DECODER>
static size_t DrReadFrames(std::list<std::vector<float>>& frames, FX_DECODER fx, TY_AUDIO& obj) {
const size_t default_chunk_size = 1024 * 256;
int64_t total_buf_size = 0;
for (;;) {
std::vector<float> buf;
buf.resize(default_chunk_size * obj.channels);
auto n_frames = fx(&obj, default_chunk_size, buf.data());
if (n_frames <= 0) {
break;
}
auto data_size = n_frames * obj.channels;
total_buf_size += data_size;
buf.resize(data_size);
frames.emplace_back(std::move(buf));
}
return total_buf_size;
}
OrtStatusPtr Compute(const ortc::Tensor<uint8_t>& input,
const std::optional<std::string> format,
ortc::Tensor<float>& output0) const {
const uint8_t* p_data = input.Data();
auto input_dim = input.Shape();
OrtStatusPtr status = nullptr;
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
status = OrtW::CreateStatus("[AudioDecoder]: Expect input dimension [n] or [1,n].", ORT_INVALID_ARGUMENT);
return status;
}
std::string str_format;
if (format) {
str_format = *format;
}
auto stream_format = ReadStreamFormat(p_data, str_format, status);
if (status) {
return status;
}
int64_t total_buf_size = 0;
std::list<std::vector<float>> lst_frames;
int64_t orig_sample_rate = 0;
int64_t orig_channels = 0;
if (stream_format == AudioStreamType::kMP3) {
auto mp3_obj_ptr = std::make_unique<drmp3>();
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) {
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
return status;
}
orig_sample_rate = mp3_obj_ptr->sampleRate;
orig_channels = mp3_obj_ptr->channels;
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
} else if (stream_format == AudioStreamType::kFLAC) {
drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr);
auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); });
if (flac_obj == nullptr) {
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION);
return status;
}
orig_sample_rate = flac_obj->sampleRate;
orig_channels = flac_obj->channels;
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
} else {
drwav wav_obj;
if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) {
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION);
return status;
}
orig_sample_rate = wav_obj.sampleRate;
orig_channels = wav_obj.channels;
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
}
if (downsample_rate_ != 0 &&
orig_sample_rate < downsample_rate_) {
status = OrtW::CreateStatus("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT);
return status;
}
// join all frames
std::vector<float> buf;
buf.resize(total_buf_size);
int64_t offset = 0;
for (auto& _b : lst_frames) {
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
offset += _b.size();
}
// mix the stereo channels into mono channel
if (stereo_mixer_ && orig_channels > 1) {
if (buf.size() > 1) {
for (size_t i = 0; i < buf.size() / 2; ++i) {
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
}
buf.resize(buf.size() / 2);
}
}
if (downsample_rate_ != 0 &&
downsample_rate_ != orig_sample_rate) {
// A lowpass filter on buf audio data to remove high frequency noise
ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate);
std::vector<float> filtered_buf = filter.Process(buf);
// downsample the audio data
KaiserWindowInterpolation::Process(filtered_buf, buf,
1.0f * orig_sample_rate, 1.0f * downsample_rate_);
}
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
float* p_output = output0.Allocate(dim_out);
std::copy(buf.begin(), buf.end(), p_output);
return status;
}
private:
int64_t downsample_rate_{};
int64_t stereo_mixer_{};
};

Просмотреть файл

@ -6,37 +6,32 @@
#include "ocos.h"
#include <dlib/matrix.h>
struct StftNormal{
struct StftNormal {
StftNormal() = default;
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
return OrtW::GetOpAttribute(info, "onesided", onesided_);
}
OrtStatusPtr Compute(const ortc::Tensor<float>& input0,
int64_t n_fft,
int64_t hop_length,
const ortc::Span<float>& input3,
int64_t frame_length,
ortc::Tensor<float>& output0) const {
OrtxStatus Compute(const ortc::Tensor<float>& input0, int64_t n_fft, int64_t hop_length,
const ortc::Span<float>& input3, int64_t frame_length, ortc::Tensor<float>& output0) const {
auto X = input0.Data();
auto window = input3.data_;
auto dimensions = input0.Shape();
auto win_length = input3.size();
if (dimensions.size() < 2 || input0.NumberOfElement() != dimensions[1]) {
return OrtW::CreateStatus("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT);
return {kOrtxErrorInvalidArgument, "[Stft] Only batch == 1 tensor supported."};
}
if (frame_length != n_fft) {
return OrtW::CreateStatus("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT);
return {kOrtxErrorInvalidArgument, "[Stft] Only support size of FFT equals the frame length."};
}
dlib::matrix<float> dm_x = dlib::mat(X, 1, dimensions[1]);
dlib::matrix<float> hann_win = dlib::mat(window, 1, win_length);
auto m_stft = dlib::stft(
dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); },
n_fft, win_length, hop_length);
auto m_stft =
dlib::stft(dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); }, n_fft, win_length, hop_length);
if (onesided_) {
m_stft = dlib::subm(m_stft, 0, 0, m_stft.nr(), (m_stft.nc() >> 1) + 1);
@ -49,7 +44,7 @@ struct StftNormal{
auto out0 = output0.Allocate(outdim);
memcpy(out0, result.steal_memory().get(), result_size * sizeof(float));
return nullptr;
return {};
}
private:

Просмотреть файл

@ -0,0 +1,77 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "speech_extractor.h"
#include "c_api_utils.hpp"
using namespace ort_extensions;
class RawAudiosObject : public OrtxObjectImpl {
public:
RawAudiosObject() : OrtxObjectImpl(extObjectKind_t::kOrtxKindRawAudios) {}
~RawAudiosObject() override = default;
std::unique_ptr<AudioRawData[]> audios_;
size_t num_audios_;
};
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** raw_audios, const char* const* audio_paths, size_t num_audios) {
if (raw_audios == nullptr || audio_paths == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto audios_obj = std::make_unique<RawAudiosObject>();
auto [audios, num] =
ort_extensions::LoadRawData<char const* const*, AudioRawData>(audio_paths, audio_paths + num_audios);
audios_obj->audios_ = std::move(audios);
audios_obj->num_audios_ = num;
*raw_audios = static_cast<OrtxRawAudios*>(audios_obj.release());
return extError_t();
}
extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* def) {
if (extractor == nullptr || def == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto extractor_ptr = std::make_unique<SpeechFeatureExtractor>();
ReturnableStatus status = extractor_ptr->Init(def);
if (status.IsOk()) {
*extractor = static_cast<OrtxFeatureExtractor*>(extractor_ptr.release());
} else {
*extractor = nullptr;
}
return status.Code();
}
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* raw_audios,
OrtxTensorResult** result) {
if (extractor == nullptr || raw_audios == nullptr || result == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto extractor_ptr = static_cast<SpeechFeatureExtractor*>(extractor);
auto audios_obj = static_cast<RawAudiosObject*>(raw_audios);
auto ts_result = std::make_unique<TensorResult>();
std::unique_ptr<ortc::Tensor<float>> log_mel[1];
ReturnableStatus status =
extractor_ptr->DoCall(ort_extensions::span(audios_obj->audios_.get(), audios_obj->num_audios_), log_mel[0]);
if (status.IsOk()) {
std::vector<std::unique_ptr<ortc::TensorBase>> tensors;
std::transform(log_mel, log_mel + 1, std::back_inserter(tensors),
[](auto& ts) { return std::unique_ptr<ortc::TensorBase>(ts.release()); });
ts_result->SetTensors(std::move(tensors));
*result = ts_result.release();
} else {
*result = nullptr;
}
return status.Code();
}

Просмотреть файл

@ -4,6 +4,8 @@
#include "ortx_processor.h"
#include "image_processor.h"
#include "c_api_utils.hpp"
using namespace ort_extensions;
extError_t OrtxCreateProcessor(OrtxProcessor** processor, const char* def) {
@ -37,19 +39,19 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima
}
auto images_obj = std::make_unique<RawImagesObject>();
auto [img, num] = LoadRawImages(image_paths, image_paths + num_images);
auto [img, num] = LoadRawData<char const**, ImageRawData>(image_paths, image_paths + num_images);
images_obj->images = std::move(img);
images_obj->num_images = num;
if (num_images_loaded != nullptr) {
*num_images_loaded = num;
}
*images = static_cast<OrtxRawImages*>(images_obj.release());
return extError_t();
}
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
OrtxImageProcessorResult** result) {
OrtxTensorResult** result) {
if (processor == nullptr || images == nullptr || result == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
@ -67,59 +69,14 @@ extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawIm
return status.Code();
}
auto result_ptr = std::make_unique<ImageProcessorResult>();
auto result_ptr = std::make_unique<TensorResult>();
status =
processor_ptr->PreProcess(ort_extensions::span(images_ptr->images.get(), images_ptr->num_images), *result_ptr);
if (status.IsOk()) {
*result = static_cast<OrtxImageProcessorResult*>(result_ptr.release());
*result = static_cast<OrtxTensorResult*>(result_ptr.release());
} else {
*result = nullptr;
}
return {};
}
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor) {
if (result == nullptr || tensor == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto result_ptr = static_cast<ImageProcessorResult*>(result);
ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult));
if (!status.IsOk()) {
return status.Code();
}
if (index >= result_ptr->results.size()) {
ReturnableStatus::last_error_message_ = "Index out of range";
return kOrtxErrorInvalidArgument;
}
auto tensor_ptr = std::make_unique<OrtxObjectWrapper<ortc::TensorBase>>();
tensor_ptr->SetObject(result_ptr->results[index].get());
*tensor = static_cast<OrtxTensor*>(tensor_ptr.release());
return extError_t();
}
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result) {
if (processor == nullptr || result == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const auto processor_ptr = static_cast<const ImageProcessor*>(processor);
ReturnableStatus status(processor_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindProcessor));
if (!status.IsOk()) {
return status.Code();
}
auto result_ptr = static_cast<ImageProcessorResult*>(result);
status = result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult);
if (!status.IsOk()) {
return status.Code();
}
ImageProcessor::ClearOutputs(result_ptr);
return extError_t();
}

Просмотреть файл

@ -6,7 +6,7 @@
#include "c_api_utils.hpp"
#include "tokenizer_impl.h"
namespace ort_extensions {
using namespace ort_extensions;
class DetokenizerCache : public OrtxObjectImpl {
public:
@ -17,29 +17,20 @@ class DetokenizerCache : public OrtxObjectImpl {
std::string last_text_{}; // last detokenized text
};
template<>
OrtxObject* OrtxObjectFactory<DetokenizerCache>::CreateForward() {
return std::make_unique<DetokenizerCache>().release();
template <>
OrtxObject* OrtxObjectFactory::CreateForward<DetokenizerCache>() {
return Create<DetokenizerCache>();
}
template<>
void OrtxObjectFactory<DetokenizerCache>::DisposeForward(OrtxObject* obj) {
Dispose(obj);
}
} // namespace ort_extensions
using namespace ort_extensions;
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer,
const char* input[], size_t batch_size, OrtxTokenId2DArray** output) {
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size,
OrtxTokenId2DArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status =
token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
ReturnableStatus status = token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
if (!status.IsOk()) {
return status.Code();
}
@ -61,8 +52,8 @@ extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer,
return extError_t();
}
extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
const OrtxTokenId2DArray* input, OrtxStringArray** output) {
extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input,
OrtxStringArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
@ -81,11 +72,8 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
}
std::vector<span<extTokenId_t const>> t_ids;
std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(),
std::back_inserter(t_ids),
[](const std::vector<extTokenId_t>& vec) {
return span<extTokenId_t const>(vec.data(), vec.size());
});
std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(), std::back_inserter(t_ids),
[](const std::vector<extTokenId_t>& vec) { return span<extTokenId_t const>(vec.data(), vec.size()); });
std::vector<std::string> output_text;
status = token_ptr->Detokenize(t_ids, output_text);
@ -101,9 +89,7 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
;
}
extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer,
const extTokenId_t* input,
size_t len,
extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len,
OrtxStringArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
@ -186,8 +172,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* to
return extError_t();
}
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array,
size_t index, const extTokenId_t** item, size_t* length) {
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array, size_t index,
const extTokenId_t** item, size_t* length) {
if (token_id_2d_array == nullptr || item == nullptr || length == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
@ -210,9 +196,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* tok
return extError_t();
}
extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer,
OrtxDetokenizerCache* cache,
extTokenId_t next_id, const char** text_out) {
extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, OrtxDetokenizerCache* cache, extTokenId_t next_id,
const char** text_out) {
if (tokenizer == nullptr || cache == nullptr || text_out == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;

Просмотреть файл

@ -10,6 +10,8 @@
using namespace ort_extensions;
class DetokenizerCache; // forward definition in tokenizer_impl.cc
thread_local std::string ReturnableStatus::last_error_message_;
OrtxStatus OrtxObjectImpl::IsInstanceOf(extObjectKind_t kind) const {
@ -37,7 +39,7 @@ extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, .
va_start(args, object);
if (kind == extObjectKind_t::kOrtxKindDetokenizerCache) {
*object = OrtxObjectFactory<DetokenizerCache>::CreateForward();
*object = OrtxObjectFactory::CreateForward<DetokenizerCache>();
} else if (kind == extObjectKind_t::kOrtxKindTokenizer) {
return OrtxCreateTokenizer(static_cast<OrtxTokenizer**>(object), va_arg(args, const char*));
}
@ -80,8 +82,8 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) {
return kOrtxErrorInvalidArgument;
}
if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) {
OrtxObjectFactory<StringArray>::Dispose(object);
/* if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) {
OrtxObjectFactory::Dispose<StringArray>(object);
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindTokenId2DArray) {
OrtxObjectFactory<TokenId2DArray>::Dispose(object);
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindDetokenizerCache) {
@ -94,6 +96,11 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) {
OrtxObjectFactory<ImageProcessorResult>::Dispose(object);
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindProcessor) {
OrtxObjectFactory<ImageProcessor>::Dispose(object);
} */
if (Ortx_object->ortx_kind() >= kOrtxKindBegin && Ortx_object->ortx_kind() < kOrtxKindEnd) {
OrtxObjectFactory::Dispose<OrtxObjectImpl>(object);
} else {
return kOrtxErrorInvalidArgument;
}
return extError_t();
@ -113,6 +120,30 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object) {
return err;
}
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor) {
if (result == nullptr || tensor == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto result_ptr = static_cast<TensorResult*>(result);
ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTensorResult));
if (!status.IsOk()) {
return status.Code();
}
ortc::TensorBase* ts = result_ptr->GetAt(index);
if (ts == nullptr) {
ReturnableStatus::last_error_message_ = "Cannot get the tensor at the specified index from the result";
return kOrtxErrorInvalidArgument;
}
auto tensor_ptr = std::make_unique<OrtxObjectWrapper<ortc::TensorBase, kOrtxKindTensor>>();
tensor_ptr->SetObject(ts);
*tensor = static_cast<OrtxTensor*>(tensor_ptr.release());
return extError_t();
}
extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data, const int64_t** shape,
size_t* num_dims) {
if (tensor == nullptr) {
@ -120,7 +151,7 @@ extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data
return kOrtxErrorInvalidArgument;
}
auto tensor_impl = static_cast<OrtxObjectWrapper<ortc::TensorBase>*>(tensor);
auto tensor_impl = static_cast<OrtxObjectWrapper<ortc::TensorBase, kOrtxKindTensor>*>(tensor);
if (tensor_impl->ortx_kind() != extObjectKind_t::kOrtxKindTensor) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;

Просмотреть файл

@ -3,8 +3,10 @@
#pragma once
#include <vector>
#include <fstream>
#include "ortx_utils.h"
#include "file_sys.h"
#include "ext_status.h"
#include "op_def_struct.h"
@ -12,7 +14,7 @@ namespace ort_extensions {
class OrtxObjectImpl : public OrtxObject {
public:
explicit OrtxObjectImpl(extObjectKind_t kind = extObjectKind_t::kOrtxKindUnknown) : OrtxObject() {
ext_kind_ = static_cast<int>(kind);
ext_kind_ = kind;
};
virtual ~OrtxObjectImpl() = default;
@ -24,30 +26,21 @@ class OrtxObjectImpl : public OrtxObject {
}
return static_cast<extObjectKind_t>(ext_kind_);
}
template <typename T>
struct Type2Kind {
static const extObjectKind_t value = kOrtxKindUnknown;
};
};
template <>
struct OrtxObjectImpl::Type2Kind<ortc::TensorBase> {
static const extObjectKind_t value = kOrtxKindTensor;
};
template <typename T>
// A wrapper class to store a object pointer which is readonly. i.e. unowned.
template <typename T, extObjectKind_t kind>
class OrtxObjectWrapper : public OrtxObjectImpl {
public:
OrtxObjectWrapper() : OrtxObjectImpl(OrtxObjectImpl::Type2Kind<T>::value) {}
OrtxObjectWrapper() : OrtxObjectImpl(kind) {}
~OrtxObjectWrapper() override = default;
void SetObject(T* t) { stored_object_ = t; }
void SetObject(const T* t) { stored_object_ = t; }
[[nodiscard]] T* GetObject() const { return stored_object_; }
[[nodiscard]] const T* GetObject() const { return stored_object_; }
private:
T* stored_object_{};
const T* stored_object_{};
};
template <typename T>
@ -100,6 +93,35 @@ class StringArray : public OrtxObjectImpl {
std::vector<std::string> strings_;
};
class TensorResult : public OrtxObjectImpl {
public:
TensorResult() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTensorResult) {}
~TensorResult() override = default;
void SetTensors(std::vector<std::unique_ptr<ortc::TensorBase>>&& tensors) { tensors_ = std::move(tensors); }
[[nodiscard]] const std::vector<std::unique_ptr<ortc::TensorBase>>& tensors() const { return tensors_; }
[[nodiscard]] std::vector<ortc::TensorBase*> GetTensors() const {
std::vector<ortc::TensorBase*> ts;
ts.reserve(tensors_.size());
for (auto& t : tensors_) {
ts.push_back(t.get());
}
return ts;
}
ortc::TensorBase* GetAt(size_t i) const {
if (i < tensors_.size()) {
return tensors_[i].get();
}
return nullptr;
}
private:
std::vector<std::unique_ptr<ortc::TensorBase>> tensors_;
};
struct ReturnableStatus {
public:
thread_local static std::string last_error_message_;
@ -123,25 +145,26 @@ struct ReturnableStatus {
OrtxStatus status_;
};
template <typename T>
class OrtxObjectFactory {
public:
static std::unique_ptr<T> Create() { return std::make_unique<T>(); }
static OrtxObject* CreateForward();
static void DisposeForward(OrtxObject* object);
template <typename T>
static OrtxObject* Create() {
return std::make_unique<T>().release();
}
template <typename T>
static void Dispose(OrtxObject* object) {
auto obj_ptr = static_cast<T*>(object);
std::unique_ptr<T> ptr(obj_ptr);
ptr.reset();
}
// Forward declaration for creating an object which isn't visible to c_api_utils.cc
// and the definition is in the corresponding .cc file.
template <typename T>
static OrtxObject* CreateForward();
};
class DetokenizerCache; // forward definition in tokenizer_impl.cc
class ProcessorResult; // forward definition in image_processor.h
class CppAllocator : public ortc::IAllocator {
public:
void* Alloc(size_t size) override { return std::make_unique<char[]>(size).release(); }
@ -157,4 +180,25 @@ class CppAllocator : public ortc::IAllocator {
}
};
template <typename It, typename T>
std::tuple<std::unique_ptr<T[]>, size_t> LoadRawData(It begin, It end) {
auto raw_data = std::make_unique<T[]>(end - begin);
size_t n = 0;
for (auto it = begin; it != end; ++it) {
std::ifstream ifs = path(*it).open(std::ios::binary | std::ios::in);
if (!ifs.is_open()) {
break;
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
ifs.seekg(0, std::ios::beg);
T& datum = raw_data[n++];
datum.resize(size);
ifs.read(reinterpret_cast<char*>(datum.data()), size);
}
return std::make_tuple(std::move(raw_data), n);
}
} // namespace ort_extensions

Просмотреть файл

@ -7,6 +7,7 @@
#include "file_sys.h"
#include "image_processor.h"
#include "c_api_utils.hpp"
#include "cv2/imgcodecs/imdecode.hpp"
#include "image_transforms.hpp"
#include "image_transforms_phi_3.hpp"
@ -14,38 +15,11 @@
using namespace ort_extensions;
using json = nlohmann::json;
namespace ort_extensions {
template <typename It>
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(It begin, It end) {
auto raw_images = std::make_unique<ImageRawData[]>(end - begin);
size_t n = 0;
for (auto it = begin; it != end; ++it) {
std::ifstream ifs = path(*it).open(std::ios::binary);
if (!ifs.is_open()) {
break;
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
ifs.seekg(0, std::ios::beg);
ImageRawData& raw_image = raw_images[n++];
raw_image.resize(size);
ifs.read(reinterpret_cast<char*>(raw_image.data()), size);
}
return std::make_tuple(std::move(raw_images), n);
std::tuple<std::unique_ptr<ImageRawData[]>, size_t>
ort_extensions::LoadRawImages(const std::initializer_list<const char*>& image_paths) {
return ort_extensions::LoadRawData<const char* const*, ImageRawData>(image_paths.begin(), image_paths.end());
}
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
const std::initializer_list<const char*>& image_paths) {
return LoadRawImages(image_paths.begin(), image_paths.end());
}
template std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages<char const**>(char const**, char const**);
} // namespace ort_extensions
Operation::KernelRegistry ImageProcessor::kernel_registry_ = {
{"DecodeImage", []() { return CreateKernelInstance(image_decoder); }},
{"Resize", []() { return CreateKernelInstance(&Resize::Compute); }},
@ -97,9 +71,7 @@ OrtxStatus ImageProcessor::Init(std::string_view processor_def) {
return {};
}
ImageProcessor::ImageProcessor()
: OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) {
}
ImageProcessor::ImageProcessor() : OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) {}
template <typename T>
static ortc::Tensor<T>* StackTensor(const std::vector<TensorArgs>& arg_lists, int axis, ortc::IAllocator* allocator) {
@ -136,39 +108,6 @@ static ortc::Tensor<T>* StackTensor(const std::vector<TensorArgs>& arg_lists, in
return output.release();
}
static OrtxStatus StackTensors(const std::vector<TensorArgs>& arg_lists, std::vector<TensorPtr>& outputs,
ortc::IAllocator* allocator) {
if (arg_lists.empty()) {
return {};
}
size_t batch_size = arg_lists.size();
size_t num_outputs = arg_lists[0].size();
for (size_t axis = 0; axis < num_outputs; ++axis) {
std::vector<ortc::TensorBase*> ts_ptrs;
ts_ptrs.reserve(arg_lists.size());
std::vector<int64_t> shape = arg_lists[0][axis]->Shape();
for (auto& ts : arg_lists) {
if (shape != ts[axis]->Shape()) {
return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."};
}
ts_ptrs.push_back(ts[axis]);
}
std::vector<int64_t> output_shape = shape;
output_shape.insert(output_shape.begin(), batch_size);
std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape);
for (size_t i = 0; i < batch_size; ++i) {
auto ts = ts_ptrs[i];
const std::byte* ts_buff = reinterpret_cast<const std::byte*>(ts->DataRaw());
auto ts_size = ts->SizeInBytes();
std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size);
}
}
return {};
}
std::tuple<OrtxStatus, ProcessorResult> ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_data,
ortc::Tensor<float>** pixel_values,
ortc::Tensor<int64_t>** image_sizes,
@ -209,7 +148,7 @@ std::tuple<OrtxStatus, ProcessorResult> ImageProcessor::PreProcess(ort_extension
return {status, std::move(r)};
}
OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_data, ImageProcessorResult& r) const {
OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_data, TensorResult& r) const {
std::vector<TensorArgs> inputs;
inputs.resize(image_data.size());
for (size_t i = 0; i < image_data.size(); ++i) {
@ -235,9 +174,13 @@ OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_d
}
}
r.results = operations_.back()->AllocateOutputs(allocator_);
status = StackTensors(outputs, r.results, allocator_);
auto img_result = operations_.back()->AllocateOutputs(allocator_);
status = OrtxRunner::StackTensors(outputs, img_result, allocator_);
operations_.back()->ResetTensors(allocator_);
if (status.IsOk()) {
r.SetTensors(std::move(img_result));
}
return status;
}
@ -257,14 +200,3 @@ void ImageProcessor::ClearOutputs(ProcessorResult* r) {
r->num_img_takens = nullptr;
}
}
void ort_extensions::ImageProcessor::ClearOutputs(ImageProcessorResult* r) {
if (r == nullptr) {
return;
}
for (auto& ts : r->results) {
ts.reset();
}
r->results.clear(); // clear the vector
}

Просмотреть файл

@ -16,9 +16,6 @@ namespace ort_extensions {
using ImageRawData = std::vector<uint8_t>;
template <typename It>
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(It begin, It end);
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
const std::initializer_list<const char*>& image_paths);
@ -29,13 +26,6 @@ class ProcessorResult : public OrtxObjectImpl {
ortc::Tensor<int64_t>* image_sizes{};
ortc::Tensor<int64_t>* num_img_takens{};
};
class ImageProcessorResult : public OrtxObjectImpl {
public:
ImageProcessorResult() : OrtxObjectImpl(kOrtxKindImageProcessorResult) {}
std::vector<TensorPtr> results;
};
class ImageProcessor : public OrtxObjectImpl {
public:
ImageProcessor();
@ -43,15 +33,16 @@ class ImageProcessor : public OrtxObjectImpl {
OrtxStatus Init(std::string_view processor_def);
// Deprecated, using the next function instead
std::tuple<OrtxStatus, ProcessorResult> PreProcess(ort_extensions::span<ImageRawData> image_data,
ortc::Tensor<float>** pixel_values,
ortc::Tensor<int64_t>** image_sizes,
ortc::Tensor<int64_t>** num_img_takens) const;
OrtxStatus PreProcess(ort_extensions::span<ImageRawData> image_data, ImageProcessorResult& r) const;
OrtxStatus PreProcess(ort_extensions::span<ImageRawData> image_data, TensorResult& r) const;
// Deprecated, using the next function instead
static void ClearOutputs(ProcessorResult* r);
static void ClearOutputs(ImageProcessorResult* r);
static Operation::KernelRegistry kernel_registry_;

Просмотреть файл

@ -28,7 +28,8 @@ class KernelDef {
virtual TensorArgs AllocateOutput(ortc::IAllocator* allocator) const = 0;
virtual OrtxStatus Apply(TensorArgs& inputs, TensorArgs& output) const = 0;
using AttrType = std::variant<std::string, double, int64_t, std::vector<double>>;
using AttrType =
std::variant<std::string, double, int64_t, std::vector<std::string>, std::vector<double>, std::vector<int64_t>>;
using AttrDict = std::unordered_map<std::string, AttrType>;
template <typename... Args>
@ -98,7 +99,7 @@ class KernelDef {
template <typename... Args>
class KernelFunction : public KernelDef {
public:
KernelFunction(OrtxStatus (*body)(Args...)) : body_(body){};
KernelFunction(OrtxStatus (*body)(Args...)) : body_(body) {};
virtual ~KernelFunction() = default;
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
@ -132,7 +133,7 @@ class KernelFunction : public KernelDef {
template <typename T, typename... Args>
class KernelStruct : public KernelDef {
public:
KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body){};
KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body) {};
virtual ~KernelStruct() = default;
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
@ -167,8 +168,18 @@ class KernelStruct : public KernelDef {
attr_dict[key] = value.template get<int64_t>();
} else if (value.is_number_float()) {
attr_dict[key] = value.template get<double>();
} else if (value.is_array()) {
attr_dict[key] = value.template get<std::vector<double>>();
} else if (value.is_array() && value.size() > 0) {
auto& elem_0 = value.at(0);
if (elem_0.is_number_float()) {
attr_dict[key] = value.template get<std::vector<double>>();
} else if (elem_0.is_string()) {
attr_dict[key] = value.template get<std::vector<std::string>>();
} else if (elem_0.is_number_integer() || elem_0.is_number_unsigned()) {
attr_dict[key] = value.template get<std::vector<int64_t>>();
} else {
return {kOrtxErrorCorruptData, "Unsupported mix types in attribute value."};
}
} else {
return {kOrtxErrorCorruptData, "Invalid attribute type."};
}
@ -309,6 +320,39 @@ class OrtxRunner {
return {};
}
static OrtxStatus StackTensors(const std::vector<TensorArgs>& arg_lists, std::vector<TensorPtr>& outputs,
ortc::IAllocator* allocator) {
if (arg_lists.empty()) {
return {};
}
size_t batch_size = arg_lists.size();
size_t num_outputs = arg_lists[0].size();
for (size_t axis = 0; axis < num_outputs; ++axis) {
std::vector<ortc::TensorBase*> ts_ptrs;
ts_ptrs.reserve(arg_lists.size());
std::vector<int64_t> shape = arg_lists[0][axis]->Shape();
for (auto& ts : arg_lists) {
if (shape != ts[axis]->Shape()) {
return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."};
}
ts_ptrs.push_back(ts[axis]);
}
std::vector<int64_t> output_shape = shape;
output_shape.insert(output_shape.begin(), batch_size);
std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape);
for (size_t i = 0; i < batch_size; ++i) {
auto ts = ts_ptrs[i];
const std::byte* ts_buff = reinterpret_cast<const std::byte*>(ts->DataRaw());
auto ts_size = ts->SizeInBytes();
std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size);
}
}
return {};
}
private:
ortc::IAllocator* allocator_;
std::vector<Operation*> ops_;

Просмотреть файл

@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "speech_extractor.h"
#include "audio/audio_decoder.h"
#include "speech_features.hpp"
using namespace ort_extensions;
Operation::KernelRegistry SpeechFeatureExtractor::kernel_registry_ = {
{"AudioDecoder", []() { return CreateKernelInstance(&AudioDecoder::ComputeNoOpt); }},
{"STFTNorm", []() { return CreateKernelInstance(&SpeechFeatures::STFTNorm); }},
{"LogMelSpectrum", []() { return CreateKernelInstance(&LogMel::Compute); }},
};
SpeechFeatureExtractor::SpeechFeatureExtractor()
: OrtxObjectImpl(extObjectKind_t::kOrtxKindFeatureExtractor), allocator_(&CppAllocator::Instance()) {}
OrtxStatus SpeechFeatureExtractor::Init(std::string_view extractor_def) {
std::string fe_def_str;
if (extractor_def.size() >= 5 && extractor_def.substr(extractor_def.size() - 5) == ".json") {
std::ifstream ifs = path({extractor_def.data(), extractor_def.size()}).open();
if (!ifs.is_open()) {
return {kOrtxErrorInvalidArgument, std::string("[ImageProcessor]: failed to open ") + std::string(extractor_def)};
}
fe_def_str = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
extractor_def = fe_def_str.c_str();
}
// pase the extraction_def by json
auto fe_json = json::parse(extractor_def, nullptr, false);
if (fe_json.is_discarded()) {
return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: failed to parse extractor json configuration."};
}
auto fe_root = fe_json.at("feature_extraction");
if (!fe_root.is_object()) {
return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: feature_extraction field is missing."};
}
auto op_sequence = fe_root.at("sequence");
if (!op_sequence.is_array() || op_sequence.empty()) {
return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: sequence field is missing."};
}
operations_.reserve(op_sequence.size());
for (auto mod_iter = op_sequence.begin(); mod_iter != op_sequence.end(); ++mod_iter) {
auto op = std::make_unique<Operation>(kernel_registry_);
auto status = op->Init(mod_iter->dump());
if (!status.IsOk()) {
return status;
}
operations_.push_back(std::move(op));
}
return {};
}
OrtxStatus SpeechFeatureExtractor::DoCall(ort_extensions::span<AudioRawData> raw_speech,
std::unique_ptr<ortc::Tensor<float>>& log_mel) const {
// setup the input tensors
std::vector<TensorArgs> inputs;
inputs.resize(raw_speech.size());
for (size_t i = 0; i < raw_speech.size(); ++i) {
auto& ts_input = inputs[i];
AudioRawData& speech = raw_speech[i];
std::vector<int64_t> shape = {static_cast<int64_t>(speech.size())};
ts_input.push_back(std::make_unique<ortc::Tensor<uint8_t>>(shape, speech.data()).release());
}
std::vector<TensorArgs> outputs;
std::vector<Operation*> ops(operations_.size());
std::transform(operations_.begin(), operations_.end(), ops.begin(), [](auto& op) { return op.get(); });
OrtxRunner runner(allocator_, ops.data(), ops.size());
auto status = runner.Run(inputs, outputs);
if (!status.IsOk()) {
return status;
}
// clear the input tensors
for (auto& input : inputs) {
for (auto& ts : input) {
std::unique_ptr<ortc::TensorBase>(ts).reset();
}
}
auto results = operations_.back()->AllocateOutputs(allocator_);
status = OrtxRunner::StackTensors(outputs, results, allocator_);
if (status.IsOk()) {
log_mel.reset(static_cast<ortc::Tensor<float>*>(results[0].release()));
operations_.back()->ResetTensors(allocator_);
}
return status;
}

Просмотреть файл

@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ortx_extractor.h"
#include "c_api_utils.hpp"
#include "runner.hpp"
namespace ort_extensions {
typedef std::vector<std::byte> AudioRawData;
class SpeechFeatureExtractor : public OrtxObjectImpl {
public:
SpeechFeatureExtractor();
virtual ~SpeechFeatureExtractor() = default;
public:
OrtxStatus Init(std::string_view extractor_def);
OrtxStatus DoCall(ort_extensions::span<AudioRawData> raw_speech, std::unique_ptr<ortc::Tensor<float>>& log_mel) const;
static Operation::KernelRegistry kernel_registry_;
private:
std::vector<std::unique_ptr<Operation>> operations_;
ortc::IAllocator* allocator_;
};
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,224 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <dlib/matrix.h>
#include <math/dlib/stft_norm.hpp>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
namespace ort_extensions {
class SpeechFeatures {
public:
template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
for (const auto& [key, value] : attrs) {
if (key == "n_fft") {
n_fft_ = std::get<int64_t>(value);
} else if (key == "hop_length") {
hop_length_ = std::get<int64_t>(value);
} else if (key == "frame_length") {
frame_length_ = std::get<int64_t>(value);
} else if (key == "hann_win") {
auto& win = std::get<std::vector<double>>(value);
hann_win_.resize(win.size());
std::transform(win.begin(), win.end(), hann_win_.begin(), [](double x) { return static_cast<float>(x); });
} else if (key != "_comment") {
return {kOrtxErrorInvalidArgument, "[AudioFeatures]: Invalid key in the JSON configuration."};
}
}
if (hann_win_.empty()) {
hann_win_ = hann_window(frame_length_);
}
return {};
}
OrtxStatus STFTNorm(const ortc::Tensor<float>& pcm, ortc::Tensor<float>& stft_norm) {
return stft_norm_.Compute(pcm, n_fft_, hop_length_, {hann_win_.data(), hann_win_.size()}, frame_length_, stft_norm);
}
static std::vector<float> hann_window(int N) {
std::vector<float> window(N);
for (int n = 0; n < N; ++n) {
// this formula leads to more rounding errors than the one below
// window[n] = static_cast<float>(0.5 * (1 - std::cos(2 * M_PI * n / (N - 1))));
double n_sin = std::sin(M_PI * n / N);
window[n] = static_cast<float>(n_sin * n_sin);
}
return window;
}
private:
StftNormal stft_norm_;
int64_t n_fft_{};
int64_t hop_length_{};
int64_t frame_length_{};
std::vector<float> hann_win_;
};
class LogMel {
public:
template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
int n_fft = 0;
int n_mel = 0;
int chunk_size = 0;
for (const auto& [key, value] : attrs) {
if (key == "hop_length") {
hop_length_ = std::get<int64_t>(value);
} else if (key == "n_fft") {
n_fft = std::get<int64_t>(value);
} else if (key == "n_mel") {
n_mel = std::get<int64_t>(value);
} else if (key == "chunk_size") {
chunk_size = std::get<int64_t>(value);
} else {
return {kOrtxErrorInvalidArgument, "[LogMel]: Invalid key in the JSON configuration."};
}
}
n_samples_ = n_sr_ * chunk_size;
mel_filters_ = MelFilterBank(n_fft, n_mel, n_sr_);
return {};
}
OrtxStatus Compute(const ortc::Tensor<float>& stft_norm, ortc::Tensor<float>& logmel) {
// Compute the Mel spectrogram by following Python code
/*
magnitudes = stft_norm[:, :, :-1]
mel_spec = self.mel_filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
spec_min = log_spec.max() - 8.0
log_spec = torch.maximum(log_spec, spec_min)
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1],
self.n_samples // self.hop_length - spec_shape[2],
dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
*/
assert(stft_norm.Shape().size() == 3 && stft_norm.Shape()[0] == 1);
std::vector<int64_t> stft_shape = stft_norm.Shape();
dlib::matrix<float> magnitudes(stft_norm.Shape()[1], stft_norm.Shape()[2] - 1);
for (int i = 0; i < magnitudes.nr(); ++i) {
std::copy(stft_norm.Data() + i * stft_shape[2], stft_norm.Data() + (i + 1) * stft_shape[2] - 1,
magnitudes.begin() + i * magnitudes.nc());
}
dlib::matrix<float> mel_spec = mel_filters_ * magnitudes;
for (int i = 0; i < mel_spec.nr(); ++i) {
for (int j = 0; j < mel_spec.nc(); ++j) {
mel_spec(i, j) = std::max(1e-10f, mel_spec(i, j));
}
}
dlib::matrix<float> log_spec = dlib::log10(mel_spec);
float log_spec_min = dlib::max(log_spec) - 8.0f;
for (int i = 0; i < log_spec.nr(); ++i) {
for (int j = 0; j < log_spec.nc(); ++j) {
float v = std::max(log_spec(i, j), log_spec_min);
v = (v + 4.0f) / 4.0f;
log_spec(i, j) = v;
}
}
std::vector<int64_t> shape = {mel_filters_.nr(), n_samples_ / hop_length_};
float* buff = logmel.Allocate(shape);
std::fill(buff, buff + logmel.NumberOfElement(), (log_spec_min + 4.0f) / 4.0f);
for (int i = 0; i < log_spec.nr(); ++i) {
auto row_len = log_spec.nc() * i;
std::copy(log_spec.begin() + i * log_spec.nc(), log_spec.begin() + (i + 1) * log_spec.nc(), buff + i * shape[1]);
}
return {};
}
// Function to compute the Mel filterbank
static dlib::matrix<float> MelFilterBank(int n_fft, int n_mels, int sr = 16000, float min_mel = 0,
float max_mel = 45.245640471924965) {
// Initialize the filterbank matrix
dlib::matrix<float> fbank(n_mels, n_fft / 2 + 1);
memset(fbank.begin(), 0, fbank.size() * sizeof(float));
// Compute the frequency bins for the DFT
std::vector<float> freq_bins(n_fft / 2 + 1);
for (int i = 0; i <= n_fft / 2; ++i) {
freq_bins[i] = i * sr / static_cast<float>(n_fft);
}
// Compute the Mel scale frequencies
std::vector<float> mel(n_mels + 2);
for (int i = 0; i < n_mels + 2; ++i) {
mel[i] = min_mel + i * (max_mel - min_mel) / (n_mels + 1);
}
// Fill in the linear scale
float f_min = 0.0f;
float f_sp = 200.0f / 3.0f;
std::vector<float> freqs(n_mels + 2);
for (int i = 0; i < n_mels + 2; ++i) {
freqs[i] = f_min + f_sp * mel[i];
}
// Nonlinear scale
float min_log_hz = 1000.0f;
float min_log_mel = (min_log_hz - f_min) / f_sp;
float logstep = log(6.4) / 27.0;
for (int i = 0; i < n_mels + 2; ++i) {
if (mel[i] >= min_log_mel) {
freqs[i] = min_log_hz * exp(logstep * (mel[i] - min_log_mel));
}
}
std::vector<float> mel_bins = freqs;
std::vector<float> mel_spacing(n_mels + 1);
for (int i = 0; i < n_mels + 1; ++i) {
mel_spacing[i] = mel_bins[i + 1] - mel_bins[i];
}
// Compute the ramps
std::vector<std::vector<float>> ramps(n_mels + 2, std::vector<float>(n_fft / 2 + 1));
for (int i = 0; i < n_mels + 2; ++i) {
for (int j = 0; j <= n_fft / 2; ++j) {
ramps[i][j] = mel_bins[i] - freq_bins[j];
}
}
for (int i = 0; i < n_mels; ++i) {
for (int j = 0; j <= n_fft / 2; ++j) {
float left = -ramps[i][j] / mel_spacing[i];
float right = ramps[i + 2][j] / mel_spacing[i + 1];
fbank(i, j) = std::max(0.0f, std::min(left, right));
}
}
// Energy normalization
for (int i = 0; i < n_mels; ++i) {
float energy_norm = 2.0f / (mel_bins[i + 2] - mel_bins[i]);
for (int j = 0; j <= n_fft / 2; ++j) {
fbank(i, j) *= energy_norm;
}
}
return fbank;
}
private:
int64_t n_samples_ = {}; // sr * chunk_size
int64_t hop_length_{};
const int64_t n_sr_{16000};
dlib::matrix<float> mel_filters_;
};
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,437 @@
{
"feature_extraction": {
"sequence": [
{
"operation": {
"name": "audio_decoder",
"type": "AudioDecoder"
}
},
{
"operation": {
"name": "STFT",
"type": "STFTNorm",
"attrs": {
"n_fft": 400,
"frame_length": 400,
"hop_length": 160,
"_comment": [
0.0,
0.0000616908073425293,
0.0002467334270477295,
0.0005550682544708252,
0.000986635684967041,
0.0015413463115692139,
0.0022190213203430176,
0.0030195116996765137,
0.003942638635635376,
0.004988163709640503,
0.006155818700790405,
0.007445335388183594,
0.008856385946273804,
0.010388582944869995,
0.012041628360748291,
0.013815045356750488,
0.01570841670036316,
0.01772129535675049,
0.019853144884109497,
0.022103488445281982,
0.02447172999382019,
0.026957333087921143,
0.029559612274169922,
0.03227800130844116,
0.03511175513267517,
0.03806024789810181,
0.0411226749420166,
0.044298380613327026,
0.04758647084236145,
0.05098623037338257,
0.05449673533439636,
0.058117181062698364,
0.06184667348861694,
0.0656842589378357,
0.06962898373603821,
0.07367992401123047,
0.0778360664844513,
0.08209633827209473,
0.08645972609519958,
0.09092515707015991,
0.09549149870872498,
0.10015767812728882,
0.10492250323295593,
0.1097848117351532,
0.11474338173866272,
0.11979702115058899,
0.12494447827339172,
0.13018447160720825,
0.1355157196521759,
0.14093685150146484,
0.1464466154575348,
0.15204361081123352,
0.1577264666557312,
0.16349375247955322,
0.16934409737586975,
0.1752760112285614,
0.18128803372383118,
0.18737870454788208,
0.19354650378227234,
0.1997898817062378,
0.20610737800598145,
0.21249738335609436,
0.21895831823349,
0.2254886031150818,
0.23208662867546082,
0.23875075578689575,
0.24547931551933289,
0.2522706985473633,
0.25912320613861084,
0.26603513956069946,
0.27300477027893066,
0.2800304591655731,
0.2871103882789612,
0.29424285888671875,
0.30142611265182495,
0.30865830183029175,
0.31593772768974304,
0.3232625722885132,
0.3306310474872589,
0.3380413055419922,
0.34549152851104736,
0.352979838848114,
0.3605044484138489,
0.3680635094642639,
0.37565508484840393,
0.38327735662460327,
0.3909284174442291,
0.39860638976097107,
0.4063093662261963,
0.41403549909591675,
0.42178282141685486,
0.4295494258403778,
0.43733343482017517,
0.44513291120529175,
0.45294591784477234,
0.46077051758766174,
0.46860480308532715,
0.4764467775821686,
0.4842946231365204,
0.492146372795105,
0.5,
0.5078536868095398,
0.515705406665802,
0.5235532522201538,
0.5313953161239624,
0.5392295718193054,
0.5470541715621948,
0.5548672080039978,
0.562666654586792,
0.5704506635665894,
0.5782172679901123,
0.5859646201133728,
0.5936906933784485,
0.6013936996459961,
0.609071671962738,
0.6167227625846863,
0.6243450045585632,
0.6319366097450256,
0.6394955515861511,
0.6470202207565308,
0.6545085310935974,
0.6619587540626526,
0.6693689823150635,
0.6767374277114868,
0.6840623021125793,
0.691341757774353,
0.6985740065574646,
0.7057572603225708,
0.7128896713256836,
0.719969630241394,
0.7269952893257141,
0.7339649796485901,
0.7408769130706787,
0.7477294206619263,
0.7545207738876343,
0.761249303817749,
0.7679134607315063,
0.774511456489563,
0.7810417413711548,
0.7875027060508728,
0.7938927412033081,
0.800210177898407,
0.8064535856246948,
0.8126214146614075,
0.8187121152877808,
0.8247240781784058,
0.8306560516357422,
0.8365063667297363,
0.8422735929489136,
0.8479564785957336,
0.8535534143447876,
0.8590631484985352,
0.8644843101501465,
0.8698155879974365,
0.8750555515289307,
0.8802030086517334,
0.8852566480636597,
0.8902152180671692,
0.8950775265693665,
0.899842381477356,
0.9045084714889526,
0.9090749025344849,
0.9135403037071228,
0.9179036617279053,
0.9221639633178711,
0.9263200759887695,
0.9303710460662842,
0.9343158006668091,
0.9381533861160278,
0.941882848739624,
0.945503294467926,
0.9490138292312622,
0.9524135589599609,
0.9557017087936401,
0.9588773250579834,
0.961939811706543,
0.9648882746696472,
0.9677220582962036,
0.9704403877258301,
0.9730427265167236,
0.9755282998085022,
0.9778965711593628,
0.9801468849182129,
0.9822787046432495,
0.9842916131019592,
0.9861849546432495,
0.9879584312438965,
0.9896113872528076,
0.9911436438560486,
0.9925546646118164,
0.9938441514968872,
0.9950118064880371,
0.996057391166687,
0.9969804883003235,
0.997780978679657,
0.9984586238861084,
0.999013364315033,
0.9994449615478516,
0.9997532367706299,
0.9999383091926575,
1,
0.9999383091926575,
0.9997532367706299,
0.9994449615478516,
0.999013364315033,
0.9984586238861084,
0.997780978679657,
0.9969804286956787,
0.9960573315620422,
0.9950118064880371,
0.9938441514968872,
0.9925546646118164,
0.9911435842514038,
0.9896113872528076,
0.9879583716392517,
0.9861849546432495,
0.9842915534973145,
0.9822787046432495,
0.9801468253135681,
0.9778964519500732,
0.9755282402038574,
0.9730426073074341,
0.9704403877258301,
0.9677219390869141,
0.9648882150650024,
0.9619396924972534,
0.9588772654533386,
0.9557015895843506,
0.9524134397506714,
0.9490137100219727,
0.9455032348632812,
0.9418827295303345,
0.9381532669067383,
0.9343156814575195,
0.9303709268569946,
0.9263200759887695,
0.9221639633178711,
0.9179036617279053,
0.913540244102478,
0.9090747833251953,
0.9045084714889526,
0.8998422622680664,
0.8950774669647217,
0.8902151584625244,
0.8852565884590149,
0.8802029490470886,
0.8750554919242859,
0.869815468788147,
0.8644842505455017,
0.8590630888938904,
0.853553295135498,
0.8479562997817993,
0.842273473739624,
0.836506187915802,
0.8306558728218079,
0.8247239589691162,
0.8187118768692017,
0.8126212358474731,
0.8064534664154053,
0.8002099990844727,
0.793892502784729,
0.7875025272369385,
0.7810416221618652,
0.7745113372802734,
0.767913281917572,
0.7612491846084595,
0.7545205950737,
0.7477291822433472,
0.7408767342567444,
0.7339648008346558,
0.7269951105117798,
0.7199694514274597,
0.7128894925117493,
0.7057570219039917,
0.6985738277435303,
0.6913415789604187,
0.684062123298645,
0.6767372488975525,
0.6693688035011292,
0.6619585752487183,
0.6545083522796631,
0.6470199823379517,
0.6394953727722168,
0.6319363117218018,
0.6243447661399841,
0.6167224645614624,
0.6090714335441589,
0.601393461227417,
0.5936904549598694,
0.5859643220901489,
0.5782170295715332,
0.5704504251480103,
0.5626664161682129,
0.5548669099807739,
0.5470539331436157,
0.5392293334007263,
0.5313950181007385,
0.5235530138015747,
0.5157051682472229,
0.507853627204895,
0.5,
0.4921463429927826,
0.484294593334198,
0.4764467477798462,
0.46860471367836,
0.4607704281806946,
0.4529458284378052,
0.4451328217983246,
0.437333345413208,
0.42954933643341064,
0.4217827320098877,
0.4140354096889496,
0.4063093066215515,
0.3986063003540039,
0.39092832803726196,
0.3832772672176361,
0.37565499544143677,
0.36806342005729675,
0.3605043888092041,
0.35297977924346924,
0.3454914391040802,
0.338041216135025,
0.33063095808029175,
0.3232625126838684,
0.3159376382827759,
0.3086581826210022,
0.3014259934425354,
0.2942427396774292,
0.28711026906967163,
0.2800303101539612,
0.2730046510696411,
0.2660350203514099,
0.2591230869293213,
0.25227057933807373,
0.24547919631004333,
0.2387506067752838,
0.23208650946617126,
0.22548848390579224,
0.21895819902420044,
0.2124972641468048,
0.2061072587966919,
0.19978976249694824,
0.1935463547706604,
0.18737855553627014,
0.18128788471221924,
0.17527586221694946,
0.1693439483642578,
0.16349363327026367,
0.15772631764411926,
0.15204349160194397,
0.14644649624824524,
0.1409367322921753,
0.13551557064056396,
0.1301843225955963,
0.12494435906410217,
0.11979690194129944,
0.11474326252937317,
0.10978469252586365,
0.10492238402366638,
0.10015755891799927,
0.09549137949943542,
0.09092503786087036,
0.08645960688591003,
0.08209621906280518,
0.07783591747283936,
0.07367980480194092,
0.06962886452674866,
0.06568413972854614,
0.06184655427932739,
0.0581170916557312,
0.0544966459274292,
0.05098611116409302,
0.04758638143539429,
0.044298261404037476,
0.04112258553504944,
0.038060128688812256,
0.03511166572570801,
0.03227788209915161,
0.02955952286720276,
0.02695724368095398,
0.024471670389175415,
0.02210339903831482,
0.01985308527946472,
0.017721205949783325,
0.015708357095718384,
0.0138150155544281,
0.012041598558425903,
0.010388582944869995,
0.008856356143951416,
0.007445335388183594,
0.006155818700790405,
0.004988163709640503,
0.003942638635635376,
0.0030195116996765137,
0.0022190213203430176,
0.0015413165092468262,
0.000986635684967041,
0.0005550682544708252,
0.0002467334270477295,
0.0000616908073425293
]
}
}
},
{
"operation": {
"name": "log_mel_spectrogram",
"type": "LogMelSpectrum",
"attrs": {
"chunk_size": 30,
"hop_length": 160,
"n_fft": 400,
"n_mel": 80
}
}
}
]
}
}

Просмотреть файл

@ -4,7 +4,9 @@
#pragma once
#include "ortx_tokenizer.h"
// make sure the C only compiler compatibility only.
#include "ortx_processor.h"
#include "ortx_extractor.h"
#ifdef __cplusplus

Просмотреть файл

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <vector>
#include <tuple>
#include <fstream>
#include <filesystem>
#include "gtest/gtest.h"
#include "ortx_cpp_helper.h"
#include "shared/api/speech_extractor.h"
using namespace ort_extensions;
TEST(ExtractorTest, TestWhisperFeatureExtraction) {
const char* audio_path[] = {"data/jfk.flac", "data/1272-141231-0002.wav", "data/1272-141231-0002.mp3"};
OrtxObjectPtr<OrtxRawAudios> raw_audios;
extError_t err = OrtxLoadAudios(ort_extensions::ptr(raw_audios), audio_path, 3);
ASSERT_EQ(err, kOrtxOK);
OrtxObjectPtr<OrtxFeatureExtractor> feature_extractor(OrtxCreateSpeechFeatureExtractor, "data/whisper/feature_extraction.json");
OrtxObjectPtr<OrtxTensorResult> result;
err = OrtxSpeechLogMel(feature_extractor.get(), raw_audios.get(), ort_extensions::ptr(result));
ASSERT_EQ(err, kOrtxOK);
OrtxObjectPtr<OrtxTensor> tensor;
err = OrtxTensorResultGetAt(result.get(), 0, ort_extensions::ptr(tensor));
ASSERT_EQ(err, kOrtxOK);
const float* data{};
const int64_t* shape{};
size_t num_dims;
err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(num_dims, 3);
ASSERT_EQ(shape[0], 3);
ASSERT_EQ(shape[1], 80);
ASSERT_EQ(shape[2], 3000);
}

Просмотреть файл

@ -7,7 +7,7 @@
#include <filesystem>
#include "gtest/gtest.h"
#include "ortx_c_helper.h"
#include "ortx_cpp_helper.h"
#include "shared/api/image_processor.h"
using namespace ort_extensions;
@ -85,18 +85,18 @@ TEST(ProcessorTest, TestClipImageProcessing) {
}
ASSERT_EQ(err, kOrtxOK);
OrtxObjectPtr<OrtxImageProcessorResult> result;
OrtxObjectPtr<OrtxTensorResult> result;
err = OrtxImagePreProcess(processor.get(), raw_images.get(), ort_extensions::ptr(result));
ASSERT_EQ(err, kOrtxOK);
OrtxObjectPtr<OrtxTensor> tensor;
err = OrtxImageGetTensorResult(result.get(), 0, ort_extensions::ptr(tensor));
OrtxTensor* tensor;
err = OrtxTensorResultGetAt(result.get(), 0, &tensor);
ASSERT_EQ(err, kOrtxOK);
const float* data{};
const int64_t* shape{};
size_t num_dims;
err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims);
err = OrtxGetTensorDataFloat(tensor, &data, &shape, &num_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(num_dims, 4);
}

Просмотреть файл