This commit is contained in:
Xavier Dupré 2024-07-12 20:06:20 +02:00 коммит произвёл GitHub
Родитель 6710e81f97 8153bc1a3a
Коммит 5c928b9aed
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
41 изменённых файлов: 2848 добавлений и 1549 удалений

Просмотреть файл

@ -171,6 +171,10 @@ if (MSVC)
endif()
message(STATUS "_STATIC_MSVC_RUNTIME_LIBRARY: ${_STATIC_MSVC_RUNTIME_LIBRARY}")
# DLL initialization errors due to old conda msvcp140.dll dll are a result of the new MSVC compiler
# See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856
# Remove this definition once the conda msvcp140.dll dll is updated.
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
endif()
if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
@ -442,7 +446,9 @@ endif()
if(OCOS_ENABLE_BERT_TOKENIZER)
# Bert
set(_HAS_TOKENIZER ON)
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*")
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*"
"operators/tokenizer/bert_tokenizer.*"
"operators/tokenizer/bert_tokenizer_decoder.*")
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
endif()
@ -820,7 +826,9 @@ if(OCOS_ENABLE_AZURE)
endif()
target_compile_definitions(ortcustomops PUBLIC ${OCOS_COMPILE_DEFINITIONS})
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:noexcep_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_include_directories(ortcustomops PUBLIC "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_link_libraries(ortcustomops PUBLIC ocos_operators)
if(_BUILD_SHARED_LIBRARY)
@ -840,7 +848,8 @@ if(_BUILD_SHARED_LIBRARY)
standardize_output_folder(extensions_shared)
if(LINUX OR ANDROID)
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS
" -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
# strip if not a debug build
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS " -Wl,-s")

Просмотреть файл

@ -30,8 +30,6 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no
add_compile_definitions(USE_CUDA)
set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use
set(OCOS_USE_FLASH_ATTENTION OFF)
if (OCOS_USE_FLASH_ATTENTION)
message(STATUS "Enable flash attention")
add_compile_definitions(OCOS_USE_FLASH_ATTENTION)

Просмотреть файл

@ -0,0 +1,60 @@
# How to write custom ops
Custom Ops are based on ONNXRuntime-extensions API, especially **OrtLiteCustomOp** and **Tensor** class. C++ template metaprogramming is heavily used under the hood to provide big flexibility to the Custom Op authors on the parameter's count, type and order.
## Basic scenario
You have 2 ways to write a custom op: by writing a function, or by writing a structure.
### Custom op in the form of function
If your kernel is simple, you can use this option by just providing a function to compute the customized kernel. That function can have arbitrary number of inputs and outputs. For the inputs that are mandatory, their type would be like:
```C++
const Ort::Custom::Tensor<T>&
// or
const Ort::Custom::Tensor<T>*
```
For the inputs that are optional, their type would be like:
```C++
std::optional<const Ort::Custom::Tensor<T>*>
```
The function can also accept the pointer of **CUDAKernelContext**, where you can retrieve CUDA stream and other CUDA resources, if it requires to be run in CUDA GPU.
The function will return the type **OrtStatusPtr**
Please refer to [negpos_def.h](https://github.com/microsoft/onnxruntime-extensions/blob/main/operators/math/cuda/negpos_def.h) as an example and [tensor_tuple.inc](https://github.com/microsoft/onnxruntime-extensions/blob/main/include/custom_op/tensor_tuple.inc) for more possible parameter types.
### Custom op in the form of structure
If the kernel is complicated and there are extra properties of the custom op, you can use this option by providing a C++ structure where you can put these properties as the structure's member variables. Besides that, you also need to provide the following member functions:
```C++
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) // This function initialize the properties of the custom op
OrtStatusPtr Compute(...) const // This function computes the customized kernel.
```
The specification of the parameters of the Compute function is the same as the first way (custom op in the form of function)
## Advanced scenario
In some cases you need more control on the parameters, in this case you have to use the structure form, which you need to provide the implementations of the following member functions such as:
```C++
// By default the function will return OrtMemType::OrtMemTypeDefault for all the inputs,
// you can provide your own implementation to specify the ith input is in CPU or GPU.
static OrtMemType GetInputMemoryType(size_t input_index)
// You can specify input i shares the same memory with output j if possible, by allocating
// two array with same length for the pointer input_index and output_index seperately, and
// then let (*input_index)[k] = i and (*output_index)[k] = j.
// The return value is the length of the allocated array.
static size_t GetMayInplace(int** input_index, int** output_index)
// Release the allocated array from the GetMayInplace() function.
static void ReleaseMayInplace(int* input_index, int* output_index)
```

19
docs/c_api.md Normal file
Просмотреть файл

@ -0,0 +1,19 @@
# ONNXRuntime Extensions C ABI
ONNXRuntime Extensions provides a C-style ABI for pre-processing. It offers support for tokenization, image processing, speech feature extraction, and more. You can compile the ONNXRuntime Extensions as either a static library or a dynamic library to access these APIs.
The C ABI header files are named `ortx_*.h` and can be found in the include folder. There are three types of data processing APIs available:
- [`ortx_tokenizer.h`](../include/ortx_tokenizer.h): Provides tokenization for LLM models.
- [`ortx_processor.h`](../include/ortx_processor.h): Offers image processing APIs for multimodels.
- [`ortx_extraction.h`](../include/ortx_extractor.h): Provides speech feature extraction for audio data processing to assist the Whisper model.
## ABI QuickStart
Most APIs accept raw data inputs such as audio, image compressed binary formats, or UTF-8 encoded text for tokenization.
**Tokenization:** You can create a tokenizer object using `OrtxCreateTokenizer` and then use the object to tokenize a text or decode the token ID into the text. A C-style code snippet is available [here](../test/pp_api_test/c_only_test.c).
**Image processing:** `OrtxCreateProcessor` can create an image processor object from a pre-defined workflow in JSON format to process image files into a tensor-like data type. An example code snippet can be found [here](../test/pp_api_test/test_processor.cc#L75).
**Audio feature extraction:** `OrtxCreateSpeechFeatureExtractor` creates a speech feature extractor to obtain log mel spectrum data as input for the Whisper model. An example code snippet can be found [here](../test/pp_api_test/test_feature_extractor.cc#L16).

Просмотреть файл

@ -886,6 +886,13 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return INPUT_OUTPUT_OPTIONAL;
};
#endif
#if ORT_API_VERSION >= 18
OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
return 0;
};
OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {};
#endif
}
const std::string op_name_;

Просмотреть файл

@ -106,6 +106,18 @@ struct CustomOp_defined_getInputMemoryType : std::false_type {};
template <typename T>
struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};
template <typename T, typename = void>
struct CustomOp_defined_getMayInplace : std::false_type {};
template <typename T>
struct CustomOp_defined_getMayInplace<T, std::void_t<decltype(&T::GetMayInplace)>> : std::true_type {};
template <typename T, typename = void>
struct CustomOp_defined_releaseMayInplace : std::false_type {};
template <typename T>
struct CustomOp_defined_releaseMayInplace<T, std::void_t<decltype(&T::ReleaseMayInplace)>> : std::true_type {};
template <typename CustomOpKernel>
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
using ComputeFunction = decltype(&CustomOpKernel::Compute);
@ -192,6 +204,19 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
};
}
#if ORT_API_VERSION >= 18
if constexpr (CustomOp_defined_getMayInplace<CustomOpKernel>::value) {
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t {
return CustomOpKernel::GetMayInplace(input_index, output_index);
};
}
if constexpr (CustomOp_defined_releaseMayInplace<CustomOpKernel>::value) {
OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void {
CustomOpKernel::ReleaseMayInplace(input_index, output_index);
};
}
#endif
OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
if (api == nullptr) {

Просмотреть файл

75
include/ortx_extractor.h Normal file
Просмотреть файл

@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// C ABI header file for the onnxruntime-extensions tokenization module
#pragma once
#include "ortx_utils.h"
typedef OrtxObject OrtxFeatureExtractor;
typedef OrtxObject OrtxRawAudios;
typedef OrtxObject OrtxTensorResult;
#ifdef __cplusplus
extern "C" {
#endif
/**
* @brief Creates a feature extractor object.
*
* This function creates a feature extractor object based on the provided feature definition.
*
* @param[out] extractor Pointer to a pointer to the created feature extractor object.
* @param[in] fe_def The feature definition used to create the feature extractor.
*
* @return An error code indicating the result of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* fe_def);
/**
* Loads a collection of audio files into memory.
*
* This function loads a collection of audio files specified by the `audio_paths` array
* into memory and returns a pointer to the loaded audio data in the `audios` parameter.
*
* @param audios A pointer to a pointer that will be updated with the loaded audio data.
* The caller is responsible for freeing the memory allocated for the audio data.
* @param audio_paths An array of strings representing the paths to the audio files to be loaded.
* @param num_audios The number of audio files to be loaded.
*
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** audios, const char* const* audio_paths, size_t num_audios);
/**
* @brief Creates an array of raw audio objects.
*
* This function creates an array of raw audio objects based on the provided data and sizes.
*
* @param audios Pointer to the variable that will hold the created raw audio objects.
* @param data Array of pointers to the audio data.
* @param sizes Array of pointers to the sizes of the audio data.
* @param num_audios Number of audio objects to create.
*
* @return extError_t Error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawAudios(OrtxRawAudios** audios, const void* data[], const int64_t* sizes[], size_t num_audios);
/**
* @brief Calculates the log mel spectrogram for a given audio using the specified feature extractor.
*
* This function takes an instance of the OrtxFeatureExtractor struct, an instance of the OrtxRawAudios struct,
* and a pointer to a OrtxTensorResult pointer. It calculates the log mel spectrogram for the given audio using
* the specified feature extractor and stores the result in the provided log_mel pointer.
*
* @param extractor The feature extractor to use for calculating the log mel spectrogram.
* @param audio The raw audio data to process.
* @param log_mel A pointer to a OrtxTensorResult pointer where the result will be stored.
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* audio, OrtxTensorResult** log_mel);
#ifdef __cplusplus
}
#endif

Просмотреть файл

@ -10,7 +10,6 @@
// typedefs to create/dispose function flood, and to make the API more C++ friendly with less casting
typedef OrtxObject OrtxProcessor;
typedef OrtxObject OrtxRawImages;
typedef OrtxObject OrtxImageProcessorResult;
#ifdef __cplusplus
extern "C" {
@ -40,8 +39,22 @@ extError_t ORTX_API_CALL OrtxCreateProcessor(OrtxProcessor** processor, const ch
extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** image_paths, size_t num_images,
size_t* num_images_loaded);
/**
* @brief Preprocesses the given raw images using the specified processor.
* @brief Creates raw images from the provided data.
*
* This function creates raw images from the provided data. The raw images are stored in the `images` parameter.
*
* @param images Pointer to a pointer to the `OrtxRawImages` structure that will hold the created raw images.
* @param data Array of pointers to the data for each image.
* @param sizes Array of pointers to the sizes of each image.
* @param num_images Number of images to create.
* @return An `extError_t` value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxCreateRawImages(OrtxRawImages** images, const void* data[], const int64_t* sizes[], size_t num_images);
/**
* @brief Pre-processes the given raw images using the specified processor.
*
* This function applies preprocessing operations on the raw images using the provided processor.
* The result of the preprocessing is stored in the `OrtxImageProcessorResult` object.
@ -52,24 +65,7 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima
* @return An `extError_t` value indicating the success or failure of the preprocessing operation.
*/
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
OrtxImageProcessorResult** result);
/**
* @brief Retrieves the image processor result at the specified index.
*
* @param result Pointer to the OrtxImageProcessorResult structure to store the result.
* @param index The index of the result to retrieve.
* @return extError_t The error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor);
/** \brief Clear the outputs of the processor
*
* \param processor The processor object
* \param result The result object to clear
* \return Error code indicating the success or failure of the operation
*/
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result);
OrtxTensorResult** result);
#ifdef __cplusplus
}

Просмотреть файл

@ -17,19 +17,22 @@ typedef enum {
kOrtxKindDetokenizerCache = 0x778B,
kOrtxKindProcessor = 0x778C,
kOrtxKindRawImages = 0x778D,
kOrtxKindImageProcessorResult = 0x778E,
kOrtxKindTensorResult = 0x778E,
kOrtxKindProcessorResult = 0x778F,
kOrtxKindTensor = 0x7790,
kOrtxKindFeatureExtractor = 0x7791,
kOrtxKindRawAudios = 0x7792,
kOrtxKindEnd = 0x7999
} extObjectKind_t;
// all object managed by the library should be 'derived' from this struct
// which eventually will be released by TfmDispose if C++, or TFM_DISPOSE if C
typedef struct {
int ext_kind_;
extObjectKind_t ext_kind_;
} OrtxObject;
typedef OrtxObject OrtxTensor;
typedef OrtxObject OrtxTensorResult;
// C, instead of C++ doesn't cast automatically,
// so we need to use a macro to cast the object to the correct type
@ -77,6 +80,18 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object);
*/
extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);
/**
* @brief Retrieves the tensor at the specified index from the given tensor result.
*
* This function allows you to access a specific tensor from a tensor result object.
*
* @param result The tensor result object from which to retrieve the tensor.
* @param index The index of the tensor to retrieve.
* @param tensor A pointer to a variable that will hold the retrieved tensor.
* @return An error code indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor);
/** \brief Get the data from the tensor
*
* \param tensor The tensor object

Просмотреть файл

@ -10,7 +10,6 @@ This enables more flexibility and control over model execution, thus expanding t
__author__ = "Microsoft"
from ._version import __version__
from ._ocos import get_library_path
from ._ocos import Opdef, PyCustomOpDef
@ -66,6 +65,10 @@ if _lib_only:
gen_processing_models = _unimplemented
OrtPyFunction = _unimplemented
ort_inference = _unimplemented
PyOrtFunction = _unimplemented
optimize_model = _unimplemented
make_onnx_model = _unimplemented
ONNXRuntimeError = _unimplemented
else:
__all__ += _offline_api

Просмотреть файл

@ -17,7 +17,7 @@ from onnx import numpy_helper
from ._ortapi2 import make_onnx_model
from ._cuops import SingleOpGraph
from ._hf_cvt import HFTokenizerConverter
from .util import remove_unused_initializers
from .util import remove_unused_initializers, mel_filterbank
class _WhisperHParams:
@ -30,53 +30,15 @@ class _WhisperHParams:
N_FRAMES = N_SAMPLES // HOP_LENGTH
def _mel_filterbank(
n_fft: int, n_mels: int = 80, sr=16000, min_mel=0, max_mel=45.245640471924965, dtype=np.float32):
"""
Compute a Mel-filterbank. The filters are stored in the rows, the columns,
and it is Slaney normalized mel-scale filterbank.
"""
fbank = np.zeros((n_mels, n_fft // 2 + 1), dtype=dtype)
# the centers of the frequency bins for the DFT
freq_bins = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
mel = np.linspace(min_mel, max_mel, n_mels + 2)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
log_t = mel >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mel[log_t] - min_log_mel))
mel_bins = freqs
mel_spacing = np.diff(mel_bins)
ramps = mel_bins.reshape(-1, 1) - freq_bins.reshape(1, -1)
for i in range(n_mels):
left = -ramps[i] / mel_spacing[i]
right = ramps[i + 2] / mel_spacing[i + 1]
# intersect them with each other and zero
fbank[i] = np.maximum(0, np.minimum(left, right))
energy_norm = 2.0 / (mel_bins[2: n_mels + 2] - mel_bins[:n_mels])
fbank *= energy_norm[:, np.newaxis]
return fbank
class CustomOpStftNorm(torch.autograd.Function):
@staticmethod
def symbolic(g, self, n_fft, hop_length, window):
t_n_fft = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
t_hop_length = g.op('Constant', value_t=torch.tensor(hop_length, dtype=torch.int64))
t_frame_size = g.op('Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
t_n_fft = g.op('Constant', value_t=torch.tensor(
n_fft, dtype=torch.int64))
t_hop_length = g.op('Constant', value_t=torch.tensor(
hop_length, dtype=torch.int64))
t_frame_size = g.op(
'Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
@staticmethod
@ -97,7 +59,7 @@ class WhisperPrePipeline(torch.nn.Module):
self.n_fft = n_fft
self.window = torch.hann_window(n_fft)
self.mel_filters = torch.from_numpy(
_mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
def forward(self, audio_pcm: torch.Tensor):
stft_norm = CustomOpStftNorm.apply(audio_pcm,
@ -112,7 +74,8 @@ class WhisperPrePipeline(torch.nn.Module):
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1],
self.n_samples // self.hop_length - spec_shape[2],
self.n_samples // self.hop_length -
spec_shape[2],
dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
@ -165,15 +128,20 @@ def _to_onnx_stft(onnx_model, n_fft):
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0',
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
name='slice_1'),
make_node('Constant', inputs=[], outputs=['const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=['const1_output_0'], name='const1', value_int=1),
make_node('Constant', inputs=[], outputs=[
'const0_output_0'], name='const0', value_int=0),
make_node('Constant', inputs=[], outputs=[
'const1_output_0'], name='const1', value_int=1),
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
name='gather_4', axis=3),
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
name='gather_5', axis=3),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=['mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=['mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[stft_norm_node.output[0]], name='add0'),
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=[
'mul_output_0'], name='mul0'),
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=[
'mul_1_output_0'], name='mul1'),
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[
stft_norm_node.output[0]], name='add0'),
]
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
new_stft_nodes.extend(replaced_nodes)
@ -253,9 +221,11 @@ class WhisperDataProcGraph:
del g.node[:]
g.node.extend(nodes)
inputs = [onnx.helper.make_tensor_value_info("sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
inputs = [onnx.helper.make_tensor_value_info(
"sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
del g.input[:]
g.input.extend(inputs)
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(onnx.TensorProto.STRING, ['N', 'text']))
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(
onnx.TensorProto.STRING, ['N', 'text']))
return make_onnx_model(g, opset_version=self.opset_version)

Просмотреть файл

@ -3,17 +3,15 @@
#include "ocos.h"
#ifdef ENABLE_DR_LIBS
#include "audio_decoder.hpp"
#include "audio_decoder.h"
#endif // ENABLE_DR_LIBS
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []()-> CustomOpArray& {
FxLoadCustomOpFactory LoadCustomOpClasses_Audio = []() -> CustomOpArray& {
static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef ENABLE_DR_LIBS
,
CustomCpuStructV2("AudioDecoder", AudioDecoder)
CustomCpuStructV2("AudioDecoder", AudioDecoder),
#endif
);
[]() { return nullptr; });
return op_loader.GetCustomOps();
};

Просмотреть файл

@ -0,0 +1,181 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include <map>
#include <memory>
#include <gsl/util>
#include "audio_decoder.h"
#define DR_FLAC_IMPLEMENTATION
#include "dr_flac.h"
#define DR_MP3_IMPLEMENTATION 1
#define DR_MP3_FLOAT_OUTPUT 1
#include "dr_mp3.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include "narrow.h"
#include "string_utils.h"
#include "string_tensor.h"
#include "sampling.h"
OrtStatusPtr AudioDecoder::OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_);
if (!status) {
status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_);
}
return status;
}
AudioDecoder::AudioStreamType AudioDecoder::ReadStreamFormat(const uint8_t* p_data, const std::string& str_format,
OrtxStatus& status) const {
const std::map<std::string, AudioStreamType> format_mapping = {{"default", AudioStreamType::kDefault},
{"wav", AudioStreamType::kWAV},
{"mp3", AudioStreamType::kMP3},
{"flac", AudioStreamType::kFLAC}};
AudioStreamType stream_format = AudioStreamType::kDefault;
if (str_format.length() > 0) {
auto pos = format_mapping.find(str_format);
if (pos == format_mapping.end()) {
status = {kOrtxErrorInvalidArgument,
MakeString("[AudioDecoder]: Unknown audio stream format: ", str_format).c_str()};
return stream_format;
}
stream_format = pos->second;
}
if (stream_format == AudioStreamType::kDefault) {
auto p_stream = reinterpret_cast<char const*>(p_data);
std::string_view marker(p_stream, 4);
if (marker == "fLaC") {
stream_format = AudioStreamType::kFLAC;
} else if (marker == "RIFF") {
stream_format = AudioStreamType::kWAV;
} else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) {
// http://www.mp3-tech.org/programmer/frame_header.html
// only detect the 8 + 3 bits sync word
stream_format = AudioStreamType::kMP3;
} else {
status = {kOrtxErrorInvalidArgument, "[AudioDecoder]: Cannot detect audio stream format"};
}
}
return stream_format;
}
template <typename TY_AUDIO, typename FX_DECODER>
static size_t DrReadFrames(std::list<std::vector<float>>& frames, FX_DECODER fx, TY_AUDIO& obj) {
const size_t default_chunk_size = 1024 * 256;
int64_t total_buf_size = 0;
for (;;) {
std::vector<float> buf;
buf.resize(default_chunk_size * obj.channels);
auto n_frames = fx(&obj, default_chunk_size, buf.data());
if (n_frames <= 0) {
break;
}
auto data_size = n_frames * obj.channels;
total_buf_size += data_size;
buf.resize(data_size);
frames.emplace_back(std::move(buf));
}
return total_buf_size;
}
OrtxStatus AudioDecoder::Compute(const ortc::Tensor<uint8_t>& input, const std::optional<std::string> format,
ortc::Tensor<float>& output0) const {
const uint8_t* p_data = input.Data();
auto input_dim = input.Shape();
OrtxStatus status;
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Expect input dimension [n] or [1,n]."};
}
std::string str_format;
if (format) {
str_format = *format;
}
auto stream_format = ReadStreamFormat(p_data, str_format, status);
if (status) {
return status;
}
int64_t total_buf_size = 0;
std::list<std::vector<float>> lst_frames;
int64_t orig_sample_rate = 0;
int64_t orig_channels = 0;
if (stream_format == AudioStreamType::kMP3) {
auto mp3_obj_ptr = std::make_unique<drmp3>();
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) {
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on MP3 stream."};
return status;
}
orig_sample_rate = mp3_obj_ptr->sampleRate;
orig_channels = mp3_obj_ptr->channels;
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
} else if (stream_format == AudioStreamType::kFLAC) {
drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr);
auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); });
if (flac_obj == nullptr) {
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on FLAC stream."};
return status;
}
orig_sample_rate = flac_obj->sampleRate;
orig_channels = flac_obj->channels;
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
} else {
drwav wav_obj;
if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) {
status = {kOrtxErrorCorruptData, "[AudioDecoder]: unexpected error on WAV stream."};
return status;
}
orig_sample_rate = wav_obj.sampleRate;
orig_channels = wav_obj.channels;
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
}
if (downsample_rate_ != 0 && orig_sample_rate < downsample_rate_) {
status = {kOrtxErrorCorruptData, "[AudioDecoder]: only down-sampling supported."};
return status;
}
// join all frames
std::vector<float> buf;
buf.resize(total_buf_size);
int64_t offset = 0;
for (auto& _b : lst_frames) {
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
offset += _b.size();
}
// mix the stereo channels into mono channel
if (stereo_mixer_ && orig_channels > 1) {
if (buf.size() > 1) {
for (size_t i = 0; i < buf.size() / 2; ++i) {
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
}
buf.resize(buf.size() / 2);
}
}
if (downsample_rate_ != 0 && downsample_rate_ != orig_sample_rate) {
// A lowpass filter on buf audio data to remove high frequency noise
ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate);
std::vector<float> filtered_buf = filter.Process(buf);
// downsample the audio data
KaiserWindowInterpolation::Process(filtered_buf, buf, 1.0f * orig_sample_rate, 1.0f * downsample_rate_);
}
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
float* p_output = output0.Allocate(dim_out);
std::copy(buf.begin(), buf.end(), p_output);
return status;
}

Просмотреть файл

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include <list>
#include <optional>
struct AudioDecoder {
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info);
template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
// in API mode, the default value is 1
downsample_rate_ = 16000;
stereo_mixer_ = 1;
for (const auto& [key, value] : attrs) {
if (key == "target_sample_rate") {
downsample_rate_ = std::get<std::int64_t>(value);
} else if (key == "stereo_to_mono") {
stereo_mixer_ = std::get<std::int64_t>(value);
} else {
return {kOrtxErrorInvalidArgument, "[AudioDecoder]: Invalid argument"};
}
}
return {};
}
enum class AudioStreamType { kDefault = 0, kWAV, kMP3, kFLAC };
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtxStatus& status) const;
OrtxStatus Compute(const ortc::Tensor<uint8_t>& input, const std::optional<std::string> format,
ortc::Tensor<float>& output0) const;
OrtxStatus ComputeNoOpt(const ortc::Tensor<uint8_t>& input, ortc::Tensor<float>& output0) {
return Compute(input, std::nullopt, output0);
}
private:
int64_t downsample_rate_{};
int64_t stereo_mixer_{};
};

Просмотреть файл

@ -1,206 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "ocos.h"
#include <list>
#include <map>
#include <memory>
#include <optional>
#define DR_FLAC_IMPLEMENTATION
#include "dr_flac.h"
#define DR_MP3_IMPLEMENTATION 1
#define DR_MP3_FLOAT_OUTPUT 1
#include "dr_mp3.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <gsl/util>
#include "narrow.h"
#include "string_utils.h"
#include "string_tensor.h"
#include "sampling.h"
struct AudioDecoder{
public:
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
auto status = OrtW::GetOpAttribute(info, "downsampling_rate", downsample_rate_);
if (!status) {
status = OrtW::GetOpAttribute(info, "stereo_to_mono", stereo_mixer_);
}
return status;
}
enum class AudioStreamType {
kDefault = 0,
kWAV,
kMP3,
kFLAC
};
AudioStreamType ReadStreamFormat(const uint8_t* p_data, const std::string& str_format, OrtStatusPtr& status) const {
static const std::map<std::string, AudioStreamType> format_mapping = {
{"default", AudioStreamType::kDefault},
{"wav", AudioStreamType::kWAV},
{"mp3", AudioStreamType::kMP3},
{"flac", AudioStreamType::kFLAC}};
AudioStreamType stream_format = AudioStreamType::kDefault;
if (str_format.length() > 0) {
auto pos = format_mapping.find(str_format);
if (pos == format_mapping.end()) {
status = OrtW::CreateStatus(MakeString(
"[AudioDecoder]: Unknown audio stream format: ", str_format)
.c_str(),
ORT_INVALID_ARGUMENT);
return stream_format;
}
stream_format = pos->second;
}
if (stream_format == AudioStreamType::kDefault) {
auto p_stream = reinterpret_cast<char const*>(p_data);
std::string_view marker(p_stream, 4);
if (marker == "fLaC") {
stream_format = AudioStreamType::kFLAC;
} else if (marker == "RIFF") {
stream_format = AudioStreamType::kWAV;
} else if (marker[0] == char(0xFF) && (marker[1] | 0x1F) == char(0xFF)) {
// http://www.mp3-tech.org/programmer/frame_header.html
// only detect the 8 + 3 bits sync word
stream_format = AudioStreamType::kMP3;
} else {
status = OrtW::CreateStatus("[AudioDecoder]: Cannot detect audio stream format", ORT_INVALID_ARGUMENT);
}
}
return stream_format;
}
template <typename TY_AUDIO, typename FX_DECODER>
static size_t DrReadFrames(std::list<std::vector<float>>& frames, FX_DECODER fx, TY_AUDIO& obj) {
const size_t default_chunk_size = 1024 * 256;
int64_t total_buf_size = 0;
for (;;) {
std::vector<float> buf;
buf.resize(default_chunk_size * obj.channels);
auto n_frames = fx(&obj, default_chunk_size, buf.data());
if (n_frames <= 0) {
break;
}
auto data_size = n_frames * obj.channels;
total_buf_size += data_size;
buf.resize(data_size);
frames.emplace_back(std::move(buf));
}
return total_buf_size;
}
OrtStatusPtr Compute(const ortc::Tensor<uint8_t>& input,
const std::optional<std::string> format,
ortc::Tensor<float>& output0) const {
const uint8_t* p_data = input.Data();
auto input_dim = input.Shape();
OrtStatusPtr status = nullptr;
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
status = OrtW::CreateStatus("[AudioDecoder]: Expect input dimension [n] or [1,n].", ORT_INVALID_ARGUMENT);
return status;
}
std::string str_format;
if (format) {
str_format = *format;
}
auto stream_format = ReadStreamFormat(p_data, str_format, status);
if (status) {
return status;
}
int64_t total_buf_size = 0;
std::list<std::vector<float>> lst_frames;
int64_t orig_sample_rate = 0;
int64_t orig_channels = 0;
if (stream_format == AudioStreamType::kMP3) {
auto mp3_obj_ptr = std::make_unique<drmp3>();
if (!drmp3_init_memory(mp3_obj_ptr.get(), p_data, input.NumberOfElement(), nullptr)) {
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on MP3 stream.", ORT_RUNTIME_EXCEPTION);
return status;
}
orig_sample_rate = mp3_obj_ptr->sampleRate;
orig_channels = mp3_obj_ptr->channels;
total_buf_size = DrReadFrames(lst_frames, drmp3_read_pcm_frames_f32, *mp3_obj_ptr);
} else if (stream_format == AudioStreamType::kFLAC) {
drflac* flac_obj = drflac_open_memory(p_data, input.NumberOfElement(), nullptr);
auto flac_obj_closer = gsl::finally([flac_obj]() { drflac_close(flac_obj); });
if (flac_obj == nullptr) {
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on FLAC stream.", ORT_RUNTIME_EXCEPTION);
return status;
}
orig_sample_rate = flac_obj->sampleRate;
orig_channels = flac_obj->channels;
total_buf_size = DrReadFrames(lst_frames, drflac_read_pcm_frames_f32, *flac_obj);
} else {
drwav wav_obj;
if (!drwav_init_memory(&wav_obj, p_data, input.NumberOfElement(), nullptr)) {
status = OrtW::CreateStatus("[AudioDecoder]: unexpected error on WAV stream.", ORT_RUNTIME_EXCEPTION);
return status;
}
orig_sample_rate = wav_obj.sampleRate;
orig_channels = wav_obj.channels;
total_buf_size = DrReadFrames(lst_frames, drwav_read_pcm_frames_f32, wav_obj);
}
if (downsample_rate_ != 0 &&
orig_sample_rate < downsample_rate_) {
status = OrtW::CreateStatus("[AudioDecoder]: only down-sampling supported.", ORT_INVALID_ARGUMENT);
return status;
}
// join all frames
std::vector<float> buf;
buf.resize(total_buf_size);
int64_t offset = 0;
for (auto& _b : lst_frames) {
std::copy(_b.begin(), _b.end(), buf.begin() + offset);
offset += _b.size();
}
// mix the stereo channels into mono channel
if (stereo_mixer_ && orig_channels > 1) {
if (buf.size() > 1) {
for (size_t i = 0; i < buf.size() / 2; ++i) {
buf[i] = (buf[i * 2] + buf[i * 2 + 1]) / 2;
}
buf.resize(buf.size() / 2);
}
}
if (downsample_rate_ != 0 &&
downsample_rate_ != orig_sample_rate) {
// A lowpass filter on buf audio data to remove high frequency noise
ButterworthLowpass filter(0.5 * downsample_rate_, 1.0 * orig_sample_rate);
std::vector<float> filtered_buf = filter.Process(buf);
// downsample the audio data
KaiserWindowInterpolation::Process(filtered_buf, buf,
1.0f * orig_sample_rate, 1.0f * downsample_rate_);
}
std::vector<int64_t> dim_out = {1, ort_extensions::narrow<int64_t>(buf.size())};
float* p_output = output0.Allocate(dim_out);
std::copy(buf.begin(), buf.end(), p_output);
return status;
}
private:
int64_t downsample_rate_{};
int64_t stereo_mixer_{};
};

Просмотреть файл

@ -87,6 +87,13 @@ struct Flash_fwd_params : public Qkv_params {
// The indices to index into the KV cache.
int* __restrict__ cache_batch_idx = nullptr;
// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
float rp_dropout;
// Local window size
int window_size_left = -1;
int window_size_right = -1;
@ -102,6 +109,9 @@ struct Flash_fwd_params : public Qkv_params {
int num_splits = 0; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
const cudaDeviceProp* dprops = nullptr;
};

Просмотреть файл

@ -32,7 +32,9 @@ void set_params_fprop(Flash_fwd_params& params,
bool is_bf16,
bool kv_bsnh = true,
int window_size_left = -1,
int window_size_right = -1) {
int window_size_right = -1,
bool paged_KV = false,
int page_block_size = -1) {
// Set the pointers and strides.
params.q_ptr = q;
params.k_ptr = k;
@ -64,8 +66,8 @@ void set_params_fprop(Flash_fwd_params& params,
if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
params.k_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
params.v_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
} else {
params.q_batch_stride = 0;
@ -99,6 +101,10 @@ void set_params_fprop(Flash_fwd_params& params,
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
params.rp_dropout = 1.f;
params.alibi_slopes_ptr = nullptr;
params.alibi_slopes_batch_stride = 0;
// In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates
// local and causal, meaning when we have local window size
params.is_causal = is_causal;
@ -349,8 +355,8 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
@ -374,7 +380,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size,
bool is_rotary_interleaved,
bool is_packed_qkv) {
bool is_packed_qkv,
int32_t* block_table, // batch_size x max_num_blocks_per_seq
int32_t max_num_blocks_per_seq,
int32_t page_block_size) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
@ -398,7 +407,9 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
is_bf16,
past_bsnh,
local_window_size,
is_causal ? 0 : -1);
is_causal ? 0 : -1,
block_table != nullptr,
page_block_size);
params.dprops = &dprops;
if (k_new != nullptr && v_new != nullptr) {
@ -454,6 +465,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.oaccum_ptr = nullptr;
}
params.block_table = block_table;
params.block_table_batch_stride = max_num_blocks_per_seq;
params.page_block_size = page_block_size;
// Only split kernel supports appending to KV cache
run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);

Просмотреть файл

@ -53,8 +53,8 @@ OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops,
OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
@ -78,7 +78,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size = -1,
bool is_rotary_interleaved = false,
bool is_packed_qkv = false);
bool is_packed_qkv = false,
int32_t* block_table = nullptr, // batch_size x max_num_blocks_per_seq
int32_t max_num_blocks_per_seq = -1,
int32_t page_block_size = 1);
size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -9,20 +9,20 @@
namespace flash {
template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
template <typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
flash::compute_attn<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params);
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
#else
(void)params;
#endif
}
template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params);
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
#else
(void)params;
#endif
@ -38,7 +38,7 @@ __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
#endif
}
template <typename Kernel_traits, bool Is_causal>
template <typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
@ -53,23 +53,25 @@ void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
// ORT_ENFORCE(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel < Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
// ORT_ENFORCE(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
});
});
@ -90,16 +92,18 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) {
BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr);
auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
});
});
@ -143,7 +147,7 @@ template <typename T>
void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) {
constexpr static int Headdim = 32;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
});
}
@ -154,7 +158,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
});
@ -168,12 +172,12 @@ void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) {
if constexpr (!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
@ -192,12 +196,12 @@ void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) {
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if (is_sm8x) {
if constexpr (!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_causal>(params, stream);
@ -220,12 +224,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) {
// and 128 x 64 with 8 warps is the fastest for non-causal.
if (is_sm8x) {
if constexpr (!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
@ -241,7 +245,7 @@ template <typename T>
void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) {
constexpr int Headdim = 192;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
@ -257,9 +261,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) {
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);
@ -280,9 +284,9 @@ void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) {
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, false /*Is_dropout*/, Is_causal>(params, stream);
}
// 64 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_causal>(params, stream);

Просмотреть файл

@ -54,10 +54,10 @@ __device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor
reduce_<zero_init>(tensor, max, max_op);
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1>& sum) {
SumOp<float> sum_op;
reduce_(tensor, sum, sum_op);
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
@ -212,4 +212,168 @@ inline __device__ void apply_mask_causal_w_idx(
}
}
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
__forceinline__ __device__ Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
};
template <bool Is_causal, bool Is_local, bool Has_alibi>
struct Mask {
const int max_seqlen_k, max_seqlen_q;
const int window_size_left, window_size_right;
const float alibi_slope;
__forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
const int window_size_left, const int window_size_right,
const float alibi_slope=0.f)
: max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q)
, window_size_left(window_size_left)
, window_size_right(window_size_right)
, alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
};
// Causal_mask: whether this particular iteration needs causal masking
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
static_assert(Layout::rank == 3, "Only support 3D Tensor");
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
if constexpr (Need_masking) {
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
// Do we need both row and column indices, or just column incides?
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Col_idx_only) {
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// No causal, no local
if constexpr (Has_alibi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
if constexpr (!Is_even_MN) {
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
} else {
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
if constexpr (Causal_mask) {
if (col_idx >= col_idx_limit_right) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
if constexpr (Is_local) {
if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
// Causal and Local already handles MN masking
if (col_idx >= max_seqlen_k) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}
}
}
};
};
} // namespace flash

Просмотреть файл

@ -198,6 +198,28 @@ inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template <typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
@ -212,6 +234,25 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template <typename MMA_traits, typename Layout>

Просмотреть файл

@ -6,37 +6,32 @@
#include "ocos.h"
#include <dlib/matrix.h>
struct StftNormal{
struct StftNormal {
StftNormal() = default;
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
return OrtW::GetOpAttribute(info, "onesided", onesided_);
}
OrtStatusPtr Compute(const ortc::Tensor<float>& input0,
int64_t n_fft,
int64_t hop_length,
const ortc::Span<float>& input3,
int64_t frame_length,
ortc::Tensor<float>& output0) const {
OrtxStatus Compute(const ortc::Tensor<float>& input0, int64_t n_fft, int64_t hop_length,
const ortc::Span<float>& input3, int64_t frame_length, ortc::Tensor<float>& output0) const {
auto X = input0.Data();
auto window = input3.data_;
auto dimensions = input0.Shape();
auto win_length = input3.size();
if (dimensions.size() < 2 || input0.NumberOfElement() != dimensions[1]) {
return OrtW::CreateStatus("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT);
return {kOrtxErrorInvalidArgument, "[Stft] Only batch == 1 tensor supported."};
}
if (frame_length != n_fft) {
return OrtW::CreateStatus("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT);
return {kOrtxErrorInvalidArgument, "[Stft] Only support size of FFT equals the frame length."};
}
dlib::matrix<float> dm_x = dlib::mat(X, 1, dimensions[1]);
dlib::matrix<float> hann_win = dlib::mat(window, 1, win_length);
auto m_stft = dlib::stft(
dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); },
n_fft, win_length, hop_length);
auto m_stft =
dlib::stft(dm_x, [&hann_win](size_t x, size_t len) { return hann_win(0, x); }, n_fft, win_length, hop_length);
if (onesided_) {
m_stft = dlib::subm(m_stft, 0, 0, m_stft.nr(), (m_stft.nc() >> 1) + 1);
@ -49,7 +44,7 @@ struct StftNormal{
auto out0 = output0.Allocate(outdim);
memcpy(out0, result.steal_memory().get(), result_size * sizeof(float));
return nullptr;
return {};
}
private:

Просмотреть файл

@ -0,0 +1,77 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "speech_extractor.h"
#include "c_api_utils.hpp"
using namespace ort_extensions;
class RawAudiosObject : public OrtxObjectImpl {
public:
RawAudiosObject() : OrtxObjectImpl(extObjectKind_t::kOrtxKindRawAudios) {}
~RawAudiosObject() override = default;
std::unique_ptr<AudioRawData[]> audios_;
size_t num_audios_;
};
extError_t ORTX_API_CALL OrtxLoadAudios(OrtxRawAudios** raw_audios, const char* const* audio_paths, size_t num_audios) {
if (raw_audios == nullptr || audio_paths == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto audios_obj = std::make_unique<RawAudiosObject>();
auto [audios, num] =
ort_extensions::LoadRawData<char const* const*, AudioRawData>(audio_paths, audio_paths + num_audios);
audios_obj->audios_ = std::move(audios);
audios_obj->num_audios_ = num;
*raw_audios = static_cast<OrtxRawAudios*>(audios_obj.release());
return extError_t();
}
extError_t ORTX_API_CALL OrtxCreateSpeechFeatureExtractor(OrtxFeatureExtractor** extractor, const char* def) {
if (extractor == nullptr || def == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto extractor_ptr = std::make_unique<SpeechFeatureExtractor>();
ReturnableStatus status = extractor_ptr->Init(def);
if (status.IsOk()) {
*extractor = static_cast<OrtxFeatureExtractor*>(extractor_ptr.release());
} else {
*extractor = nullptr;
}
return status.Code();
}
extError_t ORTX_API_CALL OrtxSpeechLogMel(OrtxFeatureExtractor* extractor, OrtxRawAudios* raw_audios,
OrtxTensorResult** result) {
if (extractor == nullptr || raw_audios == nullptr || result == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto extractor_ptr = static_cast<SpeechFeatureExtractor*>(extractor);
auto audios_obj = static_cast<RawAudiosObject*>(raw_audios);
auto ts_result = std::make_unique<TensorResult>();
std::unique_ptr<ortc::Tensor<float>> log_mel[1];
ReturnableStatus status =
extractor_ptr->DoCall(ort_extensions::span(audios_obj->audios_.get(), audios_obj->num_audios_), log_mel[0]);
if (status.IsOk()) {
std::vector<std::unique_ptr<ortc::TensorBase>> tensors;
std::transform(log_mel, log_mel + 1, std::back_inserter(tensors),
[](auto& ts) { return std::unique_ptr<ortc::TensorBase>(ts.release()); });
ts_result->SetTensors(std::move(tensors));
*result = ts_result.release();
} else {
*result = nullptr;
}
return status.Code();
}

Просмотреть файл

@ -4,6 +4,8 @@
#include "ortx_processor.h"
#include "image_processor.h"
#include "c_api_utils.hpp"
using namespace ort_extensions;
extError_t OrtxCreateProcessor(OrtxProcessor** processor, const char* def) {
@ -37,19 +39,19 @@ extError_t ORTX_API_CALL OrtxLoadImages(OrtxRawImages** images, const char** ima
}
auto images_obj = std::make_unique<RawImagesObject>();
auto [img, num] = LoadRawImages(image_paths, image_paths + num_images);
auto [img, num] = LoadRawData<char const**, ImageRawData>(image_paths, image_paths + num_images);
images_obj->images = std::move(img);
images_obj->num_images = num;
if (num_images_loaded != nullptr) {
*num_images_loaded = num;
}
*images = static_cast<OrtxRawImages*>(images_obj.release());
return extError_t();
}
extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawImages* images,
OrtxImageProcessorResult** result) {
OrtxTensorResult** result) {
if (processor == nullptr || images == nullptr || result == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
@ -67,59 +69,14 @@ extError_t ORTX_API_CALL OrtxImagePreProcess(OrtxProcessor* processor, OrtxRawIm
return status.Code();
}
auto result_ptr = std::make_unique<ImageProcessorResult>();
auto result_ptr = std::make_unique<TensorResult>();
status =
processor_ptr->PreProcess(ort_extensions::span(images_ptr->images.get(), images_ptr->num_images), *result_ptr);
if (status.IsOk()) {
*result = static_cast<OrtxImageProcessorResult*>(result_ptr.release());
*result = static_cast<OrtxTensorResult*>(result_ptr.release());
} else {
*result = nullptr;
}
return {};
}
extError_t ORTX_API_CALL OrtxImageGetTensorResult(OrtxImageProcessorResult* result, size_t index, OrtxTensor** tensor) {
if (result == nullptr || tensor == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto result_ptr = static_cast<ImageProcessorResult*>(result);
ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult));
if (!status.IsOk()) {
return status.Code();
}
if (index >= result_ptr->results.size()) {
ReturnableStatus::last_error_message_ = "Index out of range";
return kOrtxErrorInvalidArgument;
}
auto tensor_ptr = std::make_unique<OrtxObjectWrapper<ortc::TensorBase>>();
tensor_ptr->SetObject(result_ptr->results[index].get());
*tensor = static_cast<OrtxTensor*>(tensor_ptr.release());
return extError_t();
}
extError_t ORTX_API_CALL OrtxClearOutputs(OrtxProcessor* processor, OrtxImageProcessorResult* result) {
if (processor == nullptr || result == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
const auto processor_ptr = static_cast<const ImageProcessor*>(processor);
ReturnableStatus status(processor_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindProcessor));
if (!status.IsOk()) {
return status.Code();
}
auto result_ptr = static_cast<ImageProcessorResult*>(result);
status = result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindImageProcessorResult);
if (!status.IsOk()) {
return status.Code();
}
ImageProcessor::ClearOutputs(result_ptr);
return extError_t();
}

Просмотреть файл

@ -6,7 +6,7 @@
#include "c_api_utils.hpp"
#include "tokenizer_impl.h"
namespace ort_extensions {
using namespace ort_extensions;
class DetokenizerCache : public OrtxObjectImpl {
public:
@ -17,29 +17,20 @@ class DetokenizerCache : public OrtxObjectImpl {
std::string last_text_{}; // last detokenized text
};
template<>
OrtxObject* OrtxObjectFactory<DetokenizerCache>::CreateForward() {
return std::make_unique<DetokenizerCache>().release();
template <>
OrtxObject* OrtxObjectFactory::CreateForward<DetokenizerCache>() {
return Create<DetokenizerCache>();
}
template<>
void OrtxObjectFactory<DetokenizerCache>::DisposeForward(OrtxObject* obj) {
Dispose(obj);
}
} // namespace ort_extensions
using namespace ort_extensions;
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer,
const char* input[], size_t batch_size, OrtxTokenId2DArray** output) {
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size,
OrtxTokenId2DArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto token_ptr = static_cast<const TokenizerImpl*>(tokenizer);
ReturnableStatus status =
token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
ReturnableStatus status = token_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTokenizer);
if (!status.IsOk()) {
return status.Code();
}
@ -61,8 +52,8 @@ extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer,
return extError_t();
}
extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
const OrtxTokenId2DArray* input, OrtxStringArray** output) {
extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input,
OrtxStringArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
@ -81,11 +72,8 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
}
std::vector<span<extTokenId_t const>> t_ids;
std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(),
std::back_inserter(t_ids),
[](const std::vector<extTokenId_t>& vec) {
return span<extTokenId_t const>(vec.data(), vec.size());
});
std::transform(input_2d->token_ids().begin(), input_2d->token_ids().end(), std::back_inserter(t_ids),
[](const std::vector<extTokenId_t>& vec) { return span<extTokenId_t const>(vec.data(), vec.size()); });
std::vector<std::string> output_text;
status = token_ptr->Detokenize(t_ids, output_text);
@ -101,9 +89,7 @@ extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer,
;
}
extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer,
const extTokenId_t* input,
size_t len,
extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len,
OrtxStringArray** output) {
if (tokenizer == nullptr || input == nullptr || output == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
@ -186,8 +172,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* to
return extError_t();
}
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array,
size_t index, const extTokenId_t** item, size_t* length) {
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array, size_t index,
const extTokenId_t** item, size_t* length) {
if (token_id_2d_array == nullptr || item == nullptr || length == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
@ -210,9 +196,8 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* tok
return extError_t();
}
extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer,
OrtxDetokenizerCache* cache,
extTokenId_t next_id, const char** text_out) {
extError_t OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, OrtxDetokenizerCache* cache, extTokenId_t next_id,
const char** text_out) {
if (tokenizer == nullptr || cache == nullptr || text_out == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;

Просмотреть файл

@ -10,6 +10,8 @@
using namespace ort_extensions;
class DetokenizerCache; // forward definition in tokenizer_impl.cc
thread_local std::string ReturnableStatus::last_error_message_;
OrtxStatus OrtxObjectImpl::IsInstanceOf(extObjectKind_t kind) const {
@ -37,7 +39,7 @@ extError_t ORTX_API_CALL OrtxCreate(extObjectKind_t kind, OrtxObject** object, .
va_start(args, object);
if (kind == extObjectKind_t::kOrtxKindDetokenizerCache) {
*object = OrtxObjectFactory<DetokenizerCache>::CreateForward();
*object = OrtxObjectFactory::CreateForward<DetokenizerCache>();
} else if (kind == extObjectKind_t::kOrtxKindTokenizer) {
return OrtxCreateTokenizer(static_cast<OrtxTokenizer**>(object), va_arg(args, const char*));
}
@ -80,8 +82,8 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) {
return kOrtxErrorInvalidArgument;
}
if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) {
OrtxObjectFactory<StringArray>::Dispose(object);
/* if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindStringArray) {
OrtxObjectFactory::Dispose<StringArray>(object);
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindTokenId2DArray) {
OrtxObjectFactory<TokenId2DArray>::Dispose(object);
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindDetokenizerCache) {
@ -94,6 +96,11 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object) {
OrtxObjectFactory<ImageProcessorResult>::Dispose(object);
} else if (Ortx_object->ortx_kind() == extObjectKind_t::kOrtxKindProcessor) {
OrtxObjectFactory<ImageProcessor>::Dispose(object);
} */
if (Ortx_object->ortx_kind() >= kOrtxKindBegin && Ortx_object->ortx_kind() < kOrtxKindEnd) {
OrtxObjectFactory::Dispose<OrtxObjectImpl>(object);
} else {
return kOrtxErrorInvalidArgument;
}
return extError_t();
@ -113,6 +120,30 @@ extError_t ORTX_API_CALL OrtxDispose(OrtxObject** object) {
return err;
}
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor) {
if (result == nullptr || tensor == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}
auto result_ptr = static_cast<TensorResult*>(result);
ReturnableStatus status(result_ptr->IsInstanceOf(extObjectKind_t::kOrtxKindTensorResult));
if (!status.IsOk()) {
return status.Code();
}
ortc::TensorBase* ts = result_ptr->GetAt(index);
if (ts == nullptr) {
ReturnableStatus::last_error_message_ = "Cannot get the tensor at the specified index from the result";
return kOrtxErrorInvalidArgument;
}
auto tensor_ptr = std::make_unique<OrtxObjectWrapper<ortc::TensorBase, kOrtxKindTensor>>();
tensor_ptr->SetObject(ts);
*tensor = static_cast<OrtxTensor*>(tensor_ptr.release());
return extError_t();
}
extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data, const int64_t** shape,
size_t* num_dims) {
if (tensor == nullptr) {
@ -120,7 +151,7 @@ extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data
return kOrtxErrorInvalidArgument;
}
auto tensor_impl = static_cast<OrtxObjectWrapper<ortc::TensorBase>*>(tensor);
auto tensor_impl = static_cast<OrtxObjectWrapper<ortc::TensorBase, kOrtxKindTensor>*>(tensor);
if (tensor_impl->ortx_kind() != extObjectKind_t::kOrtxKindTensor) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;

Просмотреть файл

@ -3,8 +3,10 @@
#pragma once
#include <vector>
#include <fstream>
#include "ortx_utils.h"
#include "file_sys.h"
#include "ext_status.h"
#include "op_def_struct.h"
@ -12,7 +14,7 @@ namespace ort_extensions {
class OrtxObjectImpl : public OrtxObject {
public:
explicit OrtxObjectImpl(extObjectKind_t kind = extObjectKind_t::kOrtxKindUnknown) : OrtxObject() {
ext_kind_ = static_cast<int>(kind);
ext_kind_ = kind;
};
virtual ~OrtxObjectImpl() = default;
@ -24,30 +26,21 @@ class OrtxObjectImpl : public OrtxObject {
}
return static_cast<extObjectKind_t>(ext_kind_);
}
template <typename T>
struct Type2Kind {
static const extObjectKind_t value = kOrtxKindUnknown;
};
};
template <>
struct OrtxObjectImpl::Type2Kind<ortc::TensorBase> {
static const extObjectKind_t value = kOrtxKindTensor;
};
template <typename T>
// A wrapper class to store a object pointer which is readonly. i.e. unowned.
template <typename T, extObjectKind_t kind>
class OrtxObjectWrapper : public OrtxObjectImpl {
public:
OrtxObjectWrapper() : OrtxObjectImpl(OrtxObjectImpl::Type2Kind<T>::value) {}
OrtxObjectWrapper() : OrtxObjectImpl(kind) {}
~OrtxObjectWrapper() override = default;
void SetObject(T* t) { stored_object_ = t; }
void SetObject(const T* t) { stored_object_ = t; }
[[nodiscard]] T* GetObject() const { return stored_object_; }
[[nodiscard]] const T* GetObject() const { return stored_object_; }
private:
T* stored_object_{};
const T* stored_object_{};
};
template <typename T>
@ -100,6 +93,35 @@ class StringArray : public OrtxObjectImpl {
std::vector<std::string> strings_;
};
class TensorResult : public OrtxObjectImpl {
public:
TensorResult() : OrtxObjectImpl(extObjectKind_t::kOrtxKindTensorResult) {}
~TensorResult() override = default;
void SetTensors(std::vector<std::unique_ptr<ortc::TensorBase>>&& tensors) { tensors_ = std::move(tensors); }
[[nodiscard]] const std::vector<std::unique_ptr<ortc::TensorBase>>& tensors() const { return tensors_; }
[[nodiscard]] std::vector<ortc::TensorBase*> GetTensors() const {
std::vector<ortc::TensorBase*> ts;
ts.reserve(tensors_.size());
for (auto& t : tensors_) {
ts.push_back(t.get());
}
return ts;
}
ortc::TensorBase* GetAt(size_t i) const {
if (i < tensors_.size()) {
return tensors_[i].get();
}
return nullptr;
}
private:
std::vector<std::unique_ptr<ortc::TensorBase>> tensors_;
};
struct ReturnableStatus {
public:
thread_local static std::string last_error_message_;
@ -123,25 +145,26 @@ struct ReturnableStatus {
OrtxStatus status_;
};
template <typename T>
class OrtxObjectFactory {
public:
static std::unique_ptr<T> Create() { return std::make_unique<T>(); }
static OrtxObject* CreateForward();
static void DisposeForward(OrtxObject* object);
template <typename T>
static OrtxObject* Create() {
return std::make_unique<T>().release();
}
template <typename T>
static void Dispose(OrtxObject* object) {
auto obj_ptr = static_cast<T*>(object);
std::unique_ptr<T> ptr(obj_ptr);
ptr.reset();
}
// Forward declaration for creating an object which isn't visible to c_api_utils.cc
// and the definition is in the corresponding .cc file.
template <typename T>
static OrtxObject* CreateForward();
};
class DetokenizerCache; // forward definition in tokenizer_impl.cc
class ProcessorResult; // forward definition in image_processor.h
class CppAllocator : public ortc::IAllocator {
public:
void* Alloc(size_t size) override { return std::make_unique<char[]>(size).release(); }
@ -157,4 +180,25 @@ class CppAllocator : public ortc::IAllocator {
}
};
template <typename It, typename T>
std::tuple<std::unique_ptr<T[]>, size_t> LoadRawData(It begin, It end) {
auto raw_data = std::make_unique<T[]>(end - begin);
size_t n = 0;
for (auto it = begin; it != end; ++it) {
std::ifstream ifs = path(*it).open(std::ios::binary | std::ios::in);
if (!ifs.is_open()) {
break;
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
ifs.seekg(0, std::ios::beg);
T& datum = raw_data[n++];
datum.resize(size);
ifs.read(reinterpret_cast<char*>(datum.data()), size);
}
return std::make_tuple(std::move(raw_data), n);
}
} // namespace ort_extensions

Просмотреть файл

@ -7,6 +7,7 @@
#include "file_sys.h"
#include "image_processor.h"
#include "c_api_utils.hpp"
#include "cv2/imgcodecs/imdecode.hpp"
#include "image_transforms.hpp"
#include "image_transforms_phi_3.hpp"
@ -14,38 +15,11 @@
using namespace ort_extensions;
using json = nlohmann::json;
namespace ort_extensions {
template <typename It>
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(It begin, It end) {
auto raw_images = std::make_unique<ImageRawData[]>(end - begin);
size_t n = 0;
for (auto it = begin; it != end; ++it) {
std::ifstream ifs = path(*it).open(std::ios::binary);
if (!ifs.is_open()) {
break;
}
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
ifs.seekg(0, std::ios::beg);
ImageRawData& raw_image = raw_images[n++];
raw_image.resize(size);
ifs.read(reinterpret_cast<char*>(raw_image.data()), size);
}
return std::make_tuple(std::move(raw_images), n);
std::tuple<std::unique_ptr<ImageRawData[]>, size_t>
ort_extensions::LoadRawImages(const std::initializer_list<const char*>& image_paths) {
return ort_extensions::LoadRawData<const char* const*, ImageRawData>(image_paths.begin(), image_paths.end());
}
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
const std::initializer_list<const char*>& image_paths) {
return LoadRawImages(image_paths.begin(), image_paths.end());
}
template std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages<char const**>(char const**, char const**);
} // namespace ort_extensions
Operation::KernelRegistry ImageProcessor::kernel_registry_ = {
{"DecodeImage", []() { return CreateKernelInstance(image_decoder); }},
{"Resize", []() { return CreateKernelInstance(&Resize::Compute); }},
@ -97,9 +71,7 @@ OrtxStatus ImageProcessor::Init(std::string_view processor_def) {
return {};
}
ImageProcessor::ImageProcessor()
: OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) {
}
ImageProcessor::ImageProcessor() : OrtxObjectImpl(kOrtxKindProcessor), allocator_(&CppAllocator::Instance()) {}
template <typename T>
static ortc::Tensor<T>* StackTensor(const std::vector<TensorArgs>& arg_lists, int axis, ortc::IAllocator* allocator) {
@ -136,39 +108,6 @@ static ortc::Tensor<T>* StackTensor(const std::vector<TensorArgs>& arg_lists, in
return output.release();
}
static OrtxStatus StackTensors(const std::vector<TensorArgs>& arg_lists, std::vector<TensorPtr>& outputs,
ortc::IAllocator* allocator) {
if (arg_lists.empty()) {
return {};
}
size_t batch_size = arg_lists.size();
size_t num_outputs = arg_lists[0].size();
for (size_t axis = 0; axis < num_outputs; ++axis) {
std::vector<ortc::TensorBase*> ts_ptrs;
ts_ptrs.reserve(arg_lists.size());
std::vector<int64_t> shape = arg_lists[0][axis]->Shape();
for (auto& ts : arg_lists) {
if (shape != ts[axis]->Shape()) {
return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."};
}
ts_ptrs.push_back(ts[axis]);
}
std::vector<int64_t> output_shape = shape;
output_shape.insert(output_shape.begin(), batch_size);
std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape);
for (size_t i = 0; i < batch_size; ++i) {
auto ts = ts_ptrs[i];
const std::byte* ts_buff = reinterpret_cast<const std::byte*>(ts->DataRaw());
auto ts_size = ts->SizeInBytes();
std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size);
}
}
return {};
}
std::tuple<OrtxStatus, ProcessorResult> ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_data,
ortc::Tensor<float>** pixel_values,
ortc::Tensor<int64_t>** image_sizes,
@ -209,7 +148,7 @@ std::tuple<OrtxStatus, ProcessorResult> ImageProcessor::PreProcess(ort_extension
return {status, std::move(r)};
}
OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_data, ImageProcessorResult& r) const {
OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_data, TensorResult& r) const {
std::vector<TensorArgs> inputs;
inputs.resize(image_data.size());
for (size_t i = 0; i < image_data.size(); ++i) {
@ -235,9 +174,13 @@ OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_d
}
}
r.results = operations_.back()->AllocateOutputs(allocator_);
status = StackTensors(outputs, r.results, allocator_);
auto img_result = operations_.back()->AllocateOutputs(allocator_);
status = OrtxRunner::StackTensors(outputs, img_result, allocator_);
operations_.back()->ResetTensors(allocator_);
if (status.IsOk()) {
r.SetTensors(std::move(img_result));
}
return status;
}
@ -257,14 +200,3 @@ void ImageProcessor::ClearOutputs(ProcessorResult* r) {
r->num_img_takens = nullptr;
}
}
void ort_extensions::ImageProcessor::ClearOutputs(ImageProcessorResult* r) {
if (r == nullptr) {
return;
}
for (auto& ts : r->results) {
ts.reset();
}
r->results.clear(); // clear the vector
}

Просмотреть файл

@ -16,9 +16,6 @@ namespace ort_extensions {
using ImageRawData = std::vector<uint8_t>;
template <typename It>
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(It begin, It end);
std::tuple<std::unique_ptr<ImageRawData[]>, size_t> LoadRawImages(
const std::initializer_list<const char*>& image_paths);
@ -29,13 +26,6 @@ class ProcessorResult : public OrtxObjectImpl {
ortc::Tensor<int64_t>* image_sizes{};
ortc::Tensor<int64_t>* num_img_takens{};
};
class ImageProcessorResult : public OrtxObjectImpl {
public:
ImageProcessorResult() : OrtxObjectImpl(kOrtxKindImageProcessorResult) {}
std::vector<TensorPtr> results;
};
class ImageProcessor : public OrtxObjectImpl {
public:
ImageProcessor();
@ -43,15 +33,16 @@ class ImageProcessor : public OrtxObjectImpl {
OrtxStatus Init(std::string_view processor_def);
// Deprecated, using the next function instead
std::tuple<OrtxStatus, ProcessorResult> PreProcess(ort_extensions::span<ImageRawData> image_data,
ortc::Tensor<float>** pixel_values,
ortc::Tensor<int64_t>** image_sizes,
ortc::Tensor<int64_t>** num_img_takens) const;
OrtxStatus PreProcess(ort_extensions::span<ImageRawData> image_data, ImageProcessorResult& r) const;
OrtxStatus PreProcess(ort_extensions::span<ImageRawData> image_data, TensorResult& r) const;
// Deprecated, using the next function instead
static void ClearOutputs(ProcessorResult* r);
static void ClearOutputs(ImageProcessorResult* r);
static Operation::KernelRegistry kernel_registry_;

Просмотреть файл

@ -28,7 +28,8 @@ class KernelDef {
virtual TensorArgs AllocateOutput(ortc::IAllocator* allocator) const = 0;
virtual OrtxStatus Apply(TensorArgs& inputs, TensorArgs& output) const = 0;
using AttrType = std::variant<std::string, double, int64_t, std::vector<double>>;
using AttrType =
std::variant<std::string, double, int64_t, std::vector<std::string>, std::vector<double>, std::vector<int64_t>>;
using AttrDict = std::unordered_map<std::string, AttrType>;
template <typename... Args>
@ -98,7 +99,7 @@ class KernelDef {
template <typename... Args>
class KernelFunction : public KernelDef {
public:
KernelFunction(OrtxStatus (*body)(Args...)) : body_(body){};
KernelFunction(OrtxStatus (*body)(Args...)) : body_(body) {};
virtual ~KernelFunction() = default;
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
@ -132,7 +133,7 @@ class KernelFunction : public KernelDef {
template <typename T, typename... Args>
class KernelStruct : public KernelDef {
public:
KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body){};
KernelStruct(OrtxStatus (T::*body)(Args...)) : body_(body) {};
virtual ~KernelStruct() = default;
TensorArgs AllocateOutput(ortc::IAllocator* allocator) const override {
@ -167,8 +168,18 @@ class KernelStruct : public KernelDef {
attr_dict[key] = value.template get<int64_t>();
} else if (value.is_number_float()) {
attr_dict[key] = value.template get<double>();
} else if (value.is_array()) {
attr_dict[key] = value.template get<std::vector<double>>();
} else if (value.is_array() && value.size() > 0) {
auto& elem_0 = value.at(0);
if (elem_0.is_number_float()) {
attr_dict[key] = value.template get<std::vector<double>>();
} else if (elem_0.is_string()) {
attr_dict[key] = value.template get<std::vector<std::string>>();
} else if (elem_0.is_number_integer() || elem_0.is_number_unsigned()) {
attr_dict[key] = value.template get<std::vector<int64_t>>();
} else {
return {kOrtxErrorCorruptData, "Unsupported mix types in attribute value."};
}
} else {
return {kOrtxErrorCorruptData, "Invalid attribute type."};
}
@ -309,6 +320,39 @@ class OrtxRunner {
return {};
}
static OrtxStatus StackTensors(const std::vector<TensorArgs>& arg_lists, std::vector<TensorPtr>& outputs,
ortc::IAllocator* allocator) {
if (arg_lists.empty()) {
return {};
}
size_t batch_size = arg_lists.size();
size_t num_outputs = arg_lists[0].size();
for (size_t axis = 0; axis < num_outputs; ++axis) {
std::vector<ortc::TensorBase*> ts_ptrs;
ts_ptrs.reserve(arg_lists.size());
std::vector<int64_t> shape = arg_lists[0][axis]->Shape();
for (auto& ts : arg_lists) {
if (shape != ts[axis]->Shape()) {
return {kOrtxErrorInvalidArgument, "[StackTensors]: shapes of tensors to stack are not the same."};
}
ts_ptrs.push_back(ts[axis]);
}
std::vector<int64_t> output_shape = shape;
output_shape.insert(output_shape.begin(), batch_size);
std::byte* tensor_buf = outputs[axis]->AllocateRaw(output_shape);
for (size_t i = 0; i < batch_size; ++i) {
auto ts = ts_ptrs[i];
const std::byte* ts_buff = reinterpret_cast<const std::byte*>(ts->DataRaw());
auto ts_size = ts->SizeInBytes();
std::memcpy(tensor_buf + i * ts_size, ts_buff, ts_size);
}
}
return {};
}
private:
ortc::IAllocator* allocator_;
std::vector<Operation*> ops_;

Просмотреть файл

@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "speech_extractor.h"
#include "audio/audio_decoder.h"
#include "speech_features.hpp"
using namespace ort_extensions;
Operation::KernelRegistry SpeechFeatureExtractor::kernel_registry_ = {
{"AudioDecoder", []() { return CreateKernelInstance(&AudioDecoder::ComputeNoOpt); }},
{"STFTNorm", []() { return CreateKernelInstance(&SpeechFeatures::STFTNorm); }},
{"LogMelSpectrum", []() { return CreateKernelInstance(&LogMel::Compute); }},
};
SpeechFeatureExtractor::SpeechFeatureExtractor()
: OrtxObjectImpl(extObjectKind_t::kOrtxKindFeatureExtractor), allocator_(&CppAllocator::Instance()) {}
OrtxStatus SpeechFeatureExtractor::Init(std::string_view extractor_def) {
std::string fe_def_str;
if (extractor_def.size() >= 5 && extractor_def.substr(extractor_def.size() - 5) == ".json") {
std::ifstream ifs = path({extractor_def.data(), extractor_def.size()}).open();
if (!ifs.is_open()) {
return {kOrtxErrorInvalidArgument, std::string("[ImageProcessor]: failed to open ") + std::string(extractor_def)};
}
fe_def_str = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
extractor_def = fe_def_str.c_str();
}
// pase the extraction_def by json
auto fe_json = json::parse(extractor_def, nullptr, false);
if (fe_json.is_discarded()) {
return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: failed to parse extractor json configuration."};
}
auto fe_root = fe_json.at("feature_extraction");
if (!fe_root.is_object()) {
return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: feature_extraction field is missing."};
}
auto op_sequence = fe_root.at("sequence");
if (!op_sequence.is_array() || op_sequence.empty()) {
return {kOrtxErrorInvalidArgument, "[SpeechFeatureExtractor]: sequence field is missing."};
}
operations_.reserve(op_sequence.size());
for (auto mod_iter = op_sequence.begin(); mod_iter != op_sequence.end(); ++mod_iter) {
auto op = std::make_unique<Operation>(kernel_registry_);
auto status = op->Init(mod_iter->dump());
if (!status.IsOk()) {
return status;
}
operations_.push_back(std::move(op));
}
return {};
}
OrtxStatus SpeechFeatureExtractor::DoCall(ort_extensions::span<AudioRawData> raw_speech,
std::unique_ptr<ortc::Tensor<float>>& log_mel) const {
// setup the input tensors
std::vector<TensorArgs> inputs;
inputs.resize(raw_speech.size());
for (size_t i = 0; i < raw_speech.size(); ++i) {
auto& ts_input = inputs[i];
AudioRawData& speech = raw_speech[i];
std::vector<int64_t> shape = {static_cast<int64_t>(speech.size())};
ts_input.push_back(std::make_unique<ortc::Tensor<uint8_t>>(shape, speech.data()).release());
}
std::vector<TensorArgs> outputs;
std::vector<Operation*> ops(operations_.size());
std::transform(operations_.begin(), operations_.end(), ops.begin(), [](auto& op) { return op.get(); });
OrtxRunner runner(allocator_, ops.data(), ops.size());
auto status = runner.Run(inputs, outputs);
if (!status.IsOk()) {
return status;
}
// clear the input tensors
for (auto& input : inputs) {
for (auto& ts : input) {
std::unique_ptr<ortc::TensorBase>(ts).reset();
}
}
auto results = operations_.back()->AllocateOutputs(allocator_);
status = OrtxRunner::StackTensors(outputs, results, allocator_);
if (status.IsOk()) {
log_mel.reset(static_cast<ortc::Tensor<float>*>(results[0].release()));
operations_.back()->ResetTensors(allocator_);
}
return status;
}

Просмотреть файл

@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "ortx_extractor.h"
#include "c_api_utils.hpp"
#include "runner.hpp"
namespace ort_extensions {
typedef std::vector<std::byte> AudioRawData;
class SpeechFeatureExtractor : public OrtxObjectImpl {
public:
SpeechFeatureExtractor();
virtual ~SpeechFeatureExtractor() = default;
public:
OrtxStatus Init(std::string_view extractor_def);
OrtxStatus DoCall(ort_extensions::span<AudioRawData> raw_speech, std::unique_ptr<ortc::Tensor<float>>& log_mel) const;
static Operation::KernelRegistry kernel_registry_;
private:
std::vector<std::unique_ptr<Operation>> operations_;
ortc::IAllocator* allocator_;
};
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,224 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <dlib/matrix.h>
#include <math/dlib/stft_norm.hpp>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
namespace ort_extensions {
class SpeechFeatures {
public:
template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
for (const auto& [key, value] : attrs) {
if (key == "n_fft") {
n_fft_ = std::get<int64_t>(value);
} else if (key == "hop_length") {
hop_length_ = std::get<int64_t>(value);
} else if (key == "frame_length") {
frame_length_ = std::get<int64_t>(value);
} else if (key == "hann_win") {
auto& win = std::get<std::vector<double>>(value);
hann_win_.resize(win.size());
std::transform(win.begin(), win.end(), hann_win_.begin(), [](double x) { return static_cast<float>(x); });
} else if (key != "_comment") {
return {kOrtxErrorInvalidArgument, "[AudioFeatures]: Invalid key in the JSON configuration."};
}
}
if (hann_win_.empty()) {
hann_win_ = hann_window(frame_length_);
}
return {};
}
OrtxStatus STFTNorm(const ortc::Tensor<float>& pcm, ortc::Tensor<float>& stft_norm) {
return stft_norm_.Compute(pcm, n_fft_, hop_length_, {hann_win_.data(), hann_win_.size()}, frame_length_, stft_norm);
}
static std::vector<float> hann_window(int N) {
std::vector<float> window(N);
for (int n = 0; n < N; ++n) {
// this formula leads to more rounding errors than the one below
// window[n] = static_cast<float>(0.5 * (1 - std::cos(2 * M_PI * n / (N - 1))));
double n_sin = std::sin(M_PI * n / N);
window[n] = static_cast<float>(n_sin * n_sin);
}
return window;
}
private:
StftNormal stft_norm_;
int64_t n_fft_{};
int64_t hop_length_{};
int64_t frame_length_{};
std::vector<float> hann_win_;
};
class LogMel {
public:
template <typename DictT>
OrtxStatus Init(const DictT& attrs) {
int n_fft = 0;
int n_mel = 0;
int chunk_size = 0;
for (const auto& [key, value] : attrs) {
if (key == "hop_length") {
hop_length_ = std::get<int64_t>(value);
} else if (key == "n_fft") {
n_fft = std::get<int64_t>(value);
} else if (key == "n_mel") {
n_mel = std::get<int64_t>(value);
} else if (key == "chunk_size") {
chunk_size = std::get<int64_t>(value);
} else {
return {kOrtxErrorInvalidArgument, "[LogMel]: Invalid key in the JSON configuration."};
}
}
n_samples_ = n_sr_ * chunk_size;
mel_filters_ = MelFilterBank(n_fft, n_mel, n_sr_);
return {};
}
OrtxStatus Compute(const ortc::Tensor<float>& stft_norm, ortc::Tensor<float>& logmel) {
// Compute the Mel spectrogram by following Python code
/*
magnitudes = stft_norm[:, :, :-1]
mel_spec = self.mel_filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
spec_min = log_spec.max() - 8.0
log_spec = torch.maximum(log_spec, spec_min)
spec_shape = log_spec.shape
padding_spec = torch.ones(spec_shape[0],
spec_shape[1],
self.n_samples // self.hop_length - spec_shape[2],
dtype=torch.float)
padding_spec *= spec_min
log_spec = torch.cat((log_spec, padding_spec), dim=2)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
*/
assert(stft_norm.Shape().size() == 3 && stft_norm.Shape()[0] == 1);
std::vector<int64_t> stft_shape = stft_norm.Shape();
dlib::matrix<float> magnitudes(stft_norm.Shape()[1], stft_norm.Shape()[2] - 1);
for (int i = 0; i < magnitudes.nr(); ++i) {
std::copy(stft_norm.Data() + i * stft_shape[2], stft_norm.Data() + (i + 1) * stft_shape[2] - 1,
magnitudes.begin() + i * magnitudes.nc());
}
dlib::matrix<float> mel_spec = mel_filters_ * magnitudes;
for (int i = 0; i < mel_spec.nr(); ++i) {
for (int j = 0; j < mel_spec.nc(); ++j) {
mel_spec(i, j) = std::max(1e-10f, mel_spec(i, j));
}
}
dlib::matrix<float> log_spec = dlib::log10(mel_spec);
float log_spec_min = dlib::max(log_spec) - 8.0f;
for (int i = 0; i < log_spec.nr(); ++i) {
for (int j = 0; j < log_spec.nc(); ++j) {
float v = std::max(log_spec(i, j), log_spec_min);
v = (v + 4.0f) / 4.0f;
log_spec(i, j) = v;
}
}
std::vector<int64_t> shape = {mel_filters_.nr(), n_samples_ / hop_length_};
float* buff = logmel.Allocate(shape);
std::fill(buff, buff + logmel.NumberOfElement(), (log_spec_min + 4.0f) / 4.0f);
for (int i = 0; i < log_spec.nr(); ++i) {
auto row_len = log_spec.nc() * i;
std::copy(log_spec.begin() + i * log_spec.nc(), log_spec.begin() + (i + 1) * log_spec.nc(), buff + i * shape[1]);
}
return {};
}
// Function to compute the Mel filterbank
static dlib::matrix<float> MelFilterBank(int n_fft, int n_mels, int sr = 16000, float min_mel = 0,
float max_mel = 45.245640471924965) {
// Initialize the filterbank matrix
dlib::matrix<float> fbank(n_mels, n_fft / 2 + 1);
memset(fbank.begin(), 0, fbank.size() * sizeof(float));
// Compute the frequency bins for the DFT
std::vector<float> freq_bins(n_fft / 2 + 1);
for (int i = 0; i <= n_fft / 2; ++i) {
freq_bins[i] = i * sr / static_cast<float>(n_fft);
}
// Compute the Mel scale frequencies
std::vector<float> mel(n_mels + 2);
for (int i = 0; i < n_mels + 2; ++i) {
mel[i] = min_mel + i * (max_mel - min_mel) / (n_mels + 1);
}
// Fill in the linear scale
float f_min = 0.0f;
float f_sp = 200.0f / 3.0f;
std::vector<float> freqs(n_mels + 2);
for (int i = 0; i < n_mels + 2; ++i) {
freqs[i] = f_min + f_sp * mel[i];
}
// Nonlinear scale
float min_log_hz = 1000.0f;
float min_log_mel = (min_log_hz - f_min) / f_sp;
float logstep = log(6.4) / 27.0;
for (int i = 0; i < n_mels + 2; ++i) {
if (mel[i] >= min_log_mel) {
freqs[i] = min_log_hz * exp(logstep * (mel[i] - min_log_mel));
}
}
std::vector<float> mel_bins = freqs;
std::vector<float> mel_spacing(n_mels + 1);
for (int i = 0; i < n_mels + 1; ++i) {
mel_spacing[i] = mel_bins[i + 1] - mel_bins[i];
}
// Compute the ramps
std::vector<std::vector<float>> ramps(n_mels + 2, std::vector<float>(n_fft / 2 + 1));
for (int i = 0; i < n_mels + 2; ++i) {
for (int j = 0; j <= n_fft / 2; ++j) {
ramps[i][j] = mel_bins[i] - freq_bins[j];
}
}
for (int i = 0; i < n_mels; ++i) {
for (int j = 0; j <= n_fft / 2; ++j) {
float left = -ramps[i][j] / mel_spacing[i];
float right = ramps[i + 2][j] / mel_spacing[i + 1];
fbank(i, j) = std::max(0.0f, std::min(left, right));
}
}
// Energy normalization
for (int i = 0; i < n_mels; ++i) {
float energy_norm = 2.0f / (mel_bins[i + 2] - mel_bins[i]);
for (int j = 0; j <= n_fft / 2; ++j) {
fbank(i, j) *= energy_norm;
}
}
return fbank;
}
private:
int64_t n_samples_ = {}; // sr * chunk_size
int64_t hop_length_{};
const int64_t n_sr_{16000};
dlib::matrix<float> mel_filters_;
};
} // namespace ort_extensions

Просмотреть файл

@ -0,0 +1,437 @@
{
"feature_extraction": {
"sequence": [
{
"operation": {
"name": "audio_decoder",
"type": "AudioDecoder"
}
},
{
"operation": {
"name": "STFT",
"type": "STFTNorm",
"attrs": {
"n_fft": 400,
"frame_length": 400,
"hop_length": 160,
"_comment": [
0.0,
0.0000616908073425293,
0.0002467334270477295,
0.0005550682544708252,
0.000986635684967041,
0.0015413463115692139,
0.0022190213203430176,
0.0030195116996765137,
0.003942638635635376,
0.004988163709640503,
0.006155818700790405,
0.007445335388183594,
0.008856385946273804,
0.010388582944869995,
0.012041628360748291,
0.013815045356750488,
0.01570841670036316,
0.01772129535675049,
0.019853144884109497,
0.022103488445281982,
0.02447172999382019,
0.026957333087921143,
0.029559612274169922,
0.03227800130844116,
0.03511175513267517,
0.03806024789810181,
0.0411226749420166,
0.044298380613327026,
0.04758647084236145,
0.05098623037338257,
0.05449673533439636,
0.058117181062698364,
0.06184667348861694,
0.0656842589378357,
0.06962898373603821,
0.07367992401123047,
0.0778360664844513,
0.08209633827209473,
0.08645972609519958,
0.09092515707015991,
0.09549149870872498,
0.10015767812728882,
0.10492250323295593,
0.1097848117351532,
0.11474338173866272,
0.11979702115058899,
0.12494447827339172,
0.13018447160720825,
0.1355157196521759,
0.14093685150146484,
0.1464466154575348,
0.15204361081123352,
0.1577264666557312,
0.16349375247955322,
0.16934409737586975,
0.1752760112285614,
0.18128803372383118,
0.18737870454788208,
0.19354650378227234,
0.1997898817062378,
0.20610737800598145,
0.21249738335609436,
0.21895831823349,
0.2254886031150818,
0.23208662867546082,
0.23875075578689575,
0.24547931551933289,
0.2522706985473633,
0.25912320613861084,
0.26603513956069946,
0.27300477027893066,
0.2800304591655731,
0.2871103882789612,
0.29424285888671875,
0.30142611265182495,
0.30865830183029175,
0.31593772768974304,
0.3232625722885132,
0.3306310474872589,
0.3380413055419922,
0.34549152851104736,
0.352979838848114,
0.3605044484138489,
0.3680635094642639,
0.37565508484840393,
0.38327735662460327,
0.3909284174442291,
0.39860638976097107,
0.4063093662261963,
0.41403549909591675,
0.42178282141685486,
0.4295494258403778,
0.43733343482017517,
0.44513291120529175,
0.45294591784477234,
0.46077051758766174,
0.46860480308532715,
0.4764467775821686,
0.4842946231365204,
0.492146372795105,
0.5,
0.5078536868095398,
0.515705406665802,
0.5235532522201538,
0.5313953161239624,
0.5392295718193054,
0.5470541715621948,
0.5548672080039978,
0.562666654586792,
0.5704506635665894,
0.5782172679901123,
0.5859646201133728,
0.5936906933784485,
0.6013936996459961,
0.609071671962738,
0.6167227625846863,
0.6243450045585632,
0.6319366097450256,
0.6394955515861511,
0.6470202207565308,
0.6545085310935974,
0.6619587540626526,
0.6693689823150635,
0.6767374277114868,
0.6840623021125793,
0.691341757774353,
0.6985740065574646,
0.7057572603225708,
0.7128896713256836,
0.719969630241394,
0.7269952893257141,
0.7339649796485901,
0.7408769130706787,
0.7477294206619263,
0.7545207738876343,
0.761249303817749,
0.7679134607315063,
0.774511456489563,
0.7810417413711548,
0.7875027060508728,
0.7938927412033081,
0.800210177898407,
0.8064535856246948,
0.8126214146614075,
0.8187121152877808,
0.8247240781784058,
0.8306560516357422,
0.8365063667297363,
0.8422735929489136,
0.8479564785957336,
0.8535534143447876,
0.8590631484985352,
0.8644843101501465,
0.8698155879974365,
0.8750555515289307,
0.8802030086517334,
0.8852566480636597,
0.8902152180671692,
0.8950775265693665,
0.899842381477356,
0.9045084714889526,
0.9090749025344849,
0.9135403037071228,
0.9179036617279053,
0.9221639633178711,
0.9263200759887695,
0.9303710460662842,
0.9343158006668091,
0.9381533861160278,
0.941882848739624,
0.945503294467926,
0.9490138292312622,
0.9524135589599609,
0.9557017087936401,
0.9588773250579834,
0.961939811706543,
0.9648882746696472,
0.9677220582962036,
0.9704403877258301,
0.9730427265167236,
0.9755282998085022,
0.9778965711593628,
0.9801468849182129,
0.9822787046432495,
0.9842916131019592,
0.9861849546432495,
0.9879584312438965,
0.9896113872528076,
0.9911436438560486,
0.9925546646118164,
0.9938441514968872,
0.9950118064880371,
0.996057391166687,
0.9969804883003235,
0.997780978679657,
0.9984586238861084,
0.999013364315033,
0.9994449615478516,
0.9997532367706299,
0.9999383091926575,
1,
0.9999383091926575,
0.9997532367706299,
0.9994449615478516,
0.999013364315033,
0.9984586238861084,
0.997780978679657,
0.9969804286956787,
0.9960573315620422,
0.9950118064880371,
0.9938441514968872,
0.9925546646118164,
0.9911435842514038,
0.9896113872528076,
0.9879583716392517,
0.9861849546432495,
0.9842915534973145,
0.9822787046432495,
0.9801468253135681,
0.9778964519500732,
0.9755282402038574,
0.9730426073074341,
0.9704403877258301,
0.9677219390869141,
0.9648882150650024,
0.9619396924972534,
0.9588772654533386,
0.9557015895843506,
0.9524134397506714,
0.9490137100219727,
0.9455032348632812,
0.9418827295303345,
0.9381532669067383,
0.9343156814575195,
0.9303709268569946,
0.9263200759887695,
0.9221639633178711,
0.9179036617279053,
0.913540244102478,
0.9090747833251953,
0.9045084714889526,
0.8998422622680664,
0.8950774669647217,
0.8902151584625244,
0.8852565884590149,
0.8802029490470886,
0.8750554919242859,
0.869815468788147,
0.8644842505455017,
0.8590630888938904,
0.853553295135498,
0.8479562997817993,
0.842273473739624,
0.836506187915802,
0.8306558728218079,
0.8247239589691162,
0.8187118768692017,
0.8126212358474731,
0.8064534664154053,
0.8002099990844727,
0.793892502784729,
0.7875025272369385,
0.7810416221618652,
0.7745113372802734,
0.767913281917572,
0.7612491846084595,
0.7545205950737,
0.7477291822433472,
0.7408767342567444,
0.7339648008346558,
0.7269951105117798,
0.7199694514274597,
0.7128894925117493,
0.7057570219039917,
0.6985738277435303,
0.6913415789604187,
0.684062123298645,
0.6767372488975525,
0.6693688035011292,
0.6619585752487183,
0.6545083522796631,
0.6470199823379517,
0.6394953727722168,
0.6319363117218018,
0.6243447661399841,
0.6167224645614624,
0.6090714335441589,
0.601393461227417,
0.5936904549598694,
0.5859643220901489,
0.5782170295715332,
0.5704504251480103,
0.5626664161682129,
0.5548669099807739,
0.5470539331436157,
0.5392293334007263,
0.5313950181007385,
0.5235530138015747,
0.5157051682472229,
0.507853627204895,
0.5,
0.4921463429927826,
0.484294593334198,
0.4764467477798462,
0.46860471367836,
0.4607704281806946,
0.4529458284378052,
0.4451328217983246,
0.437333345413208,
0.42954933643341064,
0.4217827320098877,
0.4140354096889496,
0.4063093066215515,
0.3986063003540039,
0.39092832803726196,
0.3832772672176361,
0.37565499544143677,
0.36806342005729675,
0.3605043888092041,
0.35297977924346924,
0.3454914391040802,
0.338041216135025,
0.33063095808029175,
0.3232625126838684,
0.3159376382827759,
0.3086581826210022,
0.3014259934425354,
0.2942427396774292,
0.28711026906967163,
0.2800303101539612,
0.2730046510696411,
0.2660350203514099,
0.2591230869293213,
0.25227057933807373,
0.24547919631004333,
0.2387506067752838,
0.23208650946617126,
0.22548848390579224,
0.21895819902420044,
0.2124972641468048,
0.2061072587966919,
0.19978976249694824,
0.1935463547706604,
0.18737855553627014,
0.18128788471221924,
0.17527586221694946,
0.1693439483642578,
0.16349363327026367,
0.15772631764411926,
0.15204349160194397,
0.14644649624824524,
0.1409367322921753,
0.13551557064056396,
0.1301843225955963,
0.12494435906410217,
0.11979690194129944,
0.11474326252937317,
0.10978469252586365,
0.10492238402366638,
0.10015755891799927,
0.09549137949943542,
0.09092503786087036,
0.08645960688591003,
0.08209621906280518,
0.07783591747283936,
0.07367980480194092,
0.06962886452674866,
0.06568413972854614,
0.06184655427932739,
0.0581170916557312,
0.0544966459274292,
0.05098611116409302,
0.04758638143539429,
0.044298261404037476,
0.04112258553504944,
0.038060128688812256,
0.03511166572570801,
0.03227788209915161,
0.02955952286720276,
0.02695724368095398,
0.024471670389175415,
0.02210339903831482,
0.01985308527946472,
0.017721205949783325,
0.015708357095718384,
0.0138150155544281,
0.012041598558425903,
0.010388582944869995,
0.008856356143951416,
0.007445335388183594,
0.006155818700790405,
0.004988163709640503,
0.003942638635635376,
0.0030195116996765137,
0.0022190213203430176,
0.0015413165092468262,
0.000986635684967041,
0.0005550682544708252,
0.0002467334270477295,
0.0000616908073425293
]
}
}
},
{
"operation": {
"name": "log_mel_spectrogram",
"type": "LogMelSpectrum",
"attrs": {
"chunk_size": 30,
"hop_length": 160,
"n_fft": 400,
"n_mel": 80
}
}
}
]
}
}

Просмотреть файл

@ -4,7 +4,9 @@
#pragma once
#include "ortx_tokenizer.h"
// make sure the C only compiler compatibility only.
#include "ortx_processor.h"
#include "ortx_extractor.h"
#ifdef __cplusplus

Просмотреть файл

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <vector>
#include <tuple>
#include <fstream>
#include <filesystem>
#include "gtest/gtest.h"
#include "ortx_cpp_helper.h"
#include "shared/api/speech_extractor.h"
using namespace ort_extensions;
TEST(ExtractorTest, TestWhisperFeatureExtraction) {
const char* audio_path[] = {"data/jfk.flac", "data/1272-141231-0002.wav", "data/1272-141231-0002.mp3"};
OrtxObjectPtr<OrtxRawAudios> raw_audios;
extError_t err = OrtxLoadAudios(ort_extensions::ptr(raw_audios), audio_path, 3);
ASSERT_EQ(err, kOrtxOK);
OrtxObjectPtr<OrtxFeatureExtractor> feature_extractor(OrtxCreateSpeechFeatureExtractor, "data/whisper/feature_extraction.json");
OrtxObjectPtr<OrtxTensorResult> result;
err = OrtxSpeechLogMel(feature_extractor.get(), raw_audios.get(), ort_extensions::ptr(result));
ASSERT_EQ(err, kOrtxOK);
OrtxObjectPtr<OrtxTensor> tensor;
err = OrtxTensorResultGetAt(result.get(), 0, ort_extensions::ptr(tensor));
ASSERT_EQ(err, kOrtxOK);
const float* data{};
const int64_t* shape{};
size_t num_dims;
err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(num_dims, 3);
ASSERT_EQ(shape[0], 3);
ASSERT_EQ(shape[1], 80);
ASSERT_EQ(shape[2], 3000);
}

Просмотреть файл

@ -7,7 +7,7 @@
#include <filesystem>
#include "gtest/gtest.h"
#include "ortx_c_helper.h"
#include "ortx_cpp_helper.h"
#include "shared/api/image_processor.h"
using namespace ort_extensions;
@ -85,18 +85,18 @@ TEST(ProcessorTest, TestClipImageProcessing) {
}
ASSERT_EQ(err, kOrtxOK);
OrtxObjectPtr<OrtxImageProcessorResult> result;
OrtxObjectPtr<OrtxTensorResult> result;
err = OrtxImagePreProcess(processor.get(), raw_images.get(), ort_extensions::ptr(result));
ASSERT_EQ(err, kOrtxOK);
OrtxObjectPtr<OrtxTensor> tensor;
err = OrtxImageGetTensorResult(result.get(), 0, ort_extensions::ptr(tensor));
OrtxTensor* tensor;
err = OrtxTensorResultGetAt(result.get(), 0, &tensor);
ASSERT_EQ(err, kOrtxOK);
const float* data{};
const int64_t* shape{};
size_t num_dims;
err = OrtxGetTensorDataFloat(tensor.get(), &data, &shape, &num_dims);
err = OrtxGetTensorDataFloat(tensor, &data, &shape, &num_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(num_dims, 4);
}

Просмотреть файл

Просмотреть файл

@ -92,17 +92,17 @@ class TestPreprocessing(unittest.TestCase):
merges_file=util.get_test_data_file("data", "gpt2.merges.txt"),
)
inputs = tok.forward(test_sentence)
pnp.export(tok, test_sentence, opset_version=12, output_path="temp_tok2.onnx")
pnp.export(tok, test_sentence, opset_version=14, output_path="temp_tok2.onnx")
with open("temp_gpt2lmh.onnx", "wb") as f:
torch.onnx.export(
gpt2_m, inputs, f, opset_version=12, do_constant_folding=False
gpt2_m, inputs, f, opset_version=14, do_constant_folding=False
)
pnp.export(gpt2_m, *inputs, opset_version=12, do_constant_folding=False)
pnp.export(gpt2_m, *inputs, opset_version=14, do_constant_folding=False)
full_model = pnp.SequentialProcessingModule(tok, gpt2_m)
expected = full_model.forward(test_sentence)
model = pnp.export(
full_model, test_sentence, opset_version=12, do_constant_folding=False
full_model, test_sentence, opset_version=14, do_constant_folding=False
)
mfunc = OrtPyFunction.from_model(model)
actuals = mfunc(test_sentence)