From e3d9198de801fe80ec3896063e016f9db8cf2be2 Mon Sep 17 00:00:00 2001 From: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Date: Mon, 10 Jul 2023 16:10:11 -0700 Subject: [PATCH] using the latest Ort header instead of minimum compatible headers (#485) * using the latest Ort header instead of minimum compatible headers * Update ext_ortlib.cmake * Update ortcustomops.def * change the default ORT API version value --- cmake/ext_ortlib.cmake | 4 +-- includes/custom_op_lite.h | 9 ++----- includes/onnxruntime_customop.hpp | 4 ++- shared/lib/ortcustomops.cc | 41 ++++++++++++++++++++++++++++++- shared/ortcustomops.def | 1 + shared/ortcustomops.ver | 1 + 6 files changed, 49 insertions(+), 11 deletions(-) diff --git a/cmake/ext_ortlib.cmake b/cmake/ext_ortlib.cmake index 391b6411..7634ddb0 100644 --- a/cmake/ext_ortlib.cmake +++ b/cmake/ext_ortlib.cmake @@ -8,8 +8,8 @@ else() message(STATUS "CMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR}") message(STATUS "CMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM}") - # default to 1.11.1 if not specified - set(ONNXRUNTIME_VER "1.12.1" CACHE STRING "ONNX Runtime version") + # 1.15.1 is the latest ORT release. + set(ONNXRUNTIME_VER "1.15.1" CACHE STRING "ONNX Runtime version") if(APPLE) set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-osx-universal2-${ONNXRUNTIME_VER}.tgz") diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index 590e76a4..19f15ee4 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -726,7 +726,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtLiteCustomOp(const char* op_name, const char* execution_provider) : op_name_(op_name), execution_provider_(execution_provider) { - OrtCustomOp::version = ORT_API_VERSION; + OrtCustomOp::version = GetActiveOrtAPIVersion(); OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); }; @@ -867,11 +867,6 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { void init(CustomComputeFn) { ParseArgs(input_types_, output_types_); - if (!input_types_.empty() && input_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED || - !output_types_.empty() && output_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { - OrtCustomOp::version = 14; - } - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { auto kernel = reinterpret_cast(op_kernel); std::vector tensors; @@ -915,4 +910,4 @@ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, } } // namespace Custom -} // namespace Ort \ No newline at end of file +} // namespace Ort diff --git a/includes/onnxruntime_customop.hpp b/includes/onnxruntime_customop.hpp index 839e6a5a..f3bfee2b 100644 --- a/includes/onnxruntime_customop.hpp +++ b/includes/onnxruntime_customop.hpp @@ -15,11 +15,13 @@ #include #include "onnxruntime_c_api.h" - #include "exceptions.h" #define MIN_ORT_VERSION_SUPPORTED 10 +extern "C" int ORT_API_CALL GetActiveOrtAPIVersion(); + + namespace OrtW { // diff --git a/shared/lib/ortcustomops.cc b/shared/lib/ortcustomops.cc index 798c4156..bb33c78f 100644 --- a/shared/lib/ortcustomops.cc +++ b/shared/lib/ortcustomops.cc @@ -3,6 +3,8 @@ #include #include +#include // for std::atoi +#include #include "onnxruntime_extensions.h" #include "ocos.h" @@ -59,6 +61,34 @@ class ExternalCustomOps { std::vector op_array_; }; +static int GetOrtVersion(const OrtApiBase* api_base = nullptr) { + static int ort_version = 11; // the default version is 1.11.0 + + if (api_base != nullptr) { + std::string str_version = api_base->GetVersionString(); + + std::size_t first_dot = str_version.find('.'); + if (first_dot != std::string::npos) { + std::size_t second_dot = str_version.find('.', first_dot + 1); + // If there is no second dot and the string has more than one character after the first dot, set second_dot to the string length + if (second_dot == std::string::npos && first_dot + 1 < str_version.length()) { + second_dot = str_version.length(); + } + + if (second_dot != std::string::npos) { + std::string str_minor_version = str_version.substr(first_dot + 1, second_dot - first_dot - 1); + int ver = std::atoi(str_minor_version.c_str()); + // Only change ort_version if conversion is successful (non-zero value) + if (ver != 0) { + ort_version = ver; + } + } + } + } + + return ort_version; +} + extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) { OCOS_API_IMPL_BEGIN ExternalCustomOps::instance().Add(c_op); @@ -66,12 +96,21 @@ extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) { return true; } +extern "C" int ORT_API_CALL GetActiveOrtAPIVersion() { + int ver = 0; + OCOS_API_IMPL_BEGIN + ver = GetOrtVersion(); + OCOS_API_IMPL_END + return ver; +} + extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { OrtStatus* status = nullptr; OCOS_API_IMPL_BEGIN OrtCustomOpDomain* domain = nullptr; - const OrtApi* ortApi = api->GetApi(ORT_API_VERSION); + auto ver = GetOrtVersion(api); + const OrtApi* ortApi = api->GetApi(ver); std::set pyop_nameset; #if defined(PYTHON_OP_SUPPORT) diff --git a/shared/ortcustomops.def b/shared/ortcustomops.def index 0786e042..a9a62adc 100644 --- a/shared/ortcustomops.def +++ b/shared/ortcustomops.def @@ -2,3 +2,4 @@ LIBRARY "ortextensions.dll" EXPORTS RegisterCustomOps @1 AddExternalCustomOp @2 + GetActiveOrtAPIVersion @3 diff --git a/shared/ortcustomops.ver b/shared/ortcustomops.ver index 508184d2..bd7c5af8 100644 --- a/shared/ortcustomops.ver +++ b/shared/ortcustomops.ver @@ -2,5 +2,6 @@ global: RegisterCustomOps; AddExternalCustomOp; + GetActiveOrtAPIVersion; local: *; };