Enable the status returnable APIs from ORT 1.16 C ABI (#558)
* Initial checkins for returnable ORT ABIs * fix for linux build * more fixes on Python, test... * remove the statusmsg * native unit tests fixing * Python unit tests fixing * last unit test fixing
This commit is contained in:
Родитель
bd5de8c420
Коммит
914509d524
|
@ -75,7 +75,9 @@ option(OCOS_BUILD_APPLE_FRAMEWORK "Enable building of the MacOS/iOS framework" O
|
|||
# Optional value. Some operators do not support old versions due to using the new custom operator interface
|
||||
# and will be disabled if this value is set and the version is incompatible.
|
||||
set(OCOS_ONNXRUNTIME_VERSION "" CACHE STRING
|
||||
"The version of ONNX Runtime being used in the build. Format is <major>.<minor>.<patch>. e.g. 1.15.1" )
|
||||
"The version of ONNX Runtime being used in the build. Format is <major>.<minor>.<patch>. e.g. 1.15.1")
|
||||
set(OCOS_ONNXRUNTIME_PKG_URI "" CACHE STRING
|
||||
"Specify the onnxruntime C++ shared library zip package path, like ./onnxruntime-win-x64-1.16.0.zip")
|
||||
|
||||
# TODO: Remove the following statements if AzureOp build is enabled by default.
|
||||
# If build_buildid environment varaible is set, which means this is a CI build, then always enable AzureOp.
|
||||
|
|
|
@ -4,6 +4,28 @@ if(_ONNXRUNTIME_EMBEDDED)
|
|||
elseif(ONNXRUNTIME_PKG_DIR)
|
||||
set(ONNXRUNTIME_INCLUDE_DIR ${ONNXRUNTIME_PKG_DIR}/include)
|
||||
set(ONNXRUNTIME_LIB_DIR ${ONNXRUNTIME_PKG_DIR}/lib)
|
||||
elseif(OCOS_ONNXRUNTIME_PKG_URI)
|
||||
if (NOT OCOS_ONNXRUNTIME_VERSION)
|
||||
message(FATAL_ERROR "OCOS_ONNXRUNTIME_PKG_URI is set but OCOS_ONNXRUNTIME_VERSION is not set")
|
||||
endif()
|
||||
set(ONNXRUNTIME_VER ${OCOS_ONNXRUNTIME_VERSION})
|
||||
set(ONNXRUNTIME_URL ${OCOS_ONNXRUNTIME_PKG_URI})
|
||||
message(STATUS "ONNX Runtime URL: ${OCOS_ONNXRUNTIME_PKG_URI}")
|
||||
FetchContent_Declare(
|
||||
onnxruntime
|
||||
URL ${OCOS_ONNXRUNTIME_PKG_URI}
|
||||
)
|
||||
|
||||
FetchContent_makeAvailable(onnxruntime)
|
||||
|
||||
if (ANDROID)
|
||||
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/headers)
|
||||
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/jni/${ANDROID_ABI})
|
||||
message(STATUS "Android onnxruntime inc=${ONNXRUNTIME_INCLUDE_DIR} lib=${ONNXRUNTIME_LIB_DIR}")
|
||||
else()
|
||||
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include)
|
||||
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib)
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR}")
|
||||
message(STATUS "CMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM}")
|
||||
|
|
|
@ -91,6 +91,10 @@ class CuopContainer {
|
|||
#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
|
||||
#define CustomAzureStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "AzureExecutionProvider")); }
|
||||
|
||||
#define CustomCpuFuncV2(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2(name, "CPUExecutionProvider", f)); }
|
||||
#define CustomCpuStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CPUExecutionProvider")); }
|
||||
|
||||
|
||||
template <typename F>
|
||||
void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
|
||||
F arg) {
|
||||
|
|
|
@ -13,18 +13,57 @@
|
|||
#include <vector>
|
||||
#include <utility>
|
||||
#include <type_traits>
|
||||
#include <optional>
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
#include "exceptions.h"
|
||||
|
||||
#define MIN_ORT_VERSION_SUPPORTED 10
|
||||
|
||||
#define MIN_ORT_VERSION_SUPPORTED 11
|
||||
|
||||
extern "C" int ORT_API_CALL GetActiveOrtAPIVersion();
|
||||
|
||||
// namespace of ORT ABI Wrapper
|
||||
namespace OrtW {
|
||||
|
||||
class API {
|
||||
// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
|
||||
public:
|
||||
static API& instance(const OrtApi* ort_api = nullptr) noexcept {
|
||||
static API self(*ort_api);
|
||||
return self;
|
||||
}
|
||||
|
||||
static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
|
||||
return instance()->CreateStatus(code, msg);
|
||||
}
|
||||
|
||||
static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
|
||||
instance()->ReleaseStatus(ptr);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
|
||||
|
||||
static void ThrowOnError(OrtStatusPtr ptr) {
|
||||
OrtW::ThrowOnError(instance().api_, ptr);
|
||||
}
|
||||
|
||||
private:
|
||||
const OrtApi* operator->() const {
|
||||
return &api_;
|
||||
}
|
||||
|
||||
API(const OrtApi& api) : api_(api) {
|
||||
if (&api == nullptr) {
|
||||
ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
}
|
||||
const OrtApi& api_;
|
||||
};
|
||||
|
||||
//
|
||||
// Custom OPs (only needed to implement custom OPs)
|
||||
// DEPRECTED: Custom OPs (only needed to implement custom OPs)
|
||||
//
|
||||
struct CustomOpApi {
|
||||
CustomOpApi(const OrtApi& api) : api_(api) {}
|
||||
|
@ -63,93 +102,6 @@ struct CustomOpApi {
|
|||
const OrtApi& api_;
|
||||
};
|
||||
|
||||
template <typename TOp, typename TKernel>
|
||||
struct CustomOpBase : OrtCustomOp {
|
||||
CustomOpBase() {
|
||||
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
|
||||
void* result = nullptr;
|
||||
OCOS_API_IMPL_BEGIN
|
||||
result = static_cast<const TOp*>(this_)->CreateKernel(*api, *info);
|
||||
OCOS_API_IMPL_END
|
||||
return result;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetName();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetExecutionProviderType();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetInputTypeCount();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetInputType(index);
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetOutputTypeCount();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetOutputType(index);
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
OCOS_API_IMPL_BEGIN
|
||||
static_cast<TKernel*>(op_kernel)->Compute(context);
|
||||
OCOS_API_IMPL_END
|
||||
};
|
||||
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 26409)
|
||||
#endif
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetInputCharacteristic(index);
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index);
|
||||
};
|
||||
}
|
||||
|
||||
// default implementation. we can't use a virtual function as the layout of this struct has to be aligned with
|
||||
// OrtCustomOp, but a derived class can override by creating a function with the same name and signature,
|
||||
// calling this base class implementation as needed. e.g. see CustomOpThree in the unit test code
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 26409)
|
||||
#endif
|
||||
return new TKernel(api, info);
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
}
|
||||
|
||||
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
|
||||
const char* GetExecutionProviderType() const { return nullptr; }
|
||||
|
||||
// Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
|
||||
// (inputs and outputs are required by default)
|
||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Custom OP API Inlines
|
||||
//
|
||||
|
@ -299,8 +251,244 @@ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context,
|
|||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
|
||||
return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
|
||||
return instance()->KernelInfoGetAttribute_float(&info, name, &value);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
|
||||
if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
|
||||
// Ideally, we should know which kind of error code can be ignored, but it is not availabe now.
|
||||
// Just ignore all of them.
|
||||
API::ReleaseStatus(status);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
|
||||
return API::CreateStatus(code, msg);
|
||||
}
|
||||
|
||||
} // namespace OrtW
|
||||
|
||||
// !! TODO: only do it for legecy ort build
|
||||
|
||||
#if ORT_API_VERSION < 15
|
||||
#include "custom_op_lite.h"
|
||||
|
||||
#else
|
||||
// From onnxruntime 1.17, the custom op lite API header is used the one from onnxruntime package.
|
||||
// #include "onnxruntime_lite_custom_op.h"
|
||||
// The existing custom op lite API header has more features than the one from onnxruntime 1.16.
|
||||
#include "custom_op_lite.h"
|
||||
|
||||
#endif // ORT_API_VERSION < 15
|
||||
|
||||
|
||||
|
||||
namespace Ort {
|
||||
namespace Custom {
|
||||
|
||||
|
||||
template <typename... Args>
|
||||
struct FunctionKernel {
|
||||
using ComputeFn = std::function<OrtStatusPtr(Args...)>;
|
||||
|
||||
OrtStatusPtr Compute(Args... args) const {
|
||||
return compute_fn_(args...);
|
||||
}
|
||||
|
||||
ComputeFn compute_fn_;
|
||||
};
|
||||
|
||||
// primary template handles types that have no nested ::type member:
|
||||
template <class, class = void>
|
||||
struct IsFunctionKernel {
|
||||
typedef std::false_type type;
|
||||
};
|
||||
|
||||
// specialization recognizes types that do have a nested ::type member:
|
||||
template <class T>
|
||||
struct IsFunctionKernel<T, std::void_t<typename T::ComputeFn>>{
|
||||
typedef std::true_type type;
|
||||
};
|
||||
|
||||
// Helper type
|
||||
template <typename T>
|
||||
struct ComputeArgsList;
|
||||
|
||||
// Specialization for member function
|
||||
template <typename C, typename... Args>
|
||||
struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> {
|
||||
using FunctionType = OrtStatusPtr (*)(Args...);
|
||||
using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const;
|
||||
};
|
||||
|
||||
template <typename CustomOpKernel>
|
||||
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
|
||||
using ComputeFunction = decltype(&CustomOpKernel::Compute);
|
||||
using RegularComputeType = typename ComputeArgsList<ComputeFunction>::FunctionType;
|
||||
|
||||
template <typename... Args>
|
||||
using MemberComputeType = OrtStatusPtr (CustomOpKernel::*)(Args...) const;
|
||||
|
||||
struct KernelEx : public CustomOpKernel {
|
||||
struct {
|
||||
std::string ep_{};
|
||||
std::unique_ptr<OrtW::CustomOpApi> api_;
|
||||
} extra_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static OrtStatusPtr InitKernel(KernelEx& kernel,
|
||||
const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
|
||||
return kernel.OnModelAttach(api, info);
|
||||
}
|
||||
|
||||
static OrtStatusPtr InitKernel(
|
||||
KernelEx& kernel,
|
||||
const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, std::true_type) {
|
||||
kernel.compute_fn_ = fn;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void ParseArgs(MemberComputeType<Args...> fn) {
|
||||
OrtLiteCustomOp::ParseArgs<Args...>(OrtLiteCustomOp::input_types_, OrtLiteCustomOp::output_types_);
|
||||
}
|
||||
|
||||
// TODO: consider to disable these legacy functions for mobile build to save binary size
|
||||
template <typename... Args>
|
||||
void DefineCallbackFunctionsLegacy(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
|
||||
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
||||
auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
|
||||
auto kernel = std::make_unique<KernelEx>();
|
||||
typedef typename IsFunctionKernel<CustomOpKernel>::type type_flag;
|
||||
OrtStatusPtr status = InitKernel(*kernel, *ort_api, *info, self->regular_fn_, type_flag());
|
||||
OrtW::ThrowOnError(*ort_api, status);
|
||||
|
||||
kernel->extra_.ep_ = self->execution_provider_;
|
||||
kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
|
||||
return reinterpret_cast<void*>(kernel.release());
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
|
||||
std::vector<TensorPtr> tensors;
|
||||
auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
|
||||
context,
|
||||
tensors,
|
||||
kernel->extra_.api_->KernelContext_GetInputCount(context),
|
||||
kernel->extra_.api_->KernelContext_GetOutputCount(context),
|
||||
kernel->extra_.ep_);
|
||||
std::apply([kernel](Args const&... t_args) {
|
||||
auto status = kernel->Compute(t_args...); OrtW::API::ThrowOnError(status);}, t);
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
||||
std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
|
||||
};
|
||||
}
|
||||
|
||||
#if ORT_API_VERSION > 15
|
||||
template <typename... Args>
|
||||
void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
|
||||
OrtCustomOp::CreateKernel = nullptr;
|
||||
OrtCustomOp::KernelCompute = nullptr;
|
||||
|
||||
OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
|
||||
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
|
||||
if (api == nullptr) {
|
||||
assert(false && "Got a null pointer for ORT api on calling CreateKernelV2");
|
||||
// should never happened, what we can do?
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (this_ == nullptr || info == nullptr || op_kernel == nullptr) {
|
||||
return api->CreateStatus(ORT_INVALID_ARGUMENT, "OrtCustomOp::CreateKernelV2: received a null pointer");
|
||||
}
|
||||
|
||||
auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
|
||||
auto kernel = std::make_unique<KernelEx>();
|
||||
if (kernel == nullptr) {
|
||||
return api->CreateStatus(ORT_FAIL, "OrtCustomOp::CreateKernelV2: failed to new a kernel, OOM?");
|
||||
}
|
||||
|
||||
typedef typename IsFunctionKernel<CustomOpKernel>::type flag_type;
|
||||
OrtStatusPtr status = InitKernel(*kernel, *api, *info, self->regular_fn_, flag_type());
|
||||
if (status == nullptr) {
|
||||
kernel->extra_.ep_ = self->execution_provider_;
|
||||
kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*api);
|
||||
*op_kernel = reinterpret_cast<void*>(kernel.release());
|
||||
}
|
||||
|
||||
return status;
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
|
||||
auto kernel = reinterpret_cast<KernelEx* >(op_kernel);
|
||||
std::vector<TensorPtr> tensors;
|
||||
auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
|
||||
context,
|
||||
tensors,
|
||||
kernel->extra_.api_->KernelContext_GetInputCount(context),
|
||||
kernel->extra_.api_->KernelContext_GetOutputCount(context),
|
||||
kernel->extra_.ep_);
|
||||
return std::apply([kernel](Args const&... t_args) {
|
||||
return kernel->Compute(t_args...); }, t);
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
||||
std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
|
||||
};
|
||||
}
|
||||
#endif // ORT_API_VERSION > 15
|
||||
|
||||
OrtLiteCustomStructV2(const char* op_name,
|
||||
const char* execution_provider,
|
||||
RegularComputeType fn_compute = nullptr)
|
||||
: OrtLiteCustomOp(op_name, execution_provider), regular_fn_(fn_compute) {
|
||||
|
||||
ParseArgs(&CustomOpKernel::Compute);
|
||||
|
||||
#if ORT_API_VERSION > 15
|
||||
if (OrtCustomOp::version > 15) {
|
||||
DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute);
|
||||
} else
|
||||
#endif // ORT_API_VERSION > 15
|
||||
|
||||
{
|
||||
DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute);
|
||||
}
|
||||
}
|
||||
|
||||
RegularComputeType regular_fn_{};
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
|
||||
const char* execution_provider,
|
||||
OrtStatusPtr (*custom_compute_fn)(Args...)) {
|
||||
using LiteOp = OrtLiteCustomStructV2<FunctionKernel<Args...>>;
|
||||
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
|
||||
}
|
||||
|
||||
template <typename OpKernel>
|
||||
OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
|
||||
const char* execution_provider) {
|
||||
using LiteOp = OrtLiteCustomStructV2<OpKernel>;
|
||||
return std::make_unique<LiteOp>(op_name, execution_provider).release();
|
||||
}
|
||||
|
||||
} // namespace Custom
|
||||
} // namespace Ort
|
||||
|
||||
namespace ortc = Ort::Custom;
|
||||
|
|
|
@ -3,30 +3,6 @@
|
|||
|
||||
#include "string_tensor.h"
|
||||
|
||||
struct KernelImageReader : BaseKernel {
|
||||
KernelImageReader(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
const OrtValue* input_data = ort_.KernelContext_GetInput(context, 0);
|
||||
OrtTensorDimensions input_data_dimensions(ort_, input_data);
|
||||
|
||||
int n = input_data_dimensions[0];
|
||||
if (n != 1) {
|
||||
ORTX_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::vector<std::string> image_paths;
|
||||
GetTensorMutableDataString(api_, ort_, context, input_data, image_paths);
|
||||
|
||||
cv::Mat img = cv::imread(image_paths[0], cv::IMREAD_COLOR);
|
||||
std::vector<int64_t> output_dimensions = {1, img.size[0], img.size[1], static_cast<int64_t>(img.elemSize())};
|
||||
OrtValue* output_image = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
|
||||
std::uint8_t* p_output_image = ort_.GetTensorMutableData<uint8_t>(output_image);
|
||||
memcpy(p_output_image, img.data, img.total() * img.elemSize());
|
||||
}
|
||||
};
|
||||
|
||||
void image_reader(const ortc::Tensor<std::string>& input,
|
||||
ortc::Tensor<uint8_t>& output) {
|
||||
auto& input_data_dimensions = input.Shape();
|
||||
|
@ -40,25 +16,3 @@ void image_reader(const ortc::Tensor<std::string>& input,
|
|||
std::uint8_t* p_output_image = output.Allocate(output_dimensions);
|
||||
memcpy(p_output_image, img.data, img.total() * img.elemSize());
|
||||
}
|
||||
|
||||
struct CustomOpImageReader : OrtW::CustomOpBase<CustomOpImageReader, KernelImageReader> {
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "ImageReader";
|
||||
}
|
||||
};
|
||||
|
|
|
@ -6,11 +6,11 @@
|
|||
#include "ocos.h"
|
||||
#include <dlib/matrix.h>
|
||||
|
||||
void inverse(const ortc::Tensor<float>& input,
|
||||
OrtStatusPtr inverse(const ortc::Tensor<float>& input,
|
||||
ortc::Tensor<float>& output) {
|
||||
auto& dimensions = input.Shape();
|
||||
if (dimensions.size() != 2) {
|
||||
throw std::runtime_error("Only 2-d matrix supported.");
|
||||
return OrtW::CreateStatus("Only 2-d matrix supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
const float* X = input.Data();
|
||||
float* out = output.Allocate(dimensions);
|
||||
|
@ -19,4 +19,6 @@ void inverse(const ortc::Tensor<float>& input,
|
|||
std::copy(X, X + dm_x.size(), dm_x.begin());
|
||||
dlib::matrix<float> dm = dlib::inv(dm_x);
|
||||
memcpy(out, dm.steal_memory().get(), dm_x.size() * sizeof(float));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -6,29 +6,29 @@
|
|||
#include "ocos.h"
|
||||
#include <dlib/matrix.h>
|
||||
|
||||
struct STFT : public BaseKernel {
|
||||
STFT(const OrtApi& api, const OrtKernelInfo& info,
|
||||
bool with_norm = false) : BaseKernel(api, info),
|
||||
with_norm_(with_norm) {
|
||||
onesided_ = TryToGetAttributeWithDefault<int64_t>("onesided", 1);
|
||||
struct StftNormal{
|
||||
StftNormal() = default;
|
||||
|
||||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
|
||||
return OrtW::GetOpAttribute(info, "onesided", onesided_);
|
||||
}
|
||||
|
||||
void Compute(const ortc::Tensor<float>& input0,
|
||||
OrtStatusPtr Compute(const ortc::Tensor<float>& input0,
|
||||
int64_t n_fft,
|
||||
int64_t hop_length,
|
||||
const ortc::Span<float>& input3,
|
||||
int64_t frame_length,
|
||||
ortc::Tensor<float>& output0) const {
|
||||
auto X = input0.Data();
|
||||
auto window = input3.data();
|
||||
auto window = input3.data_;
|
||||
auto dimensions = input0.Shape();
|
||||
auto win_length = input3.size();
|
||||
|
||||
if (dimensions.size() < 2 || input0.NumberOfElement() != dimensions[1]) {
|
||||
ORTX_CXX_API_THROW("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT);
|
||||
return OrtW::CreateStatus("[Stft] Only batch == 1 tensor supported.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
if (frame_length != n_fft) {
|
||||
ORTX_CXX_API_THROW("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT);
|
||||
return OrtW::CreateStatus("[Stft] Only support size of FFT equals the frame length.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
dlib::matrix<float> dm_x = dlib::mat(X, 1, dimensions[1]);
|
||||
|
@ -42,42 +42,16 @@ struct STFT : public BaseKernel {
|
|||
m_stft = dlib::subm(m_stft, 0, 0, m_stft.nr(), (m_stft.nc() >> 1) + 1);
|
||||
}
|
||||
|
||||
if (with_norm_) {
|
||||
dlib::matrix<float> result = dlib::norm(m_stft);
|
||||
result = dlib::trans(result);
|
||||
std::vector<int64_t> outdim{1, result.nr(), result.nc()};
|
||||
auto result_size = result.size();
|
||||
auto out0 = output0.Allocate(outdim);
|
||||
memcpy(out0, result.steal_memory().get(), result_size * sizeof(float));
|
||||
} else {
|
||||
auto result = m_stft;
|
||||
// No transpose here since it is done on copying data,
|
||||
// switch nr and nc, so the output dim willbe tranposed one.
|
||||
std::vector<int64_t> outdim{1, result.nc(), result.nr(), 2};
|
||||
auto out0 = output0.Allocate(outdim);
|
||||
for (size_t c = 0; c < result.nc(); ++c) {
|
||||
for (size_t r = 0; r < result.nr(); ++r) {
|
||||
*out0 = result(r, c).real();
|
||||
*(out0 + 1) = result(r, c).imag();
|
||||
out0 += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
dlib::matrix<float> result = dlib::norm(m_stft);
|
||||
result = dlib::trans(result);
|
||||
std::vector<int64_t> outdim{1, result.nr(), result.nc()};
|
||||
auto result_size = result.size();
|
||||
auto out0 = output0.Allocate(outdim);
|
||||
memcpy(out0, result.steal_memory().get(), result_size * sizeof(float));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t onesided_{};
|
||||
bool with_norm_{};
|
||||
};
|
||||
|
||||
struct StftNormal : public STFT {
|
||||
StftNormal(const OrtApi& api, const OrtKernelInfo& info) : STFT(api, info, true) {}
|
||||
void Compute(const ortc::Tensor<float>& input0,
|
||||
int64_t n_fft,
|
||||
int64_t hop_length,
|
||||
const ortc::Span<float>& input3,
|
||||
int64_t frame_length,
|
||||
ortc::Tensor<float>& output0) const {
|
||||
STFT::Compute(input0, n_fft, hop_length, input3, frame_length, output0);
|
||||
}
|
||||
int64_t onesided_{1};
|
||||
};
|
||||
|
|
|
@ -7,16 +7,14 @@
|
|||
#include "segment_extraction.hpp"
|
||||
#include "segment_sum.hpp"
|
||||
|
||||
const std::vector<const OrtCustomOp*>& MathLoader() {
|
||||
static OrtOpLoader op_loader(CustomCpuFunc("NegPos", neg_pos),
|
||||
#ifdef ENABLE_DLIB
|
||||
CustomCpuFunc("Inverse", inverse),
|
||||
CustomCpuStruct("STFT", STFT),
|
||||
CustomCpuStruct("StftNorm", StftNormal),
|
||||
#endif
|
||||
CustomCpuFunc("SegmentExtraction", segment_extraction),
|
||||
CustomCpuFunc("SegmentSum", segment_sum));
|
||||
return op_loader.GetCustomOps();
|
||||
}
|
||||
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Math = MathLoader;
|
||||
FxLoadCustomOpFactory LoadCustomOpClasses_Math = []() -> CustomOpArray& {
|
||||
static OrtOpLoader op_loader(CustomCpuFuncV2("NegPos", neg_pos),
|
||||
#ifdef ENABLE_DLIB
|
||||
CustomCpuFuncV2("Inverse", inverse),
|
||||
CustomCpuStructV2("StftNorm", StftNormal),
|
||||
#endif
|
||||
CustomCpuFuncV2("SegmentExtraction", segment_extraction),
|
||||
CustomCpuFuncV2("SegmentSum", segment_sum));
|
||||
return op_loader.GetCustomOps();
|
||||
};
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
#include "ocos.h"
|
||||
|
||||
void neg_pos(const ortc::Tensor<float>& input,
|
||||
OrtStatusPtr neg_pos(const ortc::Tensor<float>& input,
|
||||
ortc::Tensor<float>& out0_tensor,
|
||||
ortc::Tensor<float>& out1_tensor) {
|
||||
int64_t size = input.NumberOfElement();
|
||||
|
@ -22,4 +22,6 @@ void neg_pos(const ortc::Tensor<float>& input,
|
|||
out1[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
|
||||
#include "segment_extraction.hpp"
|
||||
|
||||
void segment_extraction(const ortc::Tensor<int64_t>& input,
|
||||
OrtStatusPtr segment_extraction(const ortc::Tensor<int64_t>& input,
|
||||
ortc::Tensor<int64_t>& output0,
|
||||
ortc::Tensor<int64_t>& output1) {
|
||||
auto& input_dim = input.Shape();
|
||||
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
|
||||
ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||
return OrtW::CreateStatus("[SegmentExtraction]: Expect input dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||
}
|
||||
const int64_t* p_data = input.Data();
|
||||
std::vector<std::int64_t> segment_value;
|
||||
|
@ -38,4 +38,5 @@ void segment_extraction(const ortc::Tensor<int64_t>& input,
|
|||
|
||||
int64_t* out1_data = output1.Allocate(segment_value_dim);
|
||||
std::copy(segment_value.begin(), segment_value.end(), out1_data);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -6,6 +6,6 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
void segment_extraction(const ortc::Tensor<int64_t>& input,
|
||||
OrtStatusPtr segment_extraction(const ortc::Tensor<int64_t>& input,
|
||||
ortc::Tensor<int64_t>& output0,
|
||||
ortc::Tensor<int64_t>& output1);
|
||||
|
|
|
@ -3,19 +3,19 @@
|
|||
|
||||
#include "segment_sum.hpp"
|
||||
|
||||
void segment_sum(const ortc::Tensor<float>& data,
|
||||
OrtStatusPtr segment_sum(const ortc::Tensor<float>& data,
|
||||
const ortc::Tensor<int64_t>& segment_ids,
|
||||
ortc::Tensor<float>& output) {
|
||||
auto& dim_data = data.Shape();
|
||||
auto& dim_seg = segment_ids.Shape();
|
||||
if (dim_data.size() == 0 || dim_seg.size() == 0)
|
||||
ORTX_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
|
||||
return OrtW::CreateStatus("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
|
||||
if (dim_seg.size() != 1)
|
||||
ORTX_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
|
||||
return OrtW::CreateStatus("segment_ids must a single tensor", ORT_INVALID_GRAPH);
|
||||
if (dim_data[0] != dim_seg[0])
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
return OrtW::CreateStatus(MakeString(
|
||||
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
|
||||
" segment_ids shape: ", dim_seg),
|
||||
" segment_ids shape: ", dim_seg).c_str(),
|
||||
ORT_INVALID_GRAPH);
|
||||
|
||||
const int64_t* p_segment_ids = segment_ids.Data();
|
||||
|
@ -38,14 +38,18 @@ void segment_sum(const ortc::Tensor<float>& data,
|
|||
float *p_out, *p_out_end;
|
||||
const int64_t* p_seg = p_segment_ids;
|
||||
for (; begin != end; ++p_seg) {
|
||||
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
|
||||
ORTX_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
|
||||
*(p_seg - 1), " and ", *p_seg, " at position ",
|
||||
std::distance(p_segment_ids, p_seg), "."),
|
||||
ORT_RUNTIME_EXCEPTION);
|
||||
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1)) {
|
||||
return OrtW::CreateStatus(MakeString("segment_ids must be increasing but found ",
|
||||
*(p_seg - 1), " and ", *p_seg, " at position ",
|
||||
std::distance(p_segment_ids, p_seg), ".").c_str(),
|
||||
ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
p_out = p_output + *p_seg * in_stride;
|
||||
p_out_end = p_out + in_stride;
|
||||
for (; p_out != p_out_end; ++p_out, ++begin)
|
||||
for (; p_out != p_out_end; ++p_out, ++begin) {
|
||||
*p_out += *begin;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -6,6 +6,6 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
void segment_sum(const ortc::Tensor<float>& data,
|
||||
OrtStatusPtr segment_sum(const ortc::Tensor<float>& data,
|
||||
const ortc::Tensor<int64_t>& segment_ids,
|
||||
ortc::Tensor<float>& output);
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
struct PyCustomOpDef {
|
||||
std::string op_type;
|
||||
|
@ -48,10 +49,9 @@ struct PyCustomOpKernel {
|
|||
std::map<std::string, std::string> attrs_values_;
|
||||
};
|
||||
|
||||
struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
|
||||
PyCustomOpFactory() {
|
||||
// STL vector needs it.
|
||||
}
|
||||
struct PyCustomOpFactory : public OrtCustomOp {
|
||||
|
||||
PyCustomOpFactory() = default;
|
||||
|
||||
PyCustomOpFactory(const PyCustomOpDef* opdef, const std::string& domain, const std::string& op) {
|
||||
if (opdef == nullptr)
|
||||
|
@ -59,30 +59,66 @@ struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKerne
|
|||
opdef_ = opdef;
|
||||
op_domain_ = domain;
|
||||
op_type_ = op;
|
||||
}
|
||||
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
|
||||
return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
|
||||
};
|
||||
OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
|
||||
void* p = nullptr;
|
||||
|
||||
const char* GetName() const {
|
||||
return op_type_.c_str();
|
||||
};
|
||||
OCOS_API_IMPL_BEGIN
|
||||
auto self = static_cast<const PyCustomOpFactory*>(this_);
|
||||
auto kernel = std::make_unique<PyCustomOpKernel>(*api, *info, self->opdef_->obj_id, self->opdef_->attrs).release();
|
||||
p = reinterpret_cast<void*>(kernel);
|
||||
OCOS_API_IMPL_END
|
||||
|
||||
size_t GetInputTypeCount() const {
|
||||
return opdef_->input_types.size();
|
||||
};
|
||||
return p;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t idx) const {
|
||||
return static_cast<ONNXTensorElementDataType>(opdef_->input_types[idx]);
|
||||
};
|
||||
OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
|
||||
auto self = static_cast<const PyCustomOpFactory*>(this_);
|
||||
return self->op_type_.c_str();
|
||||
};
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return opdef_->output_types.size();
|
||||
};
|
||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
|
||||
return "CPUExecutionProvider";
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t idx) const {
|
||||
return static_cast<ONNXTensorElementDataType>(opdef_->output_types[idx]);
|
||||
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
|
||||
auto self = static_cast<const PyCustomOpFactory*>(this_);
|
||||
return self->opdef_->input_types.size();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
auto self = static_cast<const PyCustomOpFactory*>(this_);
|
||||
return static_cast<ONNXTensorElementDataType>(self->opdef_->input_types[index]);
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
|
||||
auto self = static_cast<const PyCustomOpFactory*>(this_);
|
||||
return self->opdef_->output_types.size();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
auto self = static_cast<const PyCustomOpFactory*>(this_);
|
||||
return static_cast<ONNXTensorElementDataType>(self->opdef_->output_types[index]);
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) noexcept {
|
||||
OCOS_API_IMPL_BEGIN
|
||||
static_cast<PyCustomOpKernel*>(op_kernel)->Compute(context);
|
||||
OCOS_API_IMPL_END
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) noexcept {
|
||||
std::unique_ptr<PyCustomOpKernel>(reinterpret_cast<PyCustomOpKernel*>(op_kernel)).reset();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
};
|
||||
}
|
||||
|
||||
const PyCustomOpDef* opdef_ = nullptr;
|
||||
|
|
|
@ -61,8 +61,9 @@ class ExternalCustomOps {
|
|||
std::vector<const OrtCustomOp*> op_array_;
|
||||
};
|
||||
|
||||
static int GetOrtVersion(const OrtApiBase* api_base = nullptr) {
|
||||
static int ort_version = 11; // the default version is 1.11.0
|
||||
static int GetOrtVersion(const OrtApiBase* api_base = nullptr) noexcept{
|
||||
// the version will be cached after the first call on RegisterCustomOps
|
||||
static int ort_version = MIN_ORT_VERSION_SUPPORTED; // the default version is 1.11.0
|
||||
|
||||
if (api_base != nullptr) {
|
||||
std::string str_version = api_base->GetVersionString();
|
||||
|
@ -104,13 +105,19 @@ extern "C" int ORT_API_CALL GetActiveOrtAPIVersion() {
|
|||
return ver;
|
||||
}
|
||||
|
||||
|
||||
// The main entrance of the extension library.
|
||||
extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
|
||||
OrtStatus* status = nullptr;
|
||||
OCOS_API_IMPL_BEGIN
|
||||
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
// the following will initiate some global objects which
|
||||
// means any other function invocatoin prior to these calls to trigger undefined behavior.
|
||||
auto ver = GetOrtVersion(api);
|
||||
const OrtApi* ortApi = api->GetApi(ver);
|
||||
API::instance(ortApi);
|
||||
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
std::set<std::string> pyop_nameset;
|
||||
|
||||
#if defined(PYTHON_OP_SUPPORT)
|
||||
|
|
|
@ -11,7 +11,9 @@
|
|||
// throw in ctor which will be called during model load
|
||||
struct ExceptionalKernel1 : BaseKernel {
|
||||
ExceptionalKernel1(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
OCOS_API_IMPL_BEGIN
|
||||
ORTX_CXX_API_THROW("Throw in ctor", ORT_FAIL);
|
||||
OCOS_API_IMPL_END
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {}
|
||||
|
@ -23,11 +25,19 @@ struct ExceptionalKernel2 : BaseKernel {
|
|||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
OCOS_API_IMPL_BEGIN
|
||||
ORTX_CXX_API_THROW("Throw in Compute", ORT_FAIL);
|
||||
OCOS_API_IMPL_END
|
||||
}
|
||||
};
|
||||
|
||||
struct ExceptionalCustomOp1 : OrtW::CustomOpBase<ExceptionalCustomOp1, ExceptionalKernel1> {
|
||||
struct ExceptionalCustomOp1 : Ort::CustomOpBase<ExceptionalCustomOp1, ExceptionalKernel1> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
void* result = nullptr;
|
||||
OCOS_API_IMPL_BEGIN
|
||||
result = new ExceptionalKernel1(api, *info);
|
||||
OCOS_API_IMPL_END
|
||||
return result; };
|
||||
const char* GetName() const { return "ExceptionalCustomOp1"; };
|
||||
size_t GetInputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
@ -35,7 +45,13 @@ struct ExceptionalCustomOp1 : OrtW::CustomOpBase<ExceptionalCustomOp1, Exception
|
|||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
};
|
||||
|
||||
struct ExceptionalCustomOp2 : OrtW::CustomOpBase<ExceptionalCustomOp2, ExceptionalKernel2> {
|
||||
struct ExceptionalCustomOp2 : Ort::CustomOpBase<ExceptionalCustomOp2, ExceptionalKernel2> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
void* result = nullptr;
|
||||
OCOS_API_IMPL_BEGIN
|
||||
result = new ExceptionalKernel2(api, *info);
|
||||
OCOS_API_IMPL_END
|
||||
return result; };
|
||||
const char* GetName() const { return "ExceptionalCustomOp2"; };
|
||||
size_t GetInputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
|
|
|
@ -48,7 +48,10 @@ struct KernelOne : BaseKernel {
|
|||
}
|
||||
};
|
||||
|
||||
struct CustomOpOne : OrtW::CustomOpBase<CustomOpOne, KernelOne> {
|
||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelOne(api, *info);
|
||||
};
|
||||
const char* GetName() const {
|
||||
return "CustomOpOne";
|
||||
};
|
||||
|
@ -91,7 +94,10 @@ struct KernelTwo : BaseKernel {
|
|||
}
|
||||
};
|
||||
|
||||
struct CustomOpTwo : OrtW::CustomOpBase<CustomOpTwo, KernelTwo> {
|
||||
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelTwo(api, *info);
|
||||
};
|
||||
const char* GetName() const {
|
||||
return "CustomOpTwo";
|
||||
};
|
||||
|
@ -136,13 +142,9 @@ struct KernelThree : BaseKernel {
|
|||
std::string substr_;
|
||||
};
|
||||
|
||||
struct CustomOpThree : OrtW::CustomOpBase<CustomOpThree, KernelThree> {
|
||||
// This is example code to show how to override the CustomOpBase::CreateKernel method even though it is not virtual.
|
||||
// The CustomOpBase implementation will call the CreateKernel of the first class specified in the template,
|
||||
// and from there it's also possible to call the base CreateKernel as per below.
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
|
||||
std::cout << "Called CreateKernel override" << std::endl;
|
||||
return OrtW::CustomOpBase<CustomOpThree, KernelThree>::CreateKernel(api, info);
|
||||
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelThree(api, *info);
|
||||
};
|
||||
const char* GetName() const {
|
||||
return "CustomOpThree";
|
||||
|
|
|
@ -96,16 +96,6 @@ class TestAudio(unittest.TestCase):
|
|||
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
|
||||
return np.absolute(data.T) ** 2
|
||||
|
||||
def test_onnx_stft(self):
|
||||
audio_pcm = self.test_pcm
|
||||
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
||||
|
||||
ortx_stft = OrtPyFunction.from_model(_create_test_model(), cpu_only=True)
|
||||
actual = ortx_stft(np.expand_dims(audio_pcm, axis=0), 400, 160, np.hanning(400).astype(np.float32), 400)
|
||||
actual = actual[0]
|
||||
actual = actual[:, :, 0] ** 2 + actual[:, :, 1] ** 2
|
||||
np.testing.assert_allclose(expected[:, 1:], actual[:, 1:], rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_stft_norm_np(self):
|
||||
audio_pcm = self.test_pcm
|
||||
expected = self.stft(audio_pcm, 400, 160, np.hanning(400).astype(np.float32))
|
||||
|
|
Загрузка…
Ссылка в новой задаче