Enable the status returnable APIs from ORT 1.16 C ABI (#558)

* Initial checkins for returnable ORT ABIs

* fix for linux build

* more fixes on Python, test...

* remove the statusmsg

* native unit tests fixing

* Python unit tests fixing

* last unit test fixing
This commit is contained in:
Wenbing Li 2023-09-13 14:59:09 -07:00 коммит произвёл GitHub
Родитель bd5de8c420
Коммит 914509d524
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
18 изменённых файлов: 459 добавлений и 257 удалений

Просмотреть файл

@ -75,7 +75,9 @@ option(OCOS_BUILD_APPLE_FRAMEWORK "Enable building of the MacOS/iOS framework" O
# Optional value. Some operators do not support old versions due to using the new custom operator interface
# and will be disabled if this value is set and the version is incompatible.
set(OCOS_ONNXRUNTIME_VERSION "" CACHE STRING
"The version of ONNX Runtime being used in the build. Format is <major>.<minor>.<patch>. e.g. 1.15.1" )
"The version of ONNX Runtime being used in the build. Format is <major>.<minor>.<patch>. e.g. 1.15.1")
set(OCOS_ONNXRUNTIME_PKG_URI "" CACHE STRING
"Specify the onnxruntime C++ shared library zip package path, like ./onnxruntime-win-x64-1.16.0.zip")
# TODO: Remove the following statements if AzureOp build is enabled by default.
# If build_buildid environment varaible is set, which means this is a CI build, then always enable AzureOp.

Просмотреть файл

@ -4,6 +4,28 @@ if(_ONNXRUNTIME_EMBEDDED)
elseif(ONNXRUNTIME_PKG_DIR)
set(ONNXRUNTIME_INCLUDE_DIR ${ONNXRUNTIME_PKG_DIR}/include)
set(ONNXRUNTIME_LIB_DIR ${ONNXRUNTIME_PKG_DIR}/lib)
elseif(OCOS_ONNXRUNTIME_PKG_URI)
if (NOT OCOS_ONNXRUNTIME_VERSION)
message(FATAL_ERROR "OCOS_ONNXRUNTIME_PKG_URI is set but OCOS_ONNXRUNTIME_VERSION is not set")
endif()
set(ONNXRUNTIME_VER ${OCOS_ONNXRUNTIME_VERSION})
set(ONNXRUNTIME_URL ${OCOS_ONNXRUNTIME_PKG_URI})
message(STATUS "ONNX Runtime URL: ${OCOS_ONNXRUNTIME_PKG_URI}")
FetchContent_Declare(
onnxruntime
URL ${OCOS_ONNXRUNTIME_PKG_URI}
)
FetchContent_makeAvailable(onnxruntime)
if (ANDROID)
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/headers)
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/jni/${ANDROID_ABI})
message(STATUS "Android onnxruntime inc=${ONNXRUNTIME_INCLUDE_DIR} lib=${ONNXRUNTIME_LIB_DIR}")
else()
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include)
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib)
endif()
else()
message(STATUS "CMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR}")
message(STATUS "CMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM}")

Просмотреть файл

@ -91,6 +91,10 @@ class CuopContainer {
#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
#define CustomAzureStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "AzureExecutionProvider")); }
#define CustomCpuFuncV2(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2(name, "CPUExecutionProvider", f)); }
#define CustomCpuStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CPUExecutionProvider")); }
template <typename F>
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
F arg) {

Просмотреть файл

@ -13,18 +13,57 @@
#include <vector>
#include <utility>
#include <type_traits>
#include <optional>
#include "onnxruntime_c_api.h"
#include "exceptions.h"
#define MIN_ORT_VERSION_SUPPORTED 10
#define MIN_ORT_VERSION_SUPPORTED 11
extern "C" int ORT_API_CALL GetActiveOrtAPIVersion();
// namespace of ORT ABI Wrapper
namespace OrtW {
class API {
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
public:
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
static API self(*ort_api);
return self;
}
static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
return instance()->CreateStatus(code, msg);
}
static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
instance()->ReleaseStatus(ptr);
}
template<typename T>
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
static void ThrowOnError(OrtStatusPtr ptr) {
OrtW::ThrowOnError(instance().api_, ptr);
}
private:
const OrtApi* operator->() const {
return &api_;
}
API(const OrtApi& api) : api_(api) {
if (&api == nullptr) {
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
}
}
const OrtApi& api_;
};
//
// Custom OPs (only needed to implement custom OPs)
// DEPRECTED: Custom OPs (only needed to implement custom OPs)
//
struct CustomOpApi {
CustomOpApi(const OrtApi& api) : api_(api) {}
@ -63,93 +102,6 @@ struct CustomOpApi {
const OrtApi& api_;
};
template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
void* result = nullptr;
OCOS_API_IMPL_BEGIN
result = static_cast<const TOp*>(this_)->CreateKernel(*api, *info);
OCOS_API_IMPL_END
return result;
};
OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
return static_cast<const TOp*>(this_)->GetName();
};
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
return static_cast<const TOp*>(this_)->GetExecutionProviderType();
};
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
return static_cast<const TOp*>(this_)->GetInputTypeCount();
};
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
return static_cast<const TOp*>(this_)->GetInputType(index);
};
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
return static_cast<const TOp*>(this_)->GetOutputTypeCount();
};
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
return static_cast<const TOp*>(this_)->GetOutputType(index);
};
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
OCOS_API_IMPL_BEGIN
static_cast<TKernel*>(op_kernel)->Compute(context);
OCOS_API_IMPL_END
};
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26409)
#endif
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
return static_cast<const TOp*>(this_)->GetInputCharacteristic(index);
};
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index);
};
}
// default implementation. we can't use a virtual function as the layout of this struct has to be aligned with
// OrtCustomOp, but a derived class can override by creating a function with the same name and signature,
// calling this base class implementation as needed. e.g. see CustomOpThree in the unit test code
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26409)
#endif
return new TKernel(api, info);
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
}
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
const char* GetExecutionProviderType() const { return nullptr; }
// Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
// (inputs and outputs are required by default)
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
};
//
// Custom OP API Inlines
//
@ -299,8 +251,244 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
return out;
}
template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
}
template <>
inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
}
template <class T>
static OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
// Ideally, we should know which kind of error code can be ignored, but it is not availabe now.
// Just ignore all of them.
API::ReleaseStatus(status);
}
return nullptr;
}
inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
return API::CreateStatus(code, msg);
}
} // namespace OrtW
// !! TODO: only do it for legecy ort build
#if ORT_API_VERSION < 15
#include "custom_op_lite.h"
#else
// From onnxruntime 1.17, the custom op lite API header is used the one from onnxruntime package.
// #include "onnxruntime_lite_custom_op.h"
// The existing custom op lite API header has more features than the one from onnxruntime 1.16.
#include "custom_op_lite.h"
#endif // ORT_API_VERSION < 15
namespace Ort {
namespace Custom {
template <typename... Args>
struct FunctionKernel {
using ComputeFn = std::function<OrtStatusPtr(Args...)>;
OrtStatusPtr Compute(Args... args) const {
return compute_fn_(args...);
}
ComputeFn compute_fn_;
};
// primary template handles types that have no nested ::type member:
template <class, class = void>
struct IsFunctionKernel {
typedef std::false_type type;
};
// specialization recognizes types that do have a nested ::type member:
template <class T>
struct IsFunctionKernel<T, std::void_t<typename T::ComputeFn>>{
typedef std::true_type type;
};
// Helper type
template <typename T>
struct ComputeArgsList;
// Specialization for member function
template <typename C, typename... Args>
struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> {
using FunctionType = OrtStatusPtr (*)(Args...);
using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const;
};
template <typename CustomOpKernel>
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
using ComputeFunction = decltype(&CustomOpKernel::Compute);
using RegularComputeType = typename ComputeArgsList<ComputeFunction>::FunctionType;
template <typename... Args>
using MemberComputeType = OrtStatusPtr (CustomOpKernel::*)(Args...) const;
struct KernelEx : public CustomOpKernel {
struct {
std::string ep_{};
std::unique_ptr<OrtW::CustomOpApi> api_;
} extra_;
};
template <typename T>
static OrtStatusPtr InitKernel(KernelEx& kernel,
const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
return kernel.OnModelAttach(api, info);
}
static OrtStatusPtr InitKernel(
KernelEx& kernel,
const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, std::true_type) {
kernel.compute_fn_ = fn;
return nullptr;
}
template <typename... Args>
void ParseArgs(MemberComputeType<Args...> fn) {
OrtLiteCustomOp::ParseArgs<Args...>(OrtLiteCustomOp::input_types_, OrtLiteCustomOp::output_types_);
}
// TODO: consider to disable these legacy functions for mobile build to save binary size
template <typename... Args>
void DefineCallbackFunctionsLegacy(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
auto kernel = std::make_unique<KernelEx>();
typedef typename IsFunctionKernel<CustomOpKernel>::type type_flag;
OrtStatusPtr status = InitKernel(*kernel, *ort_api, *info, self->regular_fn_, type_flag());
OrtW::ThrowOnError(*ort_api, status);
kernel->extra_.ep_ = self->execution_provider_;
kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
return reinterpret_cast<void*>(kernel.release());
};
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
std::vector<TensorPtr> tensors;
auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
context,
tensors,
kernel->extra_.api_->KernelContext_GetInputCount(context),
kernel->extra_.api_->KernelContext_GetOutputCount(context),
kernel->extra_.ep_);
std::apply([kernel](Args const&... t_args) {
auto status = kernel->Compute(t_args...); OrtW::API::ThrowOnError(status);}, t);
};
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
};
}
#if ORT_API_VERSION > 15
template <typename... Args>
void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
OrtCustomOp::CreateKernel = nullptr;
OrtCustomOp::KernelCompute = nullptr;
OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
if (api == nullptr) {
assert(false && "Got a null pointer for ORT api on calling CreateKernelV2");
// should never happened, what we can do?
return nullptr;
}
if (this_ == nullptr || info == nullptr || op_kernel == nullptr) {
return api->CreateStatus(ORT_INVALID_ARGUMENT, "OrtCustomOp::CreateKernelV2: received a null pointer");
}
auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
auto kernel = std::make_unique<KernelEx>();
if (kernel == nullptr) {
return api->CreateStatus(ORT_FAIL, "OrtCustomOp::CreateKernelV2: failed to new a kernel, OOM?");
}
typedef typename IsFunctionKernel<CustomOpKernel>::type flag_type;
OrtStatusPtr status = InitKernel(*kernel, *api, *info, self->regular_fn_, flag_type());
if (status == nullptr) {
kernel->extra_.ep_ = self->execution_provider_;
kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*api);
*op_kernel = reinterpret_cast<void*>(kernel.release());
}
return status;
};
OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
auto kernel = reinterpret_cast<KernelEx* >(op_kernel);
std::vector<TensorPtr> tensors;
auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
context,
tensors,
kernel->extra_.api_->KernelContext_GetInputCount(context),
kernel->extra_.api_->KernelContext_GetOutputCount(context),
kernel->extra_.ep_);
return std::apply([kernel](Args const&... t_args) {
return kernel->Compute(t_args...); }, t);
};
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
};
}
#endif // ORT_API_VERSION > 15
OrtLiteCustomStructV2(const char* op_name,
const char* execution_provider,
RegularComputeType fn_compute = nullptr)
: OrtLiteCustomOp(op_name, execution_provider), regular_fn_(fn_compute) {
ParseArgs(&CustomOpKernel::Compute);
#if ORT_API_VERSION > 15
if (OrtCustomOp::version > 15) {
DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute);
} else
#endif // ORT_API_VERSION > 15
{
DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute);
}
}
RegularComputeType regular_fn_{};
};
template <typename... Args>
OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
const char* execution_provider,
OrtStatusPtr (*custom_compute_fn)(Args...)) {
using LiteOp = OrtLiteCustomStructV2<FunctionKernel<Args...>>;
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
}
template <typename OpKernel>
OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
const char* execution_provider) {
using LiteOp = OrtLiteCustomStructV2<OpKernel>;
return std::make_unique<LiteOp>(op_name, execution_provider).release();
}
} // namespace Custom
} // namespace Ort
namespace ortc = Ort::Custom;

Просмотреть файл

@ -3,30 +3,6 @@
#include "string_tensor.h"
struct KernelImageReader : BaseKernel {
KernelImageReader(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
}
void Compute(OrtKernelContext* context) {
const OrtValue* input_data = ort_.KernelContext_GetInput(context, 0);
OrtTensorDimensions input_data_dimensions(ort_, input_data);
int n = input_data_dimensions[0];
if (n != 1) {
ORTX_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT);
}
std::vector<std::string> image_paths;
GetTensorMutableDataString(api_, ort_, context, input_data, image_paths);
cv::Mat img = cv::imread(image_paths[0], cv::IMREAD_COLOR);
std::vector<int64_t> output_dimensions = {1, img.size[0], img.size[1], static_cast<int64_t>(img.elemSize())};
OrtValue* output_image = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
std::uint8_t* p_output_image = ort_.GetTensorMutableData<uint8_t>(output_image);
memcpy(p_output_image, img.data, img.total() * img.elemSize());
}
};
void image_reader(const ortc::Tensor<std::string>& input,
ortc::Tensor<uint8_t>& output) {
auto& input_data_dimensions = input.Shape();
@ -40,25 +16,3 @@ void image_reader(const ortc::Tensor<std::string>& input,
std::uint8_t* p_output_image = output.Allocate(output_dimensions);
memcpy(p_output_image, img.data, img.total() * img.elemSize());
}
struct CustomOpImageReader : OrtW::CustomOpBase<CustomOpImageReader, KernelImageReader> {
size_t GetInputTypeCount() const {
return 1;
}
size_t GetOutputTypeCount() const {
return 1;
}
ONNXTensorElementDataType GetInputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
}
ONNXTensorElementDataType GetOutputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
}
const char* GetName() const {
return "ImageReader";
}
};

Просмотреть файл

@ -6,11 +6,11 @@
#include "ocos.h"
#include <dlib/matrix.h>
void inverse(const ortc::Tensor<float>& input,
OrtStatusPtr inverse(const ortc::Tensor<float>& input,
ortc::Tensor<float>& output) {
auto& dimensions = input.Shape();
if (dimensions.size() != 2) {
throw std::runtime_error("Only 2-d matrix supported.");
return OrtW::CreateStatus("Only 2-d matrix supported.", ORT_INVALID_ARGUMENT);
}
const float* X = input.Data();
float* out = output.Allocate(dimensions);
@ -19,4 +19,6 @@ void inverse(const ortc::Tensor<float>& input,
std::copy(X, X + dm_x.size(), dm_x.begin());
dlib::matrix<float> dm = dlib::inv(dm_x);
memcpy(out, dm.steal_memory().get(), dm_x.size() * sizeof(float));
return nullptr;
}

Просмотреть файл

@ -6,29 +6,29 @@
#include "ocos.h"
#include <dlib/matrix.h>
struct STFT : public BaseKernel {
STFT(const OrtApi& api, const OrtKernelInfo& info,
bool with_norm = false) : BaseKernel(api, info),
with_norm_(with_norm) {
onesided_ = TryToGetAttributeWithDefault<int64_t>("onesided", 1);
struct StftNormal{
StftNormal() = default;
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
return OrtW::GetOpAttribute(info, "onesided", onesided_);
}
void Compute(const ortc::Tensor<float>& input0,
OrtStatusPtr Compute(const ortc::Tensor<float>& input0,
int64_t n_fft,
int64_t hop_length,
const ortc::Span<float>& input3,
int64_t frame_length,
ortc::Tensor<float>& output0) const {
auto X = input0.Data();
auto window = input3.data();
auto window = input3.data_;
auto dimensions = input0.Shape();
auto win_length = input3.size();
if (dimensions.size() < 2 || input0.NumberOfElement() != dimensions[1]) {
ORTX_CXX_API_THROW("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT);
return OrtW::CreateStatus("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT);
}
if (frame_length != n_fft) {
ORTX_CXX_API_THROW("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT);
return OrtW::CreateStatus("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT);
}
dlib::matrix<float> dm_x = dlib::mat(X, 1, dimensions[1]);
@ -42,42 +42,16 @@ struct STFT : public BaseKernel {
m_stft = dlib::subm(m_stft, 0, 0, m_stft.nr(), (m_stft.nc() >> 1) + 1);
}
if (with_norm_) {
dlib::matrix<float> result = dlib::norm(m_stft);
result = dlib::trans(result);
std::vector<int64_t> outdim{1, result.nr(), result.nc()};
auto result_size = result.size();
auto out0 = output0.Allocate(outdim);
memcpy(out0, result.steal_memory().get(), result_size * sizeof(float));
} else {
auto result = m_stft;
// No transpose here since it is done on copying data,
// switch nr and nc, so the output dim willbe tranposed one.
std::vector<int64_t> outdim{1, result.nc(), result.nr(), 2};
auto out0 = output0.Allocate(outdim);
for (size_t c = 0; c < result.nc(); ++c) {
for (size_t r = 0; r < result.nr(); ++r) {
*out0 = result(r, c).real();
*(out0 + 1) = result(r, c).imag();
out0 += 2;
}
}
}
dlib::matrix<float> result = dlib::norm(m_stft);
result = dlib::trans(result);
std::vector<int64_t> outdim{1, result.nr(), result.nc()};
auto result_size = result.size();
auto out0 = output0.Allocate(outdim);
memcpy(out0, result.steal_memory().get(), result_size * sizeof(float));
return nullptr;
}
private:
int64_t onesided_{};
bool with_norm_{};
};
struct StftNormal : public STFT {
StftNormal(const OrtApi& api, const OrtKernelInfo& info) : STFT(api, info, true) {}
void Compute(const ortc::Tensor<float>& input0,
int64_t n_fft,
int64_t hop_length,
const ortc::Span<float>& input3,
int64_t frame_length,
ortc::Tensor<float>& output0) const {
STFT::Compute(input0, n_fft, hop_length, input3, frame_length, output0);
}
int64_t onesided_{1};
};

Просмотреть файл

@ -7,16 +7,14 @@
#include "segment_extraction.hpp"
#include "segment_sum.hpp"
const std::vector<const OrtCustomOp*>& MathLoader() {
static OrtOpLoader op_loader(CustomCpuFunc("NegPos", neg_pos),
#ifdef ENABLE_DLIB
CustomCpuFunc("Inverse", inverse),
CustomCpuStruct("STFT", STFT),
CustomCpuStruct("StftNorm", StftNormal),
#endif
CustomCpuFunc("SegmentExtraction", segment_extraction),
CustomCpuFunc("SegmentSum", segment_sum));
return op_loader.GetCustomOps();
}
FxLoadCustomOpFactory LoadCustomOpClasses_Math = MathLoader;
FxLoadCustomOpFactory LoadCustomOpClasses_Math = []() -> CustomOpArray& {
static OrtOpLoader op_loader(CustomCpuFuncV2("NegPos", neg_pos),
#ifdef ENABLE_DLIB
CustomCpuFuncV2("Inverse", inverse),
CustomCpuStructV2("StftNorm", StftNormal),
#endif
CustomCpuFuncV2("SegmentExtraction", segment_extraction),
CustomCpuFuncV2("SegmentSum", segment_sum));
return op_loader.GetCustomOps();
};

Просмотреть файл

@ -5,7 +5,7 @@
#include "ocos.h"
void neg_pos(const ortc::Tensor<float>& input,
OrtStatusPtr neg_pos(const ortc::Tensor<float>& input,
ortc::Tensor<float>& out0_tensor,
ortc::Tensor<float>& out1_tensor) {
int64_t size = input.NumberOfElement();
@ -22,4 +22,6 @@ void neg_pos(const ortc::Tensor<float>& input,
out1[i] = 0;
}
}
return nullptr;
}

Просмотреть файл

@ -3,12 +3,12 @@
#include "segment_extraction.hpp"
void segment_extraction(const ortc::Tensor<int64_t>& input,
OrtStatusPtr segment_extraction(const ortc::Tensor<int64_t>& input,
ortc::Tensor<int64_t>& output0,
ortc::Tensor<int64_t>& output1) {
auto& input_dim = input.Shape();
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n].", ORT_INVALID_GRAPH);
return OrtW::CreateStatus("[SegmentExtraction]: Expect input dimension [n] or [1,n].", ORT_INVALID_GRAPH);
}
const int64_t* p_data = input.Data();
std::vector<std::int64_t> segment_value;
@ -38,4 +38,5 @@ void segment_extraction(const ortc::Tensor<int64_t>& input,
int64_t* out1_data = output1.Allocate(segment_value_dim);
std::copy(segment_value.begin(), segment_value.end(), out1_data);
return nullptr;
}

Просмотреть файл

@ -6,6 +6,6 @@
#include "ocos.h"
#include "string_utils.h"
void segment_extraction(const ortc::Tensor<int64_t>& input,
OrtStatusPtr segment_extraction(const ortc::Tensor<int64_t>& input,
ortc::Tensor<int64_t>& output0,
ortc::Tensor<int64_t>& output1);

Просмотреть файл

@ -3,19 +3,19 @@
#include "segment_sum.hpp"
void segment_sum(const ortc::Tensor<float>& data,
OrtStatusPtr segment_sum(const ortc::Tensor<float>& data,
const ortc::Tensor<int64_t>& segment_ids,
ortc::Tensor<float>& output) {
auto& dim_data = data.Shape();
auto& dim_seg = segment_ids.Shape();
if (dim_data.size() == 0 || dim_seg.size() == 0)
ORTX_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
return OrtW::CreateStatus("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
if (dim_seg.size() != 1)
ORTX_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
return OrtW::CreateStatus("segment_ids must a single tensor", ORT_INVALID_GRAPH);
if (dim_data[0] != dim_seg[0])
ORTX_CXX_API_THROW(MakeString(
return OrtW::CreateStatus(MakeString(
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
" segment_ids shape: ", dim_seg),
" segment_ids shape: ", dim_seg).c_str(),
ORT_INVALID_GRAPH);
const int64_t* p_segment_ids = segment_ids.Data();
@ -38,14 +38,18 @@ void segment_sum(const ortc::Tensor<float>& data,
float *p_out, *p_out_end;
const int64_t* p_seg = p_segment_ids;
for (; begin != end; ++p_seg) {
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
ORTX_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
*(p_seg - 1), " and ", *p_seg, " at position ",
std::distance(p_segment_ids, p_seg), "."),
ORT_RUNTIME_EXCEPTION);
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1)) {
return OrtW::CreateStatus(MakeString("segment_ids must be increasing but found ",
*(p_seg - 1), " and ", *p_seg, " at position ",
std::distance(p_segment_ids, p_seg), ".").c_str(),
ORT_RUNTIME_EXCEPTION);
}
p_out = p_output + *p_seg * in_stride;
p_out_end = p_out + in_stride;
for (; p_out != p_out_end; ++p_out, ++begin)
for (; p_out != p_out_end; ++p_out, ++begin) {
*p_out += *begin;
}
}
return nullptr;
}

Просмотреть файл

@ -6,6 +6,6 @@
#include "ocos.h"
#include "string_utils.h"
void segment_sum(const ortc::Tensor<float>& data,
OrtStatusPtr segment_sum(const ortc::Tensor<float>& data,
const ortc::Tensor<int64_t>& segment_ids,
ortc::Tensor<float>& output);

Просмотреть файл

@ -7,6 +7,7 @@
#include <vector>
#include <map>
#include <memory>
struct PyCustomOpDef {
std::string op_type;
@ -48,10 +49,9 @@ struct PyCustomOpKernel {
std::map<std::string, std::string> attrs_values_;
};
struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
PyCustomOpFactory() {
// STL vector needs it.
}
struct PyCustomOpFactory : public OrtCustomOp {
PyCustomOpFactory() = default;
PyCustomOpFactory(const PyCustomOpDef* opdef, const std::string& domain, const std::string& op) {
if (opdef == nullptr)
@ -59,30 +59,66 @@ struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKerne
opdef_ = opdef;
op_domain_ = domain;
op_type_ = op;
}
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
};
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
void* p = nullptr;
const char* GetName() const {
return op_type_.c_str();
};
OCOS_API_IMPL_BEGIN
auto self = static_cast<const PyCustomOpFactory*>(this_);
auto kernel = std::make_unique<PyCustomOpKernel>(*api, *info, self->opdef_->obj_id, self->opdef_->attrs).release();
p = reinterpret_cast<void*>(kernel);
OCOS_API_IMPL_END
size_t GetInputTypeCount() const {
return opdef_->input_types.size();
};
return p;
};
ONNXTensorElementDataType GetInputType(size_t idx) const {
return static_cast<ONNXTensorElementDataType>(opdef_->input_types[idx]);
};
OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return self->op_type_.c_str();
};
size_t GetOutputTypeCount() const {
return opdef_->output_types.size();
};
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
return "CPUExecutionProvider";
};
ONNXTensorElementDataType GetOutputType(size_t idx) const {
return static_cast<ONNXTensorElementDataType>(opdef_->output_types[idx]);
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return self->opdef_->input_types.size();
};
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return static_cast<ONNXTensorElementDataType>(self->opdef_->input_types[index]);
};
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return self->opdef_->output_types.size();
};
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
auto self = static_cast<const PyCustomOpFactory*>(this_);
return static_cast<ONNXTensorElementDataType>(self->opdef_->output_types[index]);
};
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) noexcept {
OCOS_API_IMPL_BEGIN
static_cast<PyCustomOpKernel*>(op_kernel)->Compute(context);
OCOS_API_IMPL_END
};
OrtCustomOp::KernelDestroy = [](void* op_kernel) noexcept {
std::unique_ptr<PyCustomOpKernel>(reinterpret_cast<PyCustomOpKernel*>(op_kernel)).reset();
};
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
};
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
};
}
const PyCustomOpDef* opdef_ = nullptr;

Просмотреть файл

@ -61,8 +61,9 @@ class ExternalCustomOps {
std::vector<const OrtCustomOp*> op_array_;
};
static int GetOrtVersion(const OrtApiBase* api_base = nullptr) {
static int ort_version = 11; // the default version is 1.11.0
static int GetOrtVersion(const OrtApiBase* api_base = nullptr) noexcept{
// the version will be cached after the first call on RegisterCustomOps
static int ort_version = MIN_ORT_VERSION_SUPPORTED; // the default version is 1.11.0
if (api_base != nullptr) {
std::string str_version = api_base->GetVersionString();
@ -104,13 +105,19 @@ extern "C" int ORT_API_CALL GetActiveOrtAPIVersion() {
return ver;
}
// The main entrance of the extension library.
extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
OrtStatus* status = nullptr;
OCOS_API_IMPL_BEGIN
OrtCustomOpDomain* domain = nullptr;
// the following will initiate some global objects which
// means any other function invocatoin prior to these calls to trigger undefined behavior.
auto ver = GetOrtVersion(api);
const OrtApi* ortApi = api->GetApi(ver);
API::instance(ortApi);
OrtCustomOpDomain* domain = nullptr;
std::set<std::string> pyop_nameset;
#if defined(PYTHON_OP_SUPPORT)

Просмотреть файл

@ -11,7 +11,9 @@
// throw in ctor which will be called during model load
struct ExceptionalKernel1 : BaseKernel {
ExceptionalKernel1(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
OCOS_API_IMPL_BEGIN
ORTX_CXX_API_THROW("Throw in ctor", ORT_FAIL);
OCOS_API_IMPL_END
}
void Compute(OrtKernelContext* context) {}
@ -23,11 +25,19 @@ struct ExceptionalKernel2 : BaseKernel {
}
void Compute(OrtKernelContext* context) {
OCOS_API_IMPL_BEGIN
ORTX_CXX_API_THROW("Throw in Compute", ORT_FAIL);
OCOS_API_IMPL_END
}
};
struct ExceptionalCustomOp1 : OrtW::CustomOpBase<ExceptionalCustomOp1, ExceptionalKernel1> {
struct ExceptionalCustomOp1 : Ort::CustomOpBase<ExceptionalCustomOp1, ExceptionalKernel1> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
void* result = nullptr;
OCOS_API_IMPL_BEGIN
result = new ExceptionalKernel1(api, *info);
OCOS_API_IMPL_END
return result; };
const char* GetName() const { return "ExceptionalCustomOp1"; };
size_t GetInputTypeCount() const { return 1; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
@ -35,7 +45,13 @@ struct ExceptionalCustomOp1 : OrtW::CustomOpBase<ExceptionalCustomOp1, Exception
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
};
struct ExceptionalCustomOp2 : OrtW::CustomOpBase<ExceptionalCustomOp2, ExceptionalKernel2> {
struct ExceptionalCustomOp2 : Ort::CustomOpBase<ExceptionalCustomOp2, ExceptionalKernel2> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
void* result = nullptr;
OCOS_API_IMPL_BEGIN
result = new ExceptionalKernel2(api, *info);
OCOS_API_IMPL_END
return result; };
const char* GetName() const { return "ExceptionalCustomOp2"; };
size_t GetInputTypeCount() const { return 1; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };

Просмотреть файл

@ -48,7 +48,10 @@ struct KernelOne : BaseKernel {
}
};
struct CustomOpOne : OrtW::CustomOpBase<CustomOpOne, KernelOne> {
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelOne(api, *info);
};
const char* GetName() const {
return "CustomOpOne";
};
@ -91,7 +94,10 @@ struct KernelTwo : BaseKernel {
}
};
struct CustomOpTwo : OrtW::CustomOpBase<CustomOpTwo, KernelTwo> {
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelTwo(api, *info);
};
const char* GetName() const {
return "CustomOpTwo";
};
@ -136,13 +142,9 @@ struct KernelThree : BaseKernel {
std::string substr_;
};
struct CustomOpThree : OrtW::CustomOpBase<CustomOpThree, KernelThree> {
// This is example code to show how to override the CustomOpBase::CreateKernel method even though it is not virtual.
// The CustomOpBase implementation will call the CreateKernel of the first class specified in the template,
// and from there it's also possible to call the base CreateKernel as per below.
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
std::cout << "Called CreateKernel override" << std::endl;
return OrtW::CustomOpBase<CustomOpThree, KernelThree>::CreateKernel(api, info);
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelThree(api, *info);
};
const char* GetName() const {
return "CustomOpThree";

Просмотреть файл

@ -96,16 +96,6 @@ class TestAudio(unittest.TestCase):
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
return np.absolute(data.T) ** 2
def test_onnx_stft(self):
audio_pcm = self.test_pcm
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
ortx_stft = OrtPyFunction.from_model(_create_test_model(), cpu_only=True)
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
actual = actual[0]
actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
def test_stft_norm_np(self):
audio_pcm = self.test_pcm
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))