Feature extraction C API for whipser model (#755)
* Feature extraction C API for whipser model * Update the docs * Update the docs2 * refine the code * fix some issues * fix the Linux build * fix more data consistency issue * More code refinements
This commit is contained in:
Родитель
95d65e4ec0
Коммит
8153bc1a3a
|
@ -171,6 +171,10 @@ if (MSVC)
|
|||
endif()
|
||||
message(STATUS "_STATIC_MSVC_RUNTIME_LIBRARY: ${_STATIC_MSVC_RUNTIME_LIBRARY}")
|
||||
|
||||
# DLL initialization errors due to old conda msvcp140.dll dll are a result of the new MSVC compiler
|
||||
# See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856
|
||||
# Remove this definition once the conda msvcp140.dll dll is updated.
|
||||
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
|
||||
endif()
|
||||
|
||||
if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
|
||||
|
@ -442,7 +446,9 @@ endif()
|
|||
if(OCOS_ENABLE_BERT_TOKENIZER)
|
||||
# Bert
|
||||
set(_HAS_TOKENIZER ON)
|
||||
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*")
|
||||
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*"
|
||||
"operators/tokenizer/bert_tokenizer.*"
|
||||
"operators/tokenizer/bert_tokenizer_decoder.*")
|
||||
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
|
@ -820,7 +826,9 @@ if(OCOS_ENABLE_AZURE)
|
|||
endif()
|
||||
|
||||
target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS})
|
||||
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:noexcep_operators,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||
|
||||
target_link_libraries(ortcustomops PUBLIC ocos_operators)
|
||||
|
||||
if(_BUILD_SHARED_LIBRARY)
|
||||
|
@ -840,7 +848,8 @@ if(_BUILD_SHARED_LIBRARY)
|
|||
standardize_output_folder(extensions_shared)
|
||||
|
||||
if(LINUX OR ANDROID)
|
||||
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
|
||||
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS
|
||||
" -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
|
||||
# strip if not a debug build
|
||||
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s")
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# ONNXRuntime Extensions C ABI
|
||||
|
||||
ONNXRuntime Extensions provides a C-style ABI for pre-processing. It offers support for tokenization, image processing, speech feature extraction, and more. You can compile the ONNXRuntime Extensions as either a static library or a dynamic library to access these APIs.
|
||||
|
||||
The C ABI header files are named `ortx_*.h` and can be found in the include folder. There are three types of data processing APIs available:
|
||||
|
||||
- [`ortx_tokenizer.h`](../include/ortx_tokenizer.h): Provides tokenization for LLM models.
|
||||
- [`ortx_processor.h`](../include/ortx_processor.h): Offers image processing APIs for multimodels.
|
||||
- [`ortx_extraction.h`](../include/ortx_extractor.h): Provides speech feature extraction for audio data processing to assist the Whisper model.
|
||||
|
||||
## ABI QuickStart
|
||||
|
||||
Most APIs accept raw data inputs such as audio, image compressed binary formats, or UTF-8 encoded text for tokenization.
|
||||
|
||||
**Tokenization:** You can create a tokenizer object using `OrtxCreateTokenizer` and then use the object to tokenize a text or decode the token ID into the text. A C-style code snippet is available [here](../test/pp_api_test/c_only_test.c).
|
||||
|
||||
**Image processing:** `OrtxCreateProcessor` can create an image processor object from a pre-defined workflow in JSON format to process image files into a tensor-like data type. An example code snippet can be found [here](../test/pp_api_test/test_processor.cc#L75).
|
||||
|
||||
**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extractor.cc#L16).
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// C ABI header file for the onnxruntime-extensions tokenization module
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ortx_utils.h"
|
||||
|
||||
typedef OrtxObject OrtxFeatureExtractor;
|
||||
typedef OrtxObject OrtxRawAudios;
|
||||
typedef OrtxObject OrtxTensorResult;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Creates a feature extractor object.
|
||||
*
|
||||
* This function creates a feature extractor object based on the provided feature definition.
|
||||
*
|
||||
* @param[out] extractor Pointer to a pointer to the created feature extractor object.
|
||||
* @param[in] fe_def The feature definition used to create the feature extractor.
|
||||
*
|
||||
* @return An error code indicating the result of the operation.
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* fe_def);
|
||||
|
||||
/**
|
||||
* Loads a collection of audio files into memory.
|
||||
*
|
||||
* This function loads a collection of audio files specified by the `audio_paths` array
|
||||
* into memory and returns a pointer to the loaded audio data in the `audios` parameter.
|
||||
*
|
||||
* @param audios A pointer to a pointer that will be updated with the loaded audio data.
|
||||
* The caller is responsible for freeing the memory allocated for the audio data.
|
||||
* @param audio_paths An array of strings representing the paths to the audio files to be loaded.
|
||||
* @param num_audios The number of audio files to be loaded.
|
||||
*
|
||||
* @return An `extError_t` value indicating the success or failure of the operation.
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* const* audio_paths, size_t num_audios);
|
||||
|
||||
/**
|
||||
* @brief Creates an array of raw audio objects.
|
||||
*
|
||||
* This function creates an array of raw audio objects based on the provided data and sizes.
|
||||
*
|
||||
* @param audios Pointer to the variable that will hold the created raw audio objects.
|
||||
* @param data Array of pointers to the audio data.
|
||||
* @param sizes Array of pointers to the sizes of the audio data.
|
||||
* @param num_audios Number of audio objects to create.
|
||||
*
|
||||
* @return extError_t Error code indicating the success or failure of the operation.
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t* sizes[], size_t num_audios);
|
||||
|
||||
/**
|
||||
* @brief Calculates the log mel spectrogram for a given audio using the specified feature extractor.
|
||||
*
|
||||
* This function takes an instance of the OrtxFeatureExtractor struct, an instance of the OrtxRawAudios struct,
|
||||
* and a pointer to a OrtxTensorResult pointer. It calculates the log mel spectrogram for the given audio using
|
||||
* the specified feature extractor and stores the result in the provided log_mel pointer.
|
||||
*
|
||||
* @param extractor The feature extractor to use for calculating the log mel spectrogram.
|
||||
* @param audio The raw audio data to process.
|
||||
* @param log_mel A pointer to a OrtxTensorResult pointer where the result will be stored.
|
||||
* @return An extError_t value indicating the success or failure of the operation.
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** log_mel);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -10,7 +10,6 @@
|
|||
// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
|
||||
typedef OrtxObject OrtxProcessor;
|
||||
typedef OrtxObject OrtxRawImages;
|
||||
typedef OrtxObject OrtxImageProcessorResult;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -40,8 +39,22 @@ extError_t ORTX_API_CALL OrtxCreateProcessor(OrtxProcessor** processor, const ch
|
|||
extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** image_paths, size_t num_images,
|
||||
size_t* num_images_loaded);
|
||||
|
||||
|
||||
/**
|
||||
* @brief Preprocesses the given raw images using the specified processor.
|
||||
* @brief Creates raw images from the provided data.
|
||||
*
|
||||
* This function creates raw images from the provided data. The raw images are stored in the `images` parameter.
|
||||
*
|
||||
* @param images Pointer to a pointer to the `OrtxRawImages` structure that will hold the created raw images.
|
||||
* @param data Array of pointers to the data for each image.
|
||||
* @param sizes Array of pointers to the sizes of each image.
|
||||
* @param num_images Number of images to create.
|
||||
* @return An `extError_t` value indicating the success or failure of the operation.
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxCreateRawImages(OrtxRawImages** images, const void* data[], const int64_t* sizes[], size_t num_images);
|
||||
|
||||
/**
|
||||
* @brief Pre-processes the given raw images using the specified processor.
|
||||
*
|
||||
* This function applies preprocessing operations on the raw images using the provided processor.
|
||||
* The result of the preprocessing is stored in the `OrtxImageProcessorResult` object.
|
||||
|
@ -52,24 +65,7 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima
|
|||
* @return An `extError_t` value indicating the success or failure of the preprocessing operation.
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
|
||||
OrtxImageProcessorResult** result);
|
||||
|
||||
/**
|
||||
* @brief Retrieves the image processor result at the specified index.
|
||||
*
|
||||
* @param result Pointer to the OrtxImageProcessorResult structure to store the result.
|
||||
* @param index The index of the result to retrieve.
|
||||
* @return extError_t The error code indicating the success or failure of the operation.
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor);
|
||||
|
||||
/** \brief Clear the outputs of the processor
|
||||
*
|
||||
* \param processor The processor object
|
||||
* \param result The result object to clear
|
||||
* \return Error code indicating the success or failure of the operation
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result);
|
||||
OrtxTensorResult** result);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -17,19 +17,22 @@ typedef enum {
|
|||
kOrtxKindDetokenizerCache = 0x778B,
|
||||
kOrtxKindProcessor = 0x778C,
|
||||
kOrtxKindRawImages = 0x778D,
|
||||
kOrtxKindImageProcessorResult = 0x778E,
|
||||
kOrtxKindTensorResult = 0x778E,
|
||||
kOrtxKindProcessorResult = 0x778F,
|
||||
kOrtxKindTensor = 0x7790,
|
||||
kOrtxKindFeatureExtractor = 0x7791,
|
||||
kOrtxKindRawAudios = 0x7792,
|
||||
kOrtxKindEnd = 0x7999
|
||||
} extObjectKind_t;
|
||||
|
||||
// all object managed by the library should be 'derived' from this struct
|
||||
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
|
||||
typedef struct {
|
||||
int ext_kind_;
|
||||
extObjectKind_t ext_kind_;
|
||||
} OrtxObject;
|
||||
|
||||
typedef OrtxObject OrtxTensor;
|
||||
typedef OrtxObject OrtxTensorResult;
|
||||
|
||||
// C, instead of C++ doesn't cast automatically,
|
||||
// so we need to use a macro to cast the object to the correct type
|
||||
|
@ -77,6 +80,18 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
|
|||
*/
|
||||
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);
|
||||
|
||||
/**
|
||||
* @brief Retrieves the tensor at the specified index from the given tensor result.
|
||||
*
|
||||
* This function allows you to access a specific tensor from a tensor result object.
|
||||
*
|
||||
* @param result The tensor result object from which to retrieve the tensor.
|
||||
* @param index The index of the tensor to retrieve.
|
||||
* @param tensor A pointer to a variable that will hold the retrieved tensor.
|
||||
* @return An error code indicating the success or failure of the operation.
|
||||
*/
|
||||
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor);
|
||||
|
||||
/** \brief Get the data from the tensor
|
||||
*
|
||||
* \param tensor The tensor object
|
||||
|
|
|
@ -17,7 +17,7 @@ from onnx import numpy_helper
|
|||
from ._ortapi2 import make_onnx_model
|
||||
from ._cuops import SingleOpGraph
|
||||
from ._hf_cvt import HFTokenizerConverter
|
||||
from .util import remove_unused_initializers
|
||||
from .util import remove_unused_initializers, mel_filterbank
|
||||
|
||||
|
||||
class _WhisperHParams:
|
||||
|
@ -30,53 +30,15 @@ class _WhisperHParams:
|
|||
N_FRAMES = N_SAMPLES // HOP_LENGTH
|
||||
|
||||
|
||||
def _mel_filterbank(
|
||||
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
|
||||
"""
|
||||
Compute a Mel-filterbank. The filters are stored in the rows, the columns,
|
||||
and it is Slaney normalized mel-scale filterbank.
|
||||
"""
|
||||
fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype)
|
||||
|
||||
# the centers of the frequency bins for the DFT
|
||||
freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
|
||||
mel = np.linspace(min_mel, max_mel, n_mels + 2)
|
||||
# Fill in the linear scale
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mel
|
||||
|
||||
# And now the nonlinear scale
|
||||
min_log_hz = 1000.0 # beginning of log region (Hz)
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
||||
logstep = np.log(6.4) / 27.0 # step size for log region
|
||||
|
||||
log_t = mel >= min_log_mel
|
||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel))
|
||||
mel_bins = freqs
|
||||
|
||||
mel_spacing = np.diff(mel_bins)
|
||||
|
||||
ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1)
|
||||
for i in range(n_mels):
|
||||
left = -ramps[i] / mel_spacing[i]
|
||||
right = ramps[i + 2] / mel_spacing[i + 1]
|
||||
|
||||
# intersect them with each other and zero
|
||||
fbank[i] = np.maximum(0, np.minimum(left, right))
|
||||
|
||||
energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
|
||||
fbank *= energy_norm[:, np.newaxis]
|
||||
return fbank
|
||||
|
||||
|
||||
class CustomOpStftNorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def symbolic(g, self, n_fft, hop_length, window):
|
||||
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
||||
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
|
||||
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
||||
t_n_fft = g.op('Constant', value_t=torch.tensor(
|
||||
n_fft, dtype=torch.int64))
|
||||
t_hop_length = g.op('Constant', value_t=torch.tensor(
|
||||
hop_length, dtype=torch.int64))
|
||||
t_frame_size = g.op(
|
||||
'Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
||||
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
|
||||
|
||||
@staticmethod
|
||||
|
@ -97,7 +59,7 @@ class WhisperPrePipeline(torch.nn.Module):
|
|||
self.n_fft = n_fft
|
||||
self.window = torch.hann_window(n_fft)
|
||||
self.mel_filters = torch.from_numpy(
|
||||
_mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
|
||||
mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
|
||||
|
||||
def forward(self, audio_pcm: torch.Tensor):
|
||||
stft_norm = CustomOpStftNorm.apply(audio_pcm,
|
||||
|
@ -112,7 +74,8 @@ class WhisperPrePipeline(torch.nn.Module):
|
|||
spec_shape = log_spec.shape
|
||||
padding_spec = torch.ones(spec_shape[0],
|
||||
spec_shape[1],
|
||||
self.n_samples // self.hop_length - spec_shape[2],
|
||||
self.n_samples // self.hop_length -
|
||||
spec_shape[2],
|
||||
dtype=torch.float)
|
||||
padding_spec *= spec_min
|
||||
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
||||
|
@ -165,15 +128,20 @@ def _to_onnx_stft(onnx_model, n_fft):
|
|||
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0',
|
||||
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
|
||||
name='slice_1'),
|
||||
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
|
||||
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
|
||||
make_node('Constant', inputs=[], outputs=[
|
||||
'const0_output_0'], name='const0', value_int=0),
|
||||
make_node('Constant', inputs=[], outputs=[
|
||||
'const1_output_0'], name='const1', value_int=1),
|
||||
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
|
||||
name='gather_4', axis=3),
|
||||
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
|
||||
name='gather_5', axis=3),
|
||||
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
|
||||
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
|
||||
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
|
||||
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=[
|
||||
'mul_output_0'], name='mul0'),
|
||||
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=[
|
||||
'mul_1_output_0'], name='mul1'),
|
||||
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[
|
||||
stft_norm_node.output[0]], name='add0'),
|
||||
]
|
||||
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
|
||||
new_stft_nodes.extend(replaced_nodes)
|
||||
|
@ -253,9 +221,11 @@ class WhisperDataProcGraph:
|
|||
del g.node[:]
|
||||
g.node.extend(nodes)
|
||||
|
||||
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
|
||||
inputs = [onnx.helper.make_tensor_value_info(
|
||||
"sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
|
||||
del g.input[:]
|
||||
g.input.extend(inputs)
|
||||
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
|
||||
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(
|
||||
onnx.TensorProto.STRING, ['N', 'text']))
|
||||
|
||||
return make_onnx_model(g, opset_version=self.opset_version)
|
||||
|
|
|
@ -3,17 +3,15 @@
|
|||
|
||||
#include "ocos.h"
|
||||
#ifdef ENABLE_DR_LIBS
|
||||
#include "audio_decoder.hpp"
|
||||
#include "audio_decoder.h"
|
||||
#endif // ENABLE_DR_LIBS
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []() -> CustomOpArray& {
|
||||
static OrtOpLoader op_loader(
|
||||
[]() { return nullptr; }
|
||||
#ifdef ENABLE_DR_LIBS
|
||||
,
|
||||
CustomCpuStructV2("AudioDecoder", AudioDecoder)
|
||||
CustomCpuStructV2("AudioDecoder", AudioDecoder),
|
||||
#endif
|
||||
);
|
||||
[]() { return nullptr; });
|
||||
|
||||
return op_loader.GetCustomOps();
|
||||
};
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <gsl/util>
|
||||
|
||||
#include "audio_decoder.h"
|
||||
|
||||
#define DR_FLAC_IMPLEMENTATION
|
||||
#include "dr_flac.h"
|
||||
#define DR_MP3_IMPLEMENTATION 1
|
||||
#define DR_MP3_FLOAT_OUTPUT 1
|
||||
#include "dr_mp3.h"
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
|
||||
#include "narrow.h"
|
||||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
#include "sampling.h"
|
||||
|
||||
OrtStatusPtr AudioDecoder::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_);
|
||||
if (!status) {
|
||||
status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
AudioDecoder::AudioStreamType AudioDecoder::ReadStreamFormat(const uint8_t* p_data, const std::string& str_format,
|
||||
OrtxStatus& status) const {
|
||||
const std::map<std::string, AudioStreamType> format_mapping = {{"default", AudioStreamType::kDefault},
|
||||
{"wav", AudioStreamType::kWAV},
|
||||
{"mp3", AudioStreamType::kMP3},
|
||||
{"flac", AudioStreamType::kFLAC}};
|
||||
|
||||
AudioStreamType stream_format = AudioStreamType::kDefault;
|
||||
if (str_format.length() > 0) {
|
||||
auto pos = format_mapping.find(str_format);
|
||||
if (pos == format_mapping.end()) {
|
||||
status = {kOrtxErrorInvalidArgument,
|
||||
MakeString("[AudioDecoder]: Unknown audio stream format: ", str_format).c_str()};
|
||||
return stream_format;
|
||||
}
|
||||
stream_format = pos->second;
|
||||
}
|
||||
|
||||
if (stream_format == AudioStreamType::kDefault) {
|
||||
auto p_stream = reinterpret_cast<char const*>(p_data);
|
||||
std::string_view marker(p_stream, 4);
|
||||
if (marker == "fLaC") {
|
||||
stream_format = AudioStreamType::kFLAC;
|
||||
} else if (marker == "RIFF") {
|
||||
stream_format = AudioStreamType::kWAV;
|
||||
} else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) {
|
||||
// http://www.mp3-tech.org/programmer/frame_header.html
|
||||
// only detect the 8 + 3 bits sync word
|
||||
stream_format = AudioStreamType::kMP3;
|
||||
} else {
|
||||
status = {kOrtxErrorInvalidArgument, "[AudioDecoder]: Cannot detect audio stream format"};
|
||||
}
|
||||
}
|
||||
|
||||
return stream_format;
|
||||
}
|
||||
|
||||
template <typename TY_AUDIO, typename FX_DECODER>
|
||||
static size_t DrReadFrames(std::list<std::vector<float>>& frames, FX_DECODER fx, TY_AUDIO& obj) {
|
||||
const size_t default_chunk_size = 1024 * 256;
|
||||
int64_t total_buf_size = 0;
|
||||
|
||||
for (;;) {
|
||||
std::vector<float> buf;
|
||||
buf.resize(default_chunk_size * obj.channels);
|
||||
auto n_frames = fx(&obj, default_chunk_size, buf.data());
|
||||
if (n_frames <= 0) {
|
||||
break;
|
||||
}
|
||||
auto data_size = n_frames * obj.channels;
|
||||
total_buf_size += data_size;
|
||||
buf.resize(data_size);
|
||||
frames.emplace_back(std::move(buf));
|
||||
}
|
||||
|
||||
return total_buf_size;
|
||||
}
|
||||
|
||||
OrtxStatus AudioDecoder::Compute(const ortc::Tensor<uint8_t>& input, const std::optional<std::string> format,
|
||||
ortc::Tensor<float>& output0) const {
|
||||
const uint8_t* p_data = input.Data();
|
||||
auto input_dim = input.Shape();
|
||||
OrtxStatus status;
|
||||
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
|
||||
return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Expect input dimension [n] or [1,n]."};
|
||||
}
|
||||
|
||||
std::string str_format;
|
||||
if (format) {
|
||||
str_format = *format;
|
||||
}
|
||||
auto stream_format = ReadStreamFormat(p_data, str_format, status);
|
||||
if (status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
int64_t total_buf_size = 0;
|
||||
std::list<std::vector<float>> lst_frames;
|
||||
int64_t orig_sample_rate = 0;
|
||||
int64_t orig_channels = 0;
|
||||
|
||||
if (stream_format == AudioStreamType::kMP3) {
|
||||
auto mp3_obj_ptr = std::make_unique<drmp3>();
|
||||
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) {
|
||||
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on MP3 stream."};
|
||||
return status;
|
||||
}
|
||||
orig_sample_rate = mp3_obj_ptr->sampleRate;
|
||||
orig_channels = mp3_obj_ptr->channels;
|
||||
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
|
||||
|
||||
} else if (stream_format == AudioStreamType::kFLAC) {
|
||||
drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr);
|
||||
auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); });
|
||||
if (flac_obj == nullptr) {
|
||||
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on FLAC stream."};
|
||||
return status;
|
||||
}
|
||||
orig_sample_rate = flac_obj->sampleRate;
|
||||
orig_channels = flac_obj->channels;
|
||||
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
|
||||
|
||||
} else {
|
||||
drwav wav_obj;
|
||||
if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) {
|
||||
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on WAV stream."};
|
||||
return status;
|
||||
}
|
||||
orig_sample_rate = wav_obj.sampleRate;
|
||||
orig_channels = wav_obj.channels;
|
||||
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
|
||||
}
|
||||
|
||||
if (downsample_rate_ != 0 && orig_sample_rate < downsample_rate_) {
|
||||
status = {kOrtxErrorCorruptData, "[AudioDecoder]: only down-sampling supported."};
|
||||
return status;
|
||||
}
|
||||
|
||||
// join all frames
|
||||
std::vector<float> buf;
|
||||
buf.resize(total_buf_size);
|
||||
int64_t offset = 0;
|
||||
for (auto& _b : lst_frames) {
|
||||
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
|
||||
offset += _b.size();
|
||||
}
|
||||
|
||||
// mix the stereo channels into mono channel
|
||||
if (stereo_mixer_ && orig_channels > 1) {
|
||||
if (buf.size() > 1) {
|
||||
for (size_t i = 0; i < buf.size() / 2; ++i) {
|
||||
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
|
||||
}
|
||||
buf.resize(buf.size() / 2);
|
||||
}
|
||||
}
|
||||
|
||||
if (downsample_rate_ != 0 && downsample_rate_ != orig_sample_rate) {
|
||||
// A lowpass filter on buf audio data to remove high frequency noise
|
||||
ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate);
|
||||
std::vector<float> filtered_buf = filter.Process(buf);
|
||||
// downsample the audio data
|
||||
KaiserWindowInterpolation::Process(filtered_buf, buf, 1.0f * orig_sample_rate, 1.0f * downsample_rate_);
|
||||
}
|
||||
|
||||
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
|
||||
float* p_output = output0.Allocate(dim_out);
|
||||
std::copy(buf.begin(), buf.end(), p_output);
|
||||
return status;
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
#include <list>
|
||||
#include <optional>
|
||||
|
||||
struct AudioDecoder {
|
||||
public:
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
||||
template <typename DictT>
|
||||
OrtxStatus Init(const DictT& attrs) {
|
||||
// in API mode, the default value is 1
|
||||
downsample_rate_ = 16000;
|
||||
stereo_mixer_ = 1;
|
||||
for (const auto& [key, value] : attrs) {
|
||||
if (key == "target_sample_rate") {
|
||||
downsample_rate_ = std::get<std::int64_t>(value);
|
||||
} else if (key == "stereo_to_mono") {
|
||||
stereo_mixer_ = std::get<std::int64_t>(value);
|
||||
} else {
|
||||
return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Invalid argument"};
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
enum class AudioStreamType { kDefault = 0, kWAV, kMP3, kFLAC };
|
||||
|
||||
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtxStatus& status) const;
|
||||
OrtxStatus Compute(const ortc::Tensor<uint8_t>& input, const std::optional<std::string> format,
|
||||
ortc::Tensor<float>& output0) const;
|
||||
OrtxStatus ComputeNoOpt(const ortc::Tensor<uint8_t>& input, ortc::Tensor<float>& output0) {
|
||||
return Compute(input, std::nullopt, output0);
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t downsample_rate_{};
|
||||
int64_t stereo_mixer_{};
|
||||
};
|
|
@ -1,206 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#define DR_FLAC_IMPLEMENTATION
|
||||
#include "dr_flac.h"
|
||||
#define DR_MP3_IMPLEMENTATION 1
|
||||
#define DR_MP3_FLOAT_OUTPUT 1
|
||||
#include "dr_mp3.h"
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
|
||||
#include <gsl/util>
|
||||
#include "narrow.h"
|
||||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
#include "sampling.h"
|
||||
|
||||
struct AudioDecoder{
|
||||
public:
|
||||
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_);
|
||||
if (!status) {
|
||||
status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
enum class AudioStreamType {
|
||||
kDefault = 0,
|
||||
kWAV,
|
||||
kMP3,
|
||||
kFLAC
|
||||
};
|
||||
|
||||
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtStatusPtr& status) const {
|
||||
static const std::map<std::string, AudioStreamType> format_mapping = {
|
||||
{"default", AudioStreamType::kDefault},
|
||||
{"wav", AudioStreamType::kWAV},
|
||||
{"mp3", AudioStreamType::kMP3},
|
||||
{"flac", AudioStreamType::kFLAC}};
|
||||
|
||||
AudioStreamType stream_format = AudioStreamType::kDefault;
|
||||
if (str_format.length() > 0) {
|
||||
auto pos = format_mapping.find(str_format);
|
||||
if (pos == format_mapping.end()) {
|
||||
status = OrtW::CreateStatus(MakeString(
|
||||
"[AudioDecoder]: Unknown audio stream format: ", str_format)
|
||||
.c_str(),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
return stream_format;
|
||||
}
|
||||
stream_format = pos->second;
|
||||
}
|
||||
|
||||
if (stream_format == AudioStreamType::kDefault) {
|
||||
auto p_stream = reinterpret_cast<char const*>(p_data);
|
||||
std::string_view marker(p_stream, 4);
|
||||
if (marker == "fLaC") {
|
||||
stream_format = AudioStreamType::kFLAC;
|
||||
} else if (marker == "RIFF") {
|
||||
stream_format = AudioStreamType::kWAV;
|
||||
} else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) {
|
||||
// http://www.mp3-tech.org/programmer/frame_header.html
|
||||
// only detect the 8 + 3 bits sync word
|
||||
stream_format = AudioStreamType::kMP3;
|
||||
} else {
|
||||
status = OrtW::CreateStatus("[AudioDecoder]: Cannot detect audio stream format", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
return stream_format;
|
||||
}
|
||||
|
||||
template <typename TY_AUDIO, typename FX_DECODER>
|
||||
static size_t DrReadFrames(std::list<std::vector<float>>& frames, FX_DECODER fx, TY_AUDIO& obj) {
|
||||
const size_t default_chunk_size = 1024 * 256;
|
||||
int64_t total_buf_size = 0;
|
||||
|
||||
for (;;) {
|
||||
std::vector<float> buf;
|
||||
buf.resize(default_chunk_size * obj.channels);
|
||||
auto n_frames = fx(&obj, default_chunk_size, buf.data());
|
||||
if (n_frames <= 0) {
|
||||
break;
|
||||
}
|
||||
auto data_size = n_frames * obj.channels;
|
||||
total_buf_size += data_size;
|
||||
buf.resize(data_size);
|
||||
frames.emplace_back(std::move(buf));
|
||||
}
|
||||
|
||||
return total_buf_size;
|
||||
}
|
||||
|
||||
OrtStatusPtr Compute(const ortc::Tensor<uint8_t>& input,
|
||||
const std::optional<std::string> format,
|
||||
ortc::Tensor<float>& output0) const {
|
||||
const uint8_t* p_data = input.Data();
|
||||
auto input_dim = input.Shape();
|
||||
OrtStatusPtr status = nullptr;
|
||||
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
|
||||
status = OrtW::CreateStatus("[AudioDecoder]: Expect input dimension [n] or [1,n].", ORT_INVALID_ARGUMENT);
|
||||
return status;
|
||||
}
|
||||
|
||||
std::string str_format;
|
||||
if (format) {
|
||||
str_format = *format;
|
||||
}
|
||||
auto stream_format = ReadStreamFormat(p_data, str_format, status);
|
||||
if (status) {
|
||||
return status;
|
||||
}
|
||||
|
||||
int64_t total_buf_size = 0;
|
||||
std::list<std::vector<float>> lst_frames;
|
||||
int64_t orig_sample_rate = 0;
|
||||
int64_t orig_channels = 0;
|
||||
|
||||
if (stream_format == AudioStreamType::kMP3) {
|
||||
auto mp3_obj_ptr = std::make_unique<drmp3>();
|
||||
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) {
|
||||
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
|
||||
return status;
|
||||
}
|
||||
orig_sample_rate = mp3_obj_ptr->sampleRate;
|
||||
orig_channels = mp3_obj_ptr->channels;
|
||||
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
|
||||
|
||||
} else if (stream_format == AudioStreamType::kFLAC) {
|
||||
drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr);
|
||||
auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); });
|
||||
if (flac_obj == nullptr) {
|
||||
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION);
|
||||
return status;
|
||||
}
|
||||
orig_sample_rate = flac_obj->sampleRate;
|
||||
orig_channels = flac_obj->channels;
|
||||
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
|
||||
|
||||
} else {
|
||||
drwav wav_obj;
|
||||
if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) {
|
||||
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION);
|
||||
return status;
|
||||
}
|
||||
orig_sample_rate = wav_obj.sampleRate;
|
||||
orig_channels = wav_obj.channels;
|
||||
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
|
||||
}
|
||||
|
||||
if (downsample_rate_ != 0 &&
|
||||
orig_sample_rate < downsample_rate_) {
|
||||
status = OrtW::CreateStatus("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT);
|
||||
return status;
|
||||
}
|
||||
|
||||
// join all frames
|
||||
std::vector<float> buf;
|
||||
buf.resize(total_buf_size);
|
||||
int64_t offset = 0;
|
||||
for (auto& _b : lst_frames) {
|
||||
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
|
||||
offset += _b.size();
|
||||
}
|
||||
|
||||
// mix the stereo channels into mono channel
|
||||
if (stereo_mixer_ && orig_channels > 1) {
|
||||
if (buf.size() > 1) {
|
||||
for (size_t i = 0; i < buf.size() / 2; ++i) {
|
||||
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
|
||||
}
|
||||
buf.resize(buf.size() / 2);
|
||||
}
|
||||
}
|
||||
|
||||
if (downsample_rate_ != 0 &&
|
||||
downsample_rate_ != orig_sample_rate) {
|
||||
// A lowpass filter on buf audio data to remove high frequency noise
|
||||
ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate);
|
||||
std::vector<float> filtered_buf = filter.Process(buf);
|
||||
// downsample the audio data
|
||||
KaiserWindowInterpolation::Process(filtered_buf, buf,
|
||||
1.0f * orig_sample_rate, 1.0f * downsample_rate_);
|
||||
}
|
||||
|
||||
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
|
||||
float* p_output = output0.Allocate(dim_out);
|
||||
std::copy(buf.begin(), buf.end(), p_output);
|
||||
return status;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t downsample_rate_{};
|
||||
int64_t stereo_mixer_{};
|
||||
};
|
|
@ -6,37 +6,32 @@
|
|||
#include "ocos.h"
|
||||
#include <dlib/matrix.h>
|
||||
|
||||
struct StftNormal{
|
||||
struct StftNormal {
|
||||
StftNormal() = default;
|
||||
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
return OrtW::GetOpAttribute(info, "onesided", onesided_);
|
||||
}
|
||||
|
||||
OrtStatusPtr Compute(const ortc::Tensor<float>& input0,
|
||||
int64_t n_fft,
|
||||
int64_t hop_length,
|
||||
const ortc::Span<float>& input3,
|
||||
int64_t frame_length,
|
||||
ortc::Tensor<float>& output0) const {
|
||||
OrtxStatus Compute(const ortc::Tensor<float>& input0, int64_t n_fft, int64_t hop_length,
|
||||
const ortc::Span<float>& input3, int64_t frame_length, ortc::Tensor<float>& output0) const {
|
||||
auto X = input0.Data();
|
||||
auto window = input3.data_;
|
||||
auto dimensions = input0.Shape();
|
||||
auto win_length = input3.size();
|
||||
|
||||
if (dimensions.size() < 2 || input0.NumberOfElement() != dimensions[1]) {
|
||||
return OrtW::CreateStatus("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT);
|
||||
return {kOrtxErrorInvalidArgument, "[Stft] Only batch == 1 tensor supported."};
|
||||
}
|
||||
if (frame_length != n_fft) {
|
||||
return OrtW::CreateStatus("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT);
|
||||
return {kOrtxErrorInvalidArgument, "[Stft] Only support size of FFT equals the frame length."};
|
||||
}
|
||||
|
||||
dlib::matrix<float> dm_x = dlib::mat(X, 1, dimensions[1]);
|
||||
dlib::matrix<float> hann_win = dlib::mat(window, 1, win_length);
|
||||
|
||||
auto m_stft = dlib::stft(
|
||||
dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); },
|
||||
n_fft, win_length, hop_length);
|
||||
auto m_stft =
|
||||
dlib::stft(dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); }, n_fft, win_length, hop_length);
|
||||
|
||||
if (onesided_) {
|
||||
m_stft = dlib::subm(m_stft, 0, 0, m_stft.nr(), (m_stft.nc() >> 1) + 1);
|
||||
|
@ -49,7 +44,7 @@ struct StftNormal{
|
|||
auto out0 = output0.Allocate(outdim);
|
||||
memcpy(out0, result.steal_memory().get(), result_size * sizeof(float));
|
||||
|
||||
return nullptr;
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "speech_extractor.h"
|
||||
|
||||
#include "c_api_utils.hpp"
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
class RawAudiosObject : public OrtxObjectImpl {
|
||||
public:
|
||||
RawAudiosObject() : OrtxObjectImpl(extObjectKind_t::kOrtxKindRawAudios) {}
|
||||
~RawAudiosObject() override = default;
|
||||
|
||||
std::unique_ptr<AudioRawData[]> audios_;
|
||||
size_t num_audios_;
|
||||
};
|
||||
|
||||
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** raw_audios, const char* const* audio_paths, size_t num_audios) {
|
||||
if (raw_audios == nullptr || audio_paths == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto audios_obj = std::make_unique<RawAudiosObject>();
|
||||
auto [audios, num] =
|
||||
ort_extensions::LoadRawData<char const* const*, AudioRawData>(audio_paths, audio_paths + num_audios);
|
||||
audios_obj->audios_ = std::move(audios);
|
||||
audios_obj->num_audios_ = num;
|
||||
|
||||
*raw_audios = static_cast<OrtxRawAudios*>(audios_obj.release());
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* def) {
|
||||
if (extractor == nullptr || def == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto extractor_ptr = std::make_unique<SpeechFeatureExtractor>();
|
||||
ReturnableStatus status = extractor_ptr->Init(def);
|
||||
if (status.IsOk()) {
|
||||
*extractor = static_cast<OrtxFeatureExtractor*>(extractor_ptr.release());
|
||||
} else {
|
||||
*extractor = nullptr;
|
||||
}
|
||||
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* raw_audios,
|
||||
OrtxTensorResult** result) {
|
||||
if (extractor == nullptr || raw_audios == nullptr || result == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto extractor_ptr = static_cast<SpeechFeatureExtractor*>(extractor);
|
||||
auto audios_obj = static_cast<RawAudiosObject*>(raw_audios);
|
||||
|
||||
auto ts_result = std::make_unique<TensorResult>();
|
||||
std::unique_ptr<ortc::Tensor<float>> log_mel[1];
|
||||
ReturnableStatus status =
|
||||
extractor_ptr->DoCall(ort_extensions::span(audios_obj->audios_.get(), audios_obj->num_audios_), log_mel[0]);
|
||||
if (status.IsOk()) {
|
||||
std::vector<std::unique_ptr<ortc::TensorBase>> tensors;
|
||||
std::transform(log_mel, log_mel + 1, std::back_inserter(tensors),
|
||||
[](auto& ts) { return std::unique_ptr<ortc::TensorBase>(ts.release()); });
|
||||
ts_result->SetTensors(std::move(tensors));
|
||||
*result = ts_result.release();
|
||||
} else {
|
||||
*result = nullptr;
|
||||
}
|
||||
|
||||
return status.Code();
|
||||
}
|
|
@ -4,6 +4,8 @@
|
|||
#include "ortx_processor.h"
|
||||
#include "image_processor.h"
|
||||
|
||||
#include "c_api_utils.hpp"
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
extError_t OrtxCreateProcessor(OrtxProcessor** processor, const char* def) {
|
||||
|
@ -37,19 +39,19 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima
|
|||
}
|
||||
|
||||
auto images_obj = std::make_unique<RawImagesObject>();
|
||||
auto [img, num] = LoadRawImages(image_paths, image_paths + num_images);
|
||||
auto [img, num] = LoadRawData<char const**, ImageRawData>(image_paths, image_paths + num_images);
|
||||
images_obj->images = std::move(img);
|
||||
images_obj->num_images = num;
|
||||
if (num_images_loaded != nullptr) {
|
||||
*num_images_loaded = num;
|
||||
}
|
||||
|
||||
|
||||
*images = static_cast<OrtxRawImages*>(images_obj.release());
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
|
||||
OrtxImageProcessorResult** result) {
|
||||
OrtxTensorResult** result) {
|
||||
if (processor == nullptr || images == nullptr || result == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
|
@ -67,59 +69,14 @@ extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawIm
|
|||
return status.Code();
|
||||
}
|
||||
|
||||
auto result_ptr = std::make_unique<ImageProcessorResult>();
|
||||
auto result_ptr = std::make_unique<TensorResult>();
|
||||
status =
|
||||
processor_ptr->PreProcess(ort_extensions::span(images_ptr->images.get(), images_ptr->num_images), *result_ptr);
|
||||
if (status.IsOk()) {
|
||||
*result = static_cast<OrtxImageProcessorResult*>(result_ptr.release());
|
||||
*result = static_cast<OrtxTensorResult*>(result_ptr.release());
|
||||
} else {
|
||||
*result = nullptr;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor) {
|
||||
if (result == nullptr || tensor == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto result_ptr = static_cast<ImageProcessorResult*>(result);
|
||||
ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
if (index >= result_ptr->results.size()) {
|
||||
ReturnableStatus::last_error_message_ = "Index out of range";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto tensor_ptr = std::make_unique<OrtxObjectWrapper<ortc::TensorBase>>();
|
||||
tensor_ptr->SetObject(result_ptr->results[index].get());
|
||||
*tensor = static_cast<OrtxTensor*>(tensor_ptr.release());
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result) {
|
||||
if (processor == nullptr || result == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
const auto processor_ptr = static_cast<const ImageProcessor*>(processor);
|
||||
ReturnableStatus status(processor_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindProcessor));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
auto result_ptr = static_cast<ImageProcessorResult*>(result);
|
||||
status = result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult);
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
ImageProcessor::ClearOutputs(result_ptr);
|
||||
return extError_t();
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
#include "c_api_utils.hpp"
|
||||
#include "tokenizer_impl.h"
|
||||
|
||||
namespace ort_extensions {
|
||||
using namespace ort_extensions;
|
||||
|
||||
class DetokenizerCache : public OrtxObjectImpl {
|
||||
public:
|
||||
|
@ -17,29 +17,20 @@ class DetokenizerCache : public OrtxObjectImpl {
|
|||
std::string last_text_{}; // last detokenized text
|
||||
};
|
||||
|
||||
template<>
|
||||
OrtxObject* OrtxObjectFactory<DetokenizerCache>::CreateForward() {
|
||||
return std::make_unique<DetokenizerCache>().release();
|
||||
template <>
|
||||
OrtxObject* OrtxObjectFactory::CreateForward<DetokenizerCache>() {
|
||||
return Create<DetokenizerCache>();
|
||||
}
|
||||
|
||||
template<>
|
||||
void OrtxObjectFactory<DetokenizerCache>::DisposeForward(OrtxObject* obj) {
|
||||
Dispose(obj);
|
||||
}
|
||||
} // namespace ort_extensions
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer,
|
||||
const char* input[], size_t batch_size, OrtxTokenId2DArray** output) {
|
||||
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size,
|
||||
OrtxTokenId2DArray** output) {
|
||||
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
|
||||
ReturnableStatus status =
|
||||
token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
|
||||
ReturnableStatus status = token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
@ -61,8 +52,8 @@ extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer,
|
|||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
|
||||
const OrtxTokenId2DArray* input, OrtxStringArray** output) {
|
||||
extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input,
|
||||
OrtxStringArray** output) {
|
||||
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
|
@ -81,11 +72,8 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
|
|||
}
|
||||
|
||||
std::vector<span<extTokenId_t const>> t_ids;
|
||||
std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(),
|
||||
std::back_inserter(t_ids),
|
||||
[](const std::vector<extTokenId_t>& vec) {
|
||||
return span<extTokenId_t const>(vec.data(), vec.size());
|
||||
});
|
||||
std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(), std::back_inserter(t_ids),
|
||||
[](const std::vector<extTokenId_t>& vec) { return span<extTokenId_t const>(vec.data(), vec.size()); });
|
||||
|
||||
std::vector<std::string> output_text;
|
||||
status = token_ptr->Detokenize(t_ids, output_text);
|
||||
|
@ -101,9 +89,7 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
|
|||
;
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer,
|
||||
const extTokenId_t* input,
|
||||
size_t len,
|
||||
extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len,
|
||||
OrtxStringArray** output) {
|
||||
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
|
@ -186,8 +172,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* to
|
|||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array,
|
||||
size_t index, const extTokenId_t** item, size_t* length) {
|
||||
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array, size_t index,
|
||||
const extTokenId_t** item, size_t* length) {
|
||||
if (token_id_2d_array == nullptr || item == nullptr || length == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
|
@ -210,9 +196,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* tok
|
|||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer,
|
||||
OrtxDetokenizerCache* cache,
|
||||
extTokenId_t next_id, const char** text_out) {
|
||||
extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, OrtxDetokenizerCache* cache, extTokenId_t next_id,
|
||||
const char** text_out) {
|
||||
if (tokenizer == nullptr || cache == nullptr || text_out == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
|
|
|
@ -10,6 +10,8 @@
|
|||
|
||||
using namespace ort_extensions;
|
||||
|
||||
class DetokenizerCache; // forward definition in tokenizer_impl.cc
|
||||
|
||||
thread_local std::string ReturnableStatus::last_error_message_;
|
||||
|
||||
OrtxStatus OrtxObjectImpl::IsInstanceOf(extObjectKind_t kind) const {
|
||||
|
@ -37,7 +39,7 @@ extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, .
|
|||
va_start(args, object);
|
||||
|
||||
if (kind == extObjectKind_t::kOrtxKindDetokenizerCache) {
|
||||
*object = OrtxObjectFactory<DetokenizerCache>::CreateForward();
|
||||
*object = OrtxObjectFactory::CreateForward<DetokenizerCache>();
|
||||
} else if (kind == extObjectKind_t::kOrtxKindTokenizer) {
|
||||
return OrtxCreateTokenizer(static_cast<OrtxTokenizer**>(object), va_arg(args, const char*));
|
||||
}
|
||||
|
@ -80,8 +82,8 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) {
|
|||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) {
|
||||
OrtxObjectFactory<StringArray>::Dispose(object);
|
||||
/* if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) {
|
||||
OrtxObjectFactory::Dispose<StringArray>(object);
|
||||
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindTokenId2DArray) {
|
||||
OrtxObjectFactory<TokenId2DArray>::Dispose(object);
|
||||
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindDetokenizerCache) {
|
||||
|
@ -94,6 +96,11 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) {
|
|||
OrtxObjectFactory<ImageProcessorResult>::Dispose(object);
|
||||
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindProcessor) {
|
||||
OrtxObjectFactory<ImageProcessor>::Dispose(object);
|
||||
} */
|
||||
if (Ortx_object->ortx_kind() >= kOrtxKindBegin && Ortx_object->ortx_kind() < kOrtxKindEnd) {
|
||||
OrtxObjectFactory::Dispose<OrtxObjectImpl>(object);
|
||||
} else {
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
return extError_t();
|
||||
|
@ -113,6 +120,30 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object) {
|
|||
return err;
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor) {
|
||||
if (result == nullptr || tensor == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto result_ptr = static_cast<TensorResult*>(result);
|
||||
ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTensorResult));
|
||||
if (!status.IsOk()) {
|
||||
return status.Code();
|
||||
}
|
||||
|
||||
ortc::TensorBase* ts = result_ptr->GetAt(index);
|
||||
if (ts == nullptr) {
|
||||
ReturnableStatus::last_error_message_ = "Cannot get the tensor at the specified index from the result";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto tensor_ptr = std::make_unique<OrtxObjectWrapper<ortc::TensorBase, kOrtxKindTensor>>();
|
||||
tensor_ptr->SetObject(ts);
|
||||
*tensor = static_cast<OrtxTensor*>(tensor_ptr.release());
|
||||
return extError_t();
|
||||
}
|
||||
|
||||
extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data, const int64_t** shape,
|
||||
size_t* num_dims) {
|
||||
if (tensor == nullptr) {
|
||||
|
@ -120,7 +151,7 @@ extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data
|
|||
return kOrtxErrorInvalidArgument;
|
||||
}
|
||||
|
||||
auto tensor_impl = static_cast<OrtxObjectWrapper<ortc::TensorBase>*>(tensor);
|
||||
auto tensor_impl = static_cast<OrtxObjectWrapper<ortc::TensorBase, kOrtxKindTensor>*>(tensor);
|
||||
if (tensor_impl->ortx_kind() != extObjectKind_t::kOrtxKindTensor) {
|
||||
ReturnableStatus::last_error_message_ = "Invalid argument";
|
||||
return kOrtxErrorInvalidArgument;
|
||||
|
|
|
@ -3,8 +3,10 @@
|
|||
|
||||
#pragma once
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
#include "ortx_utils.h"
|
||||
#include "file_sys.h"
|
||||
#include "ext_status.h"
|
||||
#include "op_def_struct.h"
|
||||
|
||||
|
@ -12,7 +14,7 @@ namespace ort_extensions {
|
|||
class OrtxObjectImpl : public OrtxObject {
|
||||
public:
|
||||
explicit OrtxObjectImpl(extObjectKind_t kind = extObjectKind_t::kOrtxKindUnknown) : OrtxObject() {
|
||||
ext_kind_ = static_cast<int>(kind);
|
||||
ext_kind_ = kind;
|
||||
};
|
||||
virtual ~OrtxObjectImpl() = default;
|
||||
|
||||
|
@ -24,30 +26,21 @@ class OrtxObjectImpl : public OrtxObject {
|
|||
}
|
||||
return static_cast<extObjectKind_t>(ext_kind_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Type2Kind {
|
||||
static const extObjectKind_t value = kOrtxKindUnknown;
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OrtxObjectImpl::Type2Kind<ortc::TensorBase> {
|
||||
static const extObjectKind_t value = kOrtxKindTensor;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
// A wrapper class to store a object pointer which is readonly. i.e. unowned.
|
||||
template <typename T, extObjectKind_t kind>
|
||||
class OrtxObjectWrapper : public OrtxObjectImpl {
|
||||
public:
|
||||
OrtxObjectWrapper() : OrtxObjectImpl(OrtxObjectImpl::Type2Kind<T>::value) {}
|
||||
OrtxObjectWrapper() : OrtxObjectImpl(kind) {}
|
||||
~OrtxObjectWrapper() override = default;
|
||||
|
||||
void SetObject(T* t) { stored_object_ = t; }
|
||||
void SetObject(const T* t) { stored_object_ = t; }
|
||||
|
||||
[[nodiscard]] T* GetObject() const { return stored_object_; }
|
||||
[[nodiscard]] const T* GetObject() const { return stored_object_; }
|
||||
|
||||
private:
|
||||
T* stored_object_{};
|
||||
const T* stored_object_{};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -100,6 +93,35 @@ class StringArray : public OrtxObjectImpl {
|
|||
std::vector<std::string> strings_;
|
||||
};
|
||||
|
||||
class TensorResult : public OrtxObjectImpl {
|
||||
public:
|
||||
TensorResult() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTensorResult) {}
|
||||
~TensorResult() override = default;
|
||||
|
||||
void SetTensors(std::vector<std::unique_ptr<ortc::TensorBase>>&& tensors) { tensors_ = std::move(tensors); }
|
||||
|
||||
[[nodiscard]] const std::vector<std::unique_ptr<ortc::TensorBase>>& tensors() const { return tensors_; }
|
||||
|
||||
[[nodiscard]] std::vector<ortc::TensorBase*> GetTensors() const {
|
||||
std::vector<ortc::TensorBase*> ts;
|
||||
ts.reserve(tensors_.size());
|
||||
for (auto& t : tensors_) {
|
||||
ts.push_back(t.get());
|
||||
}
|
||||
return ts;
|
||||
}
|
||||
|
||||
ortc::TensorBase* GetAt(size_t i) const {
|
||||
if (i < tensors_.size()) {
|
||||
return tensors_[i].get();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<ortc::TensorBase>> tensors_;
|
||||
};
|
||||
|
||||
struct ReturnableStatus {
|
||||
public:
|
||||
thread_local static std::string last_error_message_;
|
||||
|
@ -123,25 +145,26 @@ struct ReturnableStatus {
|
|||
OrtxStatus status_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class OrtxObjectFactory {
|
||||
public:
|
||||
static std::unique_ptr<T> Create() { return std::make_unique<T>(); }
|
||||
|
||||
static OrtxObject* CreateForward();
|
||||
static void DisposeForward(OrtxObject* object);
|
||||
template <typename T>
|
||||
static OrtxObject* Create() {
|
||||
return std::make_unique<T>().release();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void Dispose(OrtxObject* object) {
|
||||
auto obj_ptr = static_cast<T*>(object);
|
||||
std::unique_ptr<T> ptr(obj_ptr);
|
||||
ptr.reset();
|
||||
}
|
||||
|
||||
// Forward declaration for creating an object which isn't visible to c_api_utils.cc
|
||||
// and the definition is in the corresponding .cc file.
|
||||
template <typename T>
|
||||
static OrtxObject* CreateForward();
|
||||
};
|
||||
|
||||
class DetokenizerCache; // forward definition in tokenizer_impl.cc
|
||||
class ProcessorResult; // forward definition in image_processor.h
|
||||
|
||||
|
||||
class CppAllocator : public ortc::IAllocator {
|
||||
public:
|
||||
void* Alloc(size_t size) override { return std::make_unique<char[]>(size).release(); }
|
||||
|
@ -157,4 +180,25 @@ class CppAllocator : public ortc::IAllocator {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename It, typename T>
|
||||
std::tuple<std::unique_ptr<T[]>, size_t> LoadRawData(It begin, It end) {
|
||||
auto raw_data = std::make_unique<T[]>(end - begin);
|
||||
size_t n = 0;
|
||||
for (auto it = begin; it != end; ++it) {
|
||||
std::ifstream ifs = path(*it).open(std::ios::binary | std::ios::in);
|
||||
if (!ifs.is_open()) {
|
||||
break;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
|
||||
T& datum = raw_data[n++];
|
||||
datum.resize(size);
|
||||
ifs.read(reinterpret_cast<char*>(datum.data()), size);
|
||||
}
|
||||
|
||||
return std::make_tuple(std::move(raw_data), n);
|
||||
}
|
||||
} // namespace ort_extensions
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include "file_sys.h"
|
||||
|
||||
#include "image_processor.h"
|
||||
#include "c_api_utils.hpp"
|
||||
#include "cv2/imgcodecs/imdecode.hpp"
|
||||
#include "image_transforms.hpp"
|
||||
#include "image_transforms_phi_3.hpp"
|
||||
|
@ -14,38 +15,11 @@
|
|||
using namespace ort_extensions;
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace ort_extensions {
|
||||
template <typename It>
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(It begin, It end) {
|
||||
auto raw_images = std::make_unique<ImageRawData[]>(end - begin);
|
||||
size_t n = 0;
|
||||
for (auto it = begin; it != end; ++it) {
|
||||
std::ifstream ifs = path(*it).open(std::ios::binary);
|
||||
if (!ifs.is_open()) {
|
||||
break;
|
||||
}
|
||||
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
|
||||
ImageRawData& raw_image = raw_images[n++];
|
||||
raw_image.resize(size);
|
||||
ifs.read(reinterpret_cast<char*>(raw_image.data()), size);
|
||||
}
|
||||
|
||||
return std::make_tuple(std::move(raw_images), n);
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t>
|
||||
ort_extensions::LoadRawImages(const std::initializer_list<const char*>& image_paths) {
|
||||
return ort_extensions::LoadRawData<const char* const*, ImageRawData>(image_paths.begin(), image_paths.end());
|
||||
}
|
||||
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
|
||||
const std::initializer_list<const char*>& image_paths) {
|
||||
return LoadRawImages(image_paths.begin(), image_paths.end());
|
||||
}
|
||||
|
||||
template std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages<char const**>(char const**, char const**);
|
||||
|
||||
} // namespace ort_extensions
|
||||
|
||||
Operation::KernelRegistry ImageProcessor::kernel_registry_ = {
|
||||
{"DecodeImage", []() { return CreateKernelInstance(image_decoder); }},
|
||||
{"Resize", []() { return CreateKernelInstance(&Resize::Compute); }},
|
||||
|
@ -97,9 +71,7 @@ OrtxStatus ImageProcessor::Init(std::string_view processor_def) {
|
|||
return {};
|
||||
}
|
||||
|
||||
ImageProcessor::ImageProcessor()
|
||||
: OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) {
|
||||
}
|
||||
ImageProcessor::ImageProcessor() : OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) {}
|
||||
|
||||
template <typename T>
|
||||
static ortc::Tensor<T>* StackTensor(const std::vector<TensorArgs>& arg_lists, int axis, ortc::IAllocator* allocator) {
|
||||
|
@ -136,39 +108,6 @@ static ortc::Tensor<T>* StackTensor(const std::vector<TensorArgs>& arg_lists, in
|
|||
return output.release();
|
||||
}
|
||||
|
||||
static OrtxStatus StackTensors(const std::vector<TensorArgs>& arg_lists, std::vector<TensorPtr>& outputs,
|
||||
ortc::IAllocator* allocator) {
|
||||
if (arg_lists.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
size_t batch_size = arg_lists.size();
|
||||
size_t num_outputs = arg_lists[0].size();
|
||||
for (size_t axis = 0; axis < num_outputs; ++axis) {
|
||||
std::vector<ortc::TensorBase*> ts_ptrs;
|
||||
ts_ptrs.reserve(arg_lists.size());
|
||||
std::vector<int64_t> shape = arg_lists[0][axis]->Shape();
|
||||
for (auto& ts : arg_lists) {
|
||||
if (shape != ts[axis]->Shape()) {
|
||||
return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."};
|
||||
}
|
||||
ts_ptrs.push_back(ts[axis]);
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_shape = shape;
|
||||
output_shape.insert(output_shape.begin(), batch_size);
|
||||
std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape);
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
auto ts = ts_ptrs[i];
|
||||
const std::byte* ts_buff = reinterpret_cast<const std::byte*>(ts->DataRaw());
|
||||
auto ts_size = ts->SizeInBytes();
|
||||
std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size);
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::tuple<OrtxStatus, ProcessorResult> ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_data,
|
||||
ortc::Tensor<float>** pixel_values,
|
||||
ortc::Tensor<int64_t>** image_sizes,
|
||||
|
@ -209,7 +148,7 @@ std::tuple<OrtxStatus, ProcessorResult> ImageProcessor::PreProcess(ort_extension
|
|||
return {status, std::move(r)};
|
||||
}
|
||||
|
||||
OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_data, ImageProcessorResult& r) const {
|
||||
OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_data, TensorResult& r) const {
|
||||
std::vector<TensorArgs> inputs;
|
||||
inputs.resize(image_data.size());
|
||||
for (size_t i = 0; i < image_data.size(); ++i) {
|
||||
|
@ -235,9 +174,13 @@ OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_d
|
|||
}
|
||||
}
|
||||
|
||||
r.results = operations_.back()->AllocateOutputs(allocator_);
|
||||
status = StackTensors(outputs, r.results, allocator_);
|
||||
auto img_result = operations_.back()->AllocateOutputs(allocator_);
|
||||
status = OrtxRunner::StackTensors(outputs, img_result, allocator_);
|
||||
operations_.back()->ResetTensors(allocator_);
|
||||
if (status.IsOk()) {
|
||||
r.SetTensors(std::move(img_result));
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -257,14 +200,3 @@ void ImageProcessor::ClearOutputs(ProcessorResult* r) {
|
|||
r->num_img_takens = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void ort_extensions::ImageProcessor::ClearOutputs(ImageProcessorResult* r) {
|
||||
if (r == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto& ts : r->results) {
|
||||
ts.reset();
|
||||
}
|
||||
r->results.clear(); // clear the vector
|
||||
}
|
||||
|
|
|
@ -16,9 +16,6 @@ namespace ort_extensions {
|
|||
|
||||
using ImageRawData = std::vector<uint8_t>;
|
||||
|
||||
template <typename It>
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(It begin, It end);
|
||||
|
||||
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
|
||||
const std::initializer_list<const char*>& image_paths);
|
||||
|
||||
|
@ -29,13 +26,6 @@ class ProcessorResult : public OrtxObjectImpl {
|
|||
ortc::Tensor<int64_t>* image_sizes{};
|
||||
ortc::Tensor<int64_t>* num_img_takens{};
|
||||
};
|
||||
|
||||
class ImageProcessorResult : public OrtxObjectImpl {
|
||||
public:
|
||||
ImageProcessorResult() : OrtxObjectImpl(kOrtxKindImageProcessorResult) {}
|
||||
std::vector<TensorPtr> results;
|
||||
};
|
||||
|
||||
class ImageProcessor : public OrtxObjectImpl {
|
||||
public:
|
||||
ImageProcessor();
|
||||
|
@ -43,15 +33,16 @@ class ImageProcessor : public OrtxObjectImpl {
|
|||
|
||||
OrtxStatus Init(std::string_view processor_def);
|
||||
|
||||
// Deprecated, using the next function instead
|
||||
std::tuple<OrtxStatus, ProcessorResult> PreProcess(ort_extensions::span<ImageRawData> image_data,
|
||||
ortc::Tensor<float>** pixel_values,
|
||||
ortc::Tensor<int64_t>** image_sizes,
|
||||
ortc::Tensor<int64_t>** num_img_takens) const;
|
||||
|
||||
OrtxStatus PreProcess(ort_extensions::span<ImageRawData> image_data, ImageProcessorResult& r) const;
|
||||
OrtxStatus PreProcess(ort_extensions::span<ImageRawData> image_data, TensorResult& r) const;
|
||||
|
||||
// Deprecated, using the next function instead
|
||||
static void ClearOutputs(ProcessorResult* r);
|
||||
static void ClearOutputs(ImageProcessorResult* r);
|
||||
|
||||
static Operation::KernelRegistry kernel_registry_;
|
||||
|
||||
|
|
|
@ -28,7 +28,8 @@ class KernelDef {
|
|||
virtual TensorArgs AllocateOutput(ortc::IAllocator* allocator) const = 0;
|
||||
virtual OrtxStatus Apply(TensorArgs& inputs, TensorArgs& output) const = 0;
|
||||
|
||||
using AttrType = std::variant<std::string, double, int64_t, std::vector<double>>;
|
||||
using AttrType =
|
||||
std::variant<std::string, double, int64_t, std::vector<std::string>, std::vector<double>, std::vector<int64_t>>;
|
||||
using AttrDict = std::unordered_map<std::string, AttrType>;
|
||||
|
||||
template <typename... Args>
|
||||
|
@ -98,7 +99,7 @@ class KernelDef {
|
|||
template <typename... Args>
|
||||
class KernelFunction : public KernelDef {
|
||||
public:
|
||||
KernelFunction(OrtxStatus (*body)(Args...)) : body_(body){};
|
||||
KernelFunction(OrtxStatus (*body)(Args...)) : body_(body) {};
|
||||
virtual ~KernelFunction() = default;
|
||||
|
||||
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
|
||||
|
@ -132,7 +133,7 @@ class KernelFunction : public KernelDef {
|
|||
template <typename T, typename... Args>
|
||||
class KernelStruct : public KernelDef {
|
||||
public:
|
||||
KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body){};
|
||||
KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body) {};
|
||||
virtual ~KernelStruct() = default;
|
||||
|
||||
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
|
||||
|
@ -167,8 +168,18 @@ class KernelStruct : public KernelDef {
|
|||
attr_dict[key] = value.template get<int64_t>();
|
||||
} else if (value.is_number_float()) {
|
||||
attr_dict[key] = value.template get<double>();
|
||||
} else if (value.is_array()) {
|
||||
attr_dict[key] = value.template get<std::vector<double>>();
|
||||
} else if (value.is_array() && value.size() > 0) {
|
||||
auto& elem_0 = value.at(0);
|
||||
if (elem_0.is_number_float()) {
|
||||
attr_dict[key] = value.template get<std::vector<double>>();
|
||||
} else if (elem_0.is_string()) {
|
||||
attr_dict[key] = value.template get<std::vector<std::string>>();
|
||||
} else if (elem_0.is_number_integer() || elem_0.is_number_unsigned()) {
|
||||
attr_dict[key] = value.template get<std::vector<int64_t>>();
|
||||
} else {
|
||||
return {kOrtxErrorCorruptData, "Unsupported mix types in attribute value."};
|
||||
}
|
||||
|
||||
} else {
|
||||
return {kOrtxErrorCorruptData, "Invalid attribute type."};
|
||||
}
|
||||
|
@ -309,6 +320,39 @@ class OrtxRunner {
|
|||
return {};
|
||||
}
|
||||
|
||||
static OrtxStatus StackTensors(const std::vector<TensorArgs>& arg_lists, std::vector<TensorPtr>& outputs,
|
||||
ortc::IAllocator* allocator) {
|
||||
if (arg_lists.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
size_t batch_size = arg_lists.size();
|
||||
size_t num_outputs = arg_lists[0].size();
|
||||
for (size_t axis = 0; axis < num_outputs; ++axis) {
|
||||
std::vector<ortc::TensorBase*> ts_ptrs;
|
||||
ts_ptrs.reserve(arg_lists.size());
|
||||
std::vector<int64_t> shape = arg_lists[0][axis]->Shape();
|
||||
for (auto& ts : arg_lists) {
|
||||
if (shape != ts[axis]->Shape()) {
|
||||
return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."};
|
||||
}
|
||||
ts_ptrs.push_back(ts[axis]);
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_shape = shape;
|
||||
output_shape.insert(output_shape.begin(), batch_size);
|
||||
std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape);
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
auto ts = ts_ptrs[i];
|
||||
const std::byte* ts_buff = reinterpret_cast<const std::byte*>(ts->DataRaw());
|
||||
auto ts_size = ts->SizeInBytes();
|
||||
std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size);
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
private:
|
||||
ortc::IAllocator* allocator_;
|
||||
std::vector<Operation*> ops_;
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "speech_extractor.h"
|
||||
|
||||
#include "audio/audio_decoder.h"
|
||||
#include "speech_features.hpp"
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
Operation::KernelRegistry SpeechFeatureExtractor::kernel_registry_ = {
|
||||
{"AudioDecoder", []() { return CreateKernelInstance(&AudioDecoder::ComputeNoOpt); }},
|
||||
{"STFTNorm", []() { return CreateKernelInstance(&SpeechFeatures::STFTNorm); }},
|
||||
{"LogMelSpectrum", []() { return CreateKernelInstance(&LogMel::Compute); }},
|
||||
};
|
||||
|
||||
SpeechFeatureExtractor::SpeechFeatureExtractor()
|
||||
: OrtxObjectImpl(extObjectKind_t::kOrtxKindFeatureExtractor), allocator_(&CppAllocator::Instance()) {}
|
||||
|
||||
OrtxStatus SpeechFeatureExtractor::Init(std::string_view extractor_def) {
|
||||
std::string fe_def_str;
|
||||
if (extractor_def.size() >= 5 && extractor_def.substr(extractor_def.size() - 5) == ".json") {
|
||||
std::ifstream ifs = path({extractor_def.data(), extractor_def.size()}).open();
|
||||
if (!ifs.is_open()) {
|
||||
return {kOrtxErrorInvalidArgument, std::string("[ImageProcessor]: failed to open ") + std::string(extractor_def)};
|
||||
}
|
||||
fe_def_str = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
|
||||
extractor_def = fe_def_str.c_str();
|
||||
}
|
||||
|
||||
// pase the extraction_def by json
|
||||
auto fe_json = json::parse(extractor_def, nullptr, false);
|
||||
if (fe_json.is_discarded()) {
|
||||
return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: failed to parse extractor json configuration."};
|
||||
}
|
||||
|
||||
auto fe_root = fe_json.at("feature_extraction");
|
||||
if (!fe_root.is_object()) {
|
||||
return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: feature_extraction field is missing."};
|
||||
}
|
||||
|
||||
auto op_sequence = fe_root.at("sequence");
|
||||
if (!op_sequence.is_array() || op_sequence.empty()) {
|
||||
return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: sequence field is missing."};
|
||||
}
|
||||
|
||||
operations_.reserve(op_sequence.size());
|
||||
for (auto mod_iter = op_sequence.begin(); mod_iter != op_sequence.end(); ++mod_iter) {
|
||||
auto op = std::make_unique<Operation>(kernel_registry_);
|
||||
auto status = op->Init(mod_iter->dump());
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
operations_.push_back(std::move(op));
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus SpeechFeatureExtractor::DoCall(ort_extensions::span<AudioRawData> raw_speech,
|
||||
std::unique_ptr<ortc::Tensor<float>>& log_mel) const {
|
||||
// setup the input tensors
|
||||
std::vector<TensorArgs> inputs;
|
||||
inputs.resize(raw_speech.size());
|
||||
for (size_t i = 0; i < raw_speech.size(); ++i) {
|
||||
auto& ts_input = inputs[i];
|
||||
AudioRawData& speech = raw_speech[i];
|
||||
std::vector<int64_t> shape = {static_cast<int64_t>(speech.size())};
|
||||
ts_input.push_back(std::make_unique<ortc::Tensor<uint8_t>>(shape, speech.data()).release());
|
||||
}
|
||||
|
||||
std::vector<TensorArgs> outputs;
|
||||
std::vector<Operation*> ops(operations_.size());
|
||||
std::transform(operations_.begin(), operations_.end(), ops.begin(), [](auto& op) { return op.get(); });
|
||||
OrtxRunner runner(allocator_, ops.data(), ops.size());
|
||||
auto status = runner.Run(inputs, outputs);
|
||||
if (!status.IsOk()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// clear the input tensors
|
||||
for (auto& input : inputs) {
|
||||
for (auto& ts : input) {
|
||||
std::unique_ptr<ortc::TensorBase>(ts).reset();
|
||||
}
|
||||
}
|
||||
|
||||
auto results = operations_.back()->AllocateOutputs(allocator_);
|
||||
status = OrtxRunner::StackTensors(outputs, results, allocator_);
|
||||
if (status.IsOk()) {
|
||||
log_mel.reset(static_cast<ortc::Tensor<float>*>(results[0].release()));
|
||||
operations_.back()->ResetTensors(allocator_);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ortx_extractor.h"
|
||||
#include "c_api_utils.hpp"
|
||||
#include "runner.hpp"
|
||||
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
typedef std::vector<std::byte> AudioRawData;
|
||||
|
||||
class SpeechFeatureExtractor : public OrtxObjectImpl {
|
||||
public:
|
||||
SpeechFeatureExtractor();
|
||||
|
||||
virtual ~SpeechFeatureExtractor() = default;
|
||||
|
||||
public:
|
||||
OrtxStatus Init(std::string_view extractor_def);
|
||||
|
||||
OrtxStatus DoCall(ort_extensions::span<AudioRawData> raw_speech, std::unique_ptr<ortc::Tensor<float>>& log_mel) const;
|
||||
|
||||
static Operation::KernelRegistry kernel_registry_;
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Operation>> operations_;
|
||||
ortc::IAllocator* allocator_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,224 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <dlib/matrix.h>
|
||||
#include <math/dlib/stft_norm.hpp>
|
||||
|
||||
#ifndef M_PI
|
||||
#define M_PI 3.14159265358979323846
|
||||
#endif
|
||||
|
||||
namespace ort_extensions {
|
||||
|
||||
class SpeechFeatures {
|
||||
public:
|
||||
template <typename DictT>
|
||||
OrtxStatus Init(const DictT& attrs) {
|
||||
for (const auto& [key, value] : attrs) {
|
||||
if (key == "n_fft") {
|
||||
n_fft_ = std::get<int64_t>(value);
|
||||
} else if (key == "hop_length") {
|
||||
hop_length_ = std::get<int64_t>(value);
|
||||
} else if (key == "frame_length") {
|
||||
frame_length_ = std::get<int64_t>(value);
|
||||
} else if (key == "hann_win") {
|
||||
auto& win = std::get<std::vector<double>>(value);
|
||||
hann_win_.resize(win.size());
|
||||
std::transform(win.begin(), win.end(), hann_win_.begin(), [](double x) { return static_cast<float>(x); });
|
||||
} else if (key != "_comment") {
|
||||
return {kOrtxErrorInvalidArgument, "[AudioFeatures]: Invalid key in the JSON configuration."};
|
||||
}
|
||||
}
|
||||
|
||||
if (hann_win_.empty()) {
|
||||
hann_win_ = hann_window(frame_length_);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus STFTNorm(const ortc::Tensor<float>& pcm, ortc::Tensor<float>& stft_norm) {
|
||||
return stft_norm_.Compute(pcm, n_fft_, hop_length_, {hann_win_.data(), hann_win_.size()}, frame_length_, stft_norm);
|
||||
}
|
||||
|
||||
static std::vector<float> hann_window(int N) {
|
||||
std::vector<float> window(N);
|
||||
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// this formula leads to more rounding errors than the one below
|
||||
// window[n] = static_cast<float>(0.5 * (1 - std::cos(2 * M_PI * n / (N - 1))));
|
||||
double n_sin = std::sin(M_PI * n / N);
|
||||
window[n] = static_cast<float>(n_sin * n_sin);
|
||||
}
|
||||
|
||||
return window;
|
||||
}
|
||||
|
||||
private:
|
||||
StftNormal stft_norm_;
|
||||
int64_t n_fft_{};
|
||||
int64_t hop_length_{};
|
||||
int64_t frame_length_{};
|
||||
std::vector<float> hann_win_;
|
||||
};
|
||||
|
||||
class LogMel {
|
||||
public:
|
||||
template <typename DictT>
|
||||
OrtxStatus Init(const DictT& attrs) {
|
||||
int n_fft = 0;
|
||||
int n_mel = 0;
|
||||
int chunk_size = 0;
|
||||
for (const auto& [key, value] : attrs) {
|
||||
if (key == "hop_length") {
|
||||
hop_length_ = std::get<int64_t>(value);
|
||||
} else if (key == "n_fft") {
|
||||
n_fft = std::get<int64_t>(value);
|
||||
} else if (key == "n_mel") {
|
||||
n_mel = std::get<int64_t>(value);
|
||||
} else if (key == "chunk_size") {
|
||||
chunk_size = std::get<int64_t>(value);
|
||||
} else {
|
||||
return {kOrtxErrorInvalidArgument, "[LogMel]: Invalid key in the JSON configuration."};
|
||||
}
|
||||
}
|
||||
|
||||
n_samples_ = n_sr_ * chunk_size;
|
||||
mel_filters_ = MelFilterBank(n_fft, n_mel, n_sr_);
|
||||
return {};
|
||||
}
|
||||
|
||||
OrtxStatus Compute(const ortc::Tensor<float>& stft_norm, ortc::Tensor<float>& logmel) {
|
||||
// Compute the Mel spectrogram by following Python code
|
||||
/*
|
||||
magnitudes = stft_norm[:, :, :-1]
|
||||
mel_spec = self.mel_filters @ magnitudes
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
spec_min = log_spec.max() - 8.0
|
||||
log_spec = torch.maximum(log_spec, spec_min)
|
||||
spec_shape = log_spec.shape
|
||||
padding_spec = torch.ones(spec_shape[0],
|
||||
spec_shape[1],
|
||||
self.n_samples // self.hop_length - spec_shape[2],
|
||||
dtype=torch.float)
|
||||
padding_spec *= spec_min
|
||||
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
*/
|
||||
assert(stft_norm.Shape().size() == 3 && stft_norm.Shape()[0] == 1);
|
||||
std::vector<int64_t> stft_shape = stft_norm.Shape();
|
||||
dlib::matrix<float> magnitudes(stft_norm.Shape()[1], stft_norm.Shape()[2] - 1);
|
||||
for (int i = 0; i < magnitudes.nr(); ++i) {
|
||||
std::copy(stft_norm.Data() + i * stft_shape[2], stft_norm.Data() + (i + 1) * stft_shape[2] - 1,
|
||||
magnitudes.begin() + i * magnitudes.nc());
|
||||
}
|
||||
|
||||
dlib::matrix<float> mel_spec = mel_filters_ * magnitudes;
|
||||
for (int i = 0; i < mel_spec.nr(); ++i) {
|
||||
for (int j = 0; j < mel_spec.nc(); ++j) {
|
||||
mel_spec(i, j) = std::max(1e-10f, mel_spec(i, j));
|
||||
}
|
||||
}
|
||||
|
||||
dlib::matrix<float> log_spec = dlib::log10(mel_spec);
|
||||
float log_spec_min = dlib::max(log_spec) - 8.0f;
|
||||
for (int i = 0; i < log_spec.nr(); ++i) {
|
||||
for (int j = 0; j < log_spec.nc(); ++j) {
|
||||
float v = std::max(log_spec(i, j), log_spec_min);
|
||||
v = (v + 4.0f) / 4.0f;
|
||||
log_spec(i, j) = v;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape = {mel_filters_.nr(), n_samples_ / hop_length_};
|
||||
float* buff = logmel.Allocate(shape);
|
||||
std::fill(buff, buff + logmel.NumberOfElement(), (log_spec_min + 4.0f) / 4.0f);
|
||||
for (int i = 0; i < log_spec.nr(); ++i) {
|
||||
auto row_len = log_spec.nc() * i;
|
||||
std::copy(log_spec.begin() + i * log_spec.nc(), log_spec.begin() + (i + 1) * log_spec.nc(), buff + i * shape[1]);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
// Function to compute the Mel filterbank
|
||||
static dlib::matrix<float> MelFilterBank(int n_fft, int n_mels, int sr = 16000, float min_mel = 0,
|
||||
float max_mel = 45.245640471924965) {
|
||||
// Initialize the filterbank matrix
|
||||
dlib::matrix<float> fbank(n_mels, n_fft / 2 + 1);
|
||||
memset(fbank.begin(), 0, fbank.size() * sizeof(float));
|
||||
|
||||
// Compute the frequency bins for the DFT
|
||||
std::vector<float> freq_bins(n_fft / 2 + 1);
|
||||
for (int i = 0; i <= n_fft / 2; ++i) {
|
||||
freq_bins[i] = i * sr / static_cast<float>(n_fft);
|
||||
}
|
||||
|
||||
// Compute the Mel scale frequencies
|
||||
std::vector<float> mel(n_mels + 2);
|
||||
for (int i = 0; i < n_mels + 2; ++i) {
|
||||
mel[i] = min_mel + i * (max_mel - min_mel) / (n_mels + 1);
|
||||
}
|
||||
|
||||
// Fill in the linear scale
|
||||
float f_min = 0.0f;
|
||||
float f_sp = 200.0f / 3.0f;
|
||||
std::vector<float> freqs(n_mels + 2);
|
||||
for (int i = 0; i < n_mels + 2; ++i) {
|
||||
freqs[i] = f_min + f_sp * mel[i];
|
||||
}
|
||||
|
||||
// Nonlinear scale
|
||||
float min_log_hz = 1000.0f;
|
||||
float min_log_mel = (min_log_hz - f_min) / f_sp;
|
||||
float logstep = log(6.4) / 27.0;
|
||||
|
||||
for (int i = 0; i < n_mels + 2; ++i) {
|
||||
if (mel[i] >= min_log_mel) {
|
||||
freqs[i] = min_log_hz * exp(logstep * (mel[i] - min_log_mel));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> mel_bins = freqs;
|
||||
std::vector<float> mel_spacing(n_mels + 1);
|
||||
for (int i = 0; i < n_mels + 1; ++i) {
|
||||
mel_spacing[i] = mel_bins[i + 1] - mel_bins[i];
|
||||
}
|
||||
|
||||
// Compute the ramps
|
||||
std::vector<std::vector<float>> ramps(n_mels + 2, std::vector<float>(n_fft / 2 + 1));
|
||||
for (int i = 0; i < n_mels + 2; ++i) {
|
||||
for (int j = 0; j <= n_fft / 2; ++j) {
|
||||
ramps[i][j] = mel_bins[i] - freq_bins[j];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_mels; ++i) {
|
||||
for (int j = 0; j <= n_fft / 2; ++j) {
|
||||
float left = -ramps[i][j] / mel_spacing[i];
|
||||
float right = ramps[i + 2][j] / mel_spacing[i + 1];
|
||||
fbank(i, j) = std::max(0.0f, std::min(left, right));
|
||||
}
|
||||
}
|
||||
|
||||
// Energy normalization
|
||||
for (int i = 0; i < n_mels; ++i) {
|
||||
float energy_norm = 2.0f / (mel_bins[i + 2] - mel_bins[i]);
|
||||
for (int j = 0; j <= n_fft / 2; ++j) {
|
||||
fbank(i, j) *= energy_norm;
|
||||
}
|
||||
}
|
||||
|
||||
return fbank;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t n_samples_ = {}; // sr * chunk_size
|
||||
int64_t hop_length_{};
|
||||
const int64_t n_sr_{16000};
|
||||
dlib::matrix<float> mel_filters_;
|
||||
};
|
||||
|
||||
} // namespace ort_extensions
|
|
@ -0,0 +1,437 @@
|
|||
{
|
||||
"feature_extraction": {
|
||||
"sequence": [
|
||||
{
|
||||
"operation": {
|
||||
"name": "audio_decoder",
|
||||
"type": "AudioDecoder"
|
||||
}
|
||||
},
|
||||
{
|
||||
"operation": {
|
||||
"name": "STFT",
|
||||
"type": "STFTNorm",
|
||||
"attrs": {
|
||||
"n_fft": 400,
|
||||
"frame_length": 400,
|
||||
"hop_length": 160,
|
||||
"_comment": [
|
||||
0.0,
|
||||
0.0000616908073425293,
|
||||
0.0002467334270477295,
|
||||
0.0005550682544708252,
|
||||
0.000986635684967041,
|
||||
0.0015413463115692139,
|
||||
0.0022190213203430176,
|
||||
0.0030195116996765137,
|
||||
0.003942638635635376,
|
||||
0.004988163709640503,
|
||||
0.006155818700790405,
|
||||
0.007445335388183594,
|
||||
0.008856385946273804,
|
||||
0.010388582944869995,
|
||||
0.012041628360748291,
|
||||
0.013815045356750488,
|
||||
0.01570841670036316,
|
||||
0.01772129535675049,
|
||||
0.019853144884109497,
|
||||
0.022103488445281982,
|
||||
0.02447172999382019,
|
||||
0.026957333087921143,
|
||||
0.029559612274169922,
|
||||
0.03227800130844116,
|
||||
0.03511175513267517,
|
||||
0.03806024789810181,
|
||||
0.0411226749420166,
|
||||
0.044298380613327026,
|
||||
0.04758647084236145,
|
||||
0.05098623037338257,
|
||||
0.05449673533439636,
|
||||
0.058117181062698364,
|
||||
0.06184667348861694,
|
||||
0.0656842589378357,
|
||||
0.06962898373603821,
|
||||
0.07367992401123047,
|
||||
0.0778360664844513,
|
||||
0.08209633827209473,
|
||||
0.08645972609519958,
|
||||
0.09092515707015991,
|
||||
0.09549149870872498,
|
||||
0.10015767812728882,
|
||||
0.10492250323295593,
|
||||
0.1097848117351532,
|
||||
0.11474338173866272,
|
||||
0.11979702115058899,
|
||||
0.12494447827339172,
|
||||
0.13018447160720825,
|
||||
0.1355157196521759,
|
||||
0.14093685150146484,
|
||||
0.1464466154575348,
|
||||
0.15204361081123352,
|
||||
0.1577264666557312,
|
||||
0.16349375247955322,
|
||||
0.16934409737586975,
|
||||
0.1752760112285614,
|
||||
0.18128803372383118,
|
||||
0.18737870454788208,
|
||||
0.19354650378227234,
|
||||
0.1997898817062378,
|
||||
0.20610737800598145,
|
||||
0.21249738335609436,
|
||||
0.21895831823349,
|
||||
0.2254886031150818,
|
||||
0.23208662867546082,
|
||||
0.23875075578689575,
|
||||
0.24547931551933289,
|
||||
0.2522706985473633,
|
||||
0.25912320613861084,
|
||||
0.26603513956069946,
|
||||
0.27300477027893066,
|
||||
0.2800304591655731,
|
||||
0.2871103882789612,
|
||||
0.29424285888671875,
|
||||
0.30142611265182495,
|
||||
0.30865830183029175,
|
||||
0.31593772768974304,
|
||||
0.3232625722885132,
|
||||
0.3306310474872589,
|
||||
0.3380413055419922,
|
||||
0.34549152851104736,
|
||||
0.352979838848114,
|
||||
0.3605044484138489,
|
||||
0.3680635094642639,
|
||||
0.37565508484840393,
|
||||
0.38327735662460327,
|
||||
0.3909284174442291,
|
||||
0.39860638976097107,
|
||||
0.4063093662261963,
|
||||
0.41403549909591675,
|
||||
0.42178282141685486,
|
||||
0.4295494258403778,
|
||||
0.43733343482017517,
|
||||
0.44513291120529175,
|
||||
0.45294591784477234,
|
||||
0.46077051758766174,
|
||||
0.46860480308532715,
|
||||
0.4764467775821686,
|
||||
0.4842946231365204,
|
||||
0.492146372795105,
|
||||
0.5,
|
||||
0.5078536868095398,
|
||||
0.515705406665802,
|
||||
0.5235532522201538,
|
||||
0.5313953161239624,
|
||||
0.5392295718193054,
|
||||
0.5470541715621948,
|
||||
0.5548672080039978,
|
||||
0.562666654586792,
|
||||
0.5704506635665894,
|
||||
0.5782172679901123,
|
||||
0.5859646201133728,
|
||||
0.5936906933784485,
|
||||
0.6013936996459961,
|
||||
0.609071671962738,
|
||||
0.6167227625846863,
|
||||
0.6243450045585632,
|
||||
0.6319366097450256,
|
||||
0.6394955515861511,
|
||||
0.6470202207565308,
|
||||
0.6545085310935974,
|
||||
0.6619587540626526,
|
||||
0.6693689823150635,
|
||||
0.6767374277114868,
|
||||
0.6840623021125793,
|
||||
0.691341757774353,
|
||||
0.6985740065574646,
|
||||
0.7057572603225708,
|
||||
0.7128896713256836,
|
||||
0.719969630241394,
|
||||
0.7269952893257141,
|
||||
0.7339649796485901,
|
||||
0.7408769130706787,
|
||||
0.7477294206619263,
|
||||
0.7545207738876343,
|
||||
0.761249303817749,
|
||||
0.7679134607315063,
|
||||
0.774511456489563,
|
||||
0.7810417413711548,
|
||||
0.7875027060508728,
|
||||
0.7938927412033081,
|
||||
0.800210177898407,
|
||||
0.8064535856246948,
|
||||
0.8126214146614075,
|
||||
0.8187121152877808,
|
||||
0.8247240781784058,
|
||||
0.8306560516357422,
|
||||
0.8365063667297363,
|
||||
0.8422735929489136,
|
||||
0.8479564785957336,
|
||||
0.8535534143447876,
|
||||
0.8590631484985352,
|
||||
0.8644843101501465,
|
||||
0.8698155879974365,
|
||||
0.8750555515289307,
|
||||
0.8802030086517334,
|
||||
0.8852566480636597,
|
||||
0.8902152180671692,
|
||||
0.8950775265693665,
|
||||
0.899842381477356,
|
||||
0.9045084714889526,
|
||||
0.9090749025344849,
|
||||
0.9135403037071228,
|
||||
0.9179036617279053,
|
||||
0.9221639633178711,
|
||||
0.9263200759887695,
|
||||
0.9303710460662842,
|
||||
0.9343158006668091,
|
||||
0.9381533861160278,
|
||||
0.941882848739624,
|
||||
0.945503294467926,
|
||||
0.9490138292312622,
|
||||
0.9524135589599609,
|
||||
0.9557017087936401,
|
||||
0.9588773250579834,
|
||||
0.961939811706543,
|
||||
0.9648882746696472,
|
||||
0.9677220582962036,
|
||||
0.9704403877258301,
|
||||
0.9730427265167236,
|
||||
0.9755282998085022,
|
||||
0.9778965711593628,
|
||||
0.9801468849182129,
|
||||
0.9822787046432495,
|
||||
0.9842916131019592,
|
||||
0.9861849546432495,
|
||||
0.9879584312438965,
|
||||
0.9896113872528076,
|
||||
0.9911436438560486,
|
||||
0.9925546646118164,
|
||||
0.9938441514968872,
|
||||
0.9950118064880371,
|
||||
0.996057391166687,
|
||||
0.9969804883003235,
|
||||
0.997780978679657,
|
||||
0.9984586238861084,
|
||||
0.999013364315033,
|
||||
0.9994449615478516,
|
||||
0.9997532367706299,
|
||||
0.9999383091926575,
|
||||
1,
|
||||
0.9999383091926575,
|
||||
0.9997532367706299,
|
||||
0.9994449615478516,
|
||||
0.999013364315033,
|
||||
0.9984586238861084,
|
||||
0.997780978679657,
|
||||
0.9969804286956787,
|
||||
0.9960573315620422,
|
||||
0.9950118064880371,
|
||||
0.9938441514968872,
|
||||
0.9925546646118164,
|
||||
0.9911435842514038,
|
||||
0.9896113872528076,
|
||||
0.9879583716392517,
|
||||
0.9861849546432495,
|
||||
0.9842915534973145,
|
||||
0.9822787046432495,
|
||||
0.9801468253135681,
|
||||
0.9778964519500732,
|
||||
0.9755282402038574,
|
||||
0.9730426073074341,
|
||||
0.9704403877258301,
|
||||
0.9677219390869141,
|
||||
0.9648882150650024,
|
||||
0.9619396924972534,
|
||||
0.9588772654533386,
|
||||
0.9557015895843506,
|
||||
0.9524134397506714,
|
||||
0.9490137100219727,
|
||||
0.9455032348632812,
|
||||
0.9418827295303345,
|
||||
0.9381532669067383,
|
||||
0.9343156814575195,
|
||||
0.9303709268569946,
|
||||
0.9263200759887695,
|
||||
0.9221639633178711,
|
||||
0.9179036617279053,
|
||||
0.913540244102478,
|
||||
0.9090747833251953,
|
||||
0.9045084714889526,
|
||||
0.8998422622680664,
|
||||
0.8950774669647217,
|
||||
0.8902151584625244,
|
||||
0.8852565884590149,
|
||||
0.8802029490470886,
|
||||
0.8750554919242859,
|
||||
0.869815468788147,
|
||||
0.8644842505455017,
|
||||
0.8590630888938904,
|
||||
0.853553295135498,
|
||||
0.8479562997817993,
|
||||
0.842273473739624,
|
||||
0.836506187915802,
|
||||
0.8306558728218079,
|
||||
0.8247239589691162,
|
||||
0.8187118768692017,
|
||||
0.8126212358474731,
|
||||
0.8064534664154053,
|
||||
0.8002099990844727,
|
||||
0.793892502784729,
|
||||
0.7875025272369385,
|
||||
0.7810416221618652,
|
||||
0.7745113372802734,
|
||||
0.767913281917572,
|
||||
0.7612491846084595,
|
||||
0.7545205950737,
|
||||
0.7477291822433472,
|
||||
0.7408767342567444,
|
||||
0.7339648008346558,
|
||||
0.7269951105117798,
|
||||
0.7199694514274597,
|
||||
0.7128894925117493,
|
||||
0.7057570219039917,
|
||||
0.6985738277435303,
|
||||
0.6913415789604187,
|
||||
0.684062123298645,
|
||||
0.6767372488975525,
|
||||
0.6693688035011292,
|
||||
0.6619585752487183,
|
||||
0.6545083522796631,
|
||||
0.6470199823379517,
|
||||
0.6394953727722168,
|
||||
0.6319363117218018,
|
||||
0.6243447661399841,
|
||||
0.6167224645614624,
|
||||
0.6090714335441589,
|
||||
0.601393461227417,
|
||||
0.5936904549598694,
|
||||
0.5859643220901489,
|
||||
0.5782170295715332,
|
||||
0.5704504251480103,
|
||||
0.5626664161682129,
|
||||
0.5548669099807739,
|
||||
0.5470539331436157,
|
||||
0.5392293334007263,
|
||||
0.5313950181007385,
|
||||
0.5235530138015747,
|
||||
0.5157051682472229,
|
||||
0.507853627204895,
|
||||
0.5,
|
||||
0.4921463429927826,
|
||||
0.484294593334198,
|
||||
0.4764467477798462,
|
||||
0.46860471367836,
|
||||
0.4607704281806946,
|
||||
0.4529458284378052,
|
||||
0.4451328217983246,
|
||||
0.437333345413208,
|
||||
0.42954933643341064,
|
||||
0.4217827320098877,
|
||||
0.4140354096889496,
|
||||
0.4063093066215515,
|
||||
0.3986063003540039,
|
||||
0.39092832803726196,
|
||||
0.3832772672176361,
|
||||
0.37565499544143677,
|
||||
0.36806342005729675,
|
||||
0.3605043888092041,
|
||||
0.35297977924346924,
|
||||
0.3454914391040802,
|
||||
0.338041216135025,
|
||||
0.33063095808029175,
|
||||
0.3232625126838684,
|
||||
0.3159376382827759,
|
||||
0.3086581826210022,
|
||||
0.3014259934425354,
|
||||
0.2942427396774292,
|
||||
0.28711026906967163,
|
||||
0.2800303101539612,
|
||||
0.2730046510696411,
|
||||
0.2660350203514099,
|
||||
0.2591230869293213,
|
||||
0.25227057933807373,
|
||||
0.24547919631004333,
|
||||
0.2387506067752838,
|
||||
0.23208650946617126,
|
||||
0.22548848390579224,
|
||||
0.21895819902420044,
|
||||
0.2124972641468048,
|
||||
0.2061072587966919,
|
||||
0.19978976249694824,
|
||||
0.1935463547706604,
|
||||
0.18737855553627014,
|
||||
0.18128788471221924,
|
||||
0.17527586221694946,
|
||||
0.1693439483642578,
|
||||
0.16349363327026367,
|
||||
0.15772631764411926,
|
||||
0.15204349160194397,
|
||||
0.14644649624824524,
|
||||
0.1409367322921753,
|
||||
0.13551557064056396,
|
||||
0.1301843225955963,
|
||||
0.12494435906410217,
|
||||
0.11979690194129944,
|
||||
0.11474326252937317,
|
||||
0.10978469252586365,
|
||||
0.10492238402366638,
|
||||
0.10015755891799927,
|
||||
0.09549137949943542,
|
||||
0.09092503786087036,
|
||||
0.08645960688591003,
|
||||
0.08209621906280518,
|
||||
0.07783591747283936,
|
||||
0.07367980480194092,
|
||||
0.06962886452674866,
|
||||
0.06568413972854614,
|
||||
0.06184655427932739,
|
||||
0.0581170916557312,
|
||||
0.0544966459274292,
|
||||
0.05098611116409302,
|
||||
0.04758638143539429,
|
||||
0.044298261404037476,
|
||||
0.04112258553504944,
|
||||
0.038060128688812256,
|
||||
0.03511166572570801,
|
||||
0.03227788209915161,
|
||||
0.02955952286720276,
|
||||
0.02695724368095398,
|
||||
0.024471670389175415,
|
||||
0.02210339903831482,
|
||||
0.01985308527946472,
|
||||
0.017721205949783325,
|
||||
0.015708357095718384,
|
||||
0.0138150155544281,
|
||||
0.012041598558425903,
|
||||
0.010388582944869995,
|
||||
0.008856356143951416,
|
||||
0.007445335388183594,
|
||||
0.006155818700790405,
|
||||
0.004988163709640503,
|
||||
0.003942638635635376,
|
||||
0.0030195116996765137,
|
||||
0.0022190213203430176,
|
||||
0.0015413165092468262,
|
||||
0.000986635684967041,
|
||||
0.0005550682544708252,
|
||||
0.0002467334270477295,
|
||||
0.0000616908073425293
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"operation": {
|
||||
"name": "log_mel_spectrogram",
|
||||
"type": "LogMelSpectrum",
|
||||
"attrs": {
|
||||
"chunk_size": 30,
|
||||
"hop_length": 160,
|
||||
"n_fft": 400,
|
||||
"n_mel": 80
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
|
@ -4,7 +4,9 @@
|
|||
#pragma once
|
||||
|
||||
#include "ortx_tokenizer.h"
|
||||
// make sure the C only compiler compatibility only.
|
||||
#include "ortx_processor.h"
|
||||
#include "ortx_extractor.h"
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortx_cpp_helper.h"
|
||||
#include "shared/api/speech_extractor.h"
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
||||
TEST(ExtractorTest, TestWhisperFeatureExtraction) {
|
||||
const char* audio_path[] = {"data/jfk.flac", "data/1272-141231-0002.wav", "data/1272-141231-0002.mp3"};
|
||||
OrtxObjectPtr<OrtxRawAudios> raw_audios;
|
||||
extError_t err = OrtxLoadAudios(ort_extensions::ptr(raw_audios), audio_path, 3);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
OrtxObjectPtr<OrtxFeatureExtractor> feature_extractor(OrtxCreateSpeechFeatureExtractor, "data/whisper/feature_extraction.json");
|
||||
OrtxObjectPtr<OrtxTensorResult> result;
|
||||
err = OrtxSpeechLogMel(feature_extractor.get(), raw_audios.get(), ort_extensions::ptr(result));
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
OrtxObjectPtr<OrtxTensor> tensor;
|
||||
err = OrtxTensorResultGetAt(result.get(), 0, ort_extensions::ptr(tensor));
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
const float* data{};
|
||||
const int64_t* shape{};
|
||||
size_t num_dims;
|
||||
err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
ASSERT_EQ(num_dims, 3);
|
||||
ASSERT_EQ(shape[0], 3);
|
||||
ASSERT_EQ(shape[1], 80);
|
||||
ASSERT_EQ(shape[2], 3000);
|
||||
}
|
|
@ -7,7 +7,7 @@
|
|||
#include <filesystem>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortx_c_helper.h"
|
||||
#include "ortx_cpp_helper.h"
|
||||
#include "shared/api/image_processor.h"
|
||||
|
||||
using namespace ort_extensions;
|
||||
|
@ -85,18 +85,18 @@ TEST(ProcessorTest, TestClipImageProcessing) {
|
|||
}
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
OrtxObjectPtr<OrtxImageProcessorResult> result;
|
||||
OrtxObjectPtr<OrtxTensorResult> result;
|
||||
err = OrtxImagePreProcess(processor.get(), raw_images.get(), ort_extensions::ptr(result));
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
OrtxObjectPtr<OrtxTensor> tensor;
|
||||
err = OrtxImageGetTensorResult(result.get(), 0, ort_extensions::ptr(tensor));
|
||||
OrtxTensor* tensor;
|
||||
err = OrtxTensorResultGetAt(result.get(), 0, &tensor);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
|
||||
const float* data{};
|
||||
const int64_t* shape{};
|
||||
size_t num_dims;
|
||||
err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims);
|
||||
err = OrtxGetTensorDataFloat(tensor, &data, &shape, &num_dims);
|
||||
ASSERT_EQ(err, kOrtxOK);
|
||||
ASSERT_EQ(num_dims, 4);
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче