using the latest Ort header instead of minimum compatible headers (#485)
* using the latest Ort header instead of minimum compatible headers * Update ext_ortlib.cmake * Update ortcustomops.def * change the default ORT API version value
This commit is contained in:
Родитель
27132ced71
Коммит
e3d9198de8
|
@ -8,8 +8,8 @@ else()
|
|||
message(STATUS "CMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR}")
|
||||
message(STATUS "CMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM}")
|
||||
|
||||
# default to 1.11.1 if not specified
|
||||
set(ONNXRUNTIME_VER "1.12.1" CACHE STRING "ONNX Runtime version")
|
||||
# 1.15.1 is the latest ORT release.
|
||||
set(ONNXRUNTIME_VER "1.15.1" CACHE STRING "ONNX Runtime version")
|
||||
|
||||
if(APPLE)
|
||||
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-osx-universal2-${ONNXRUNTIME_VER}.tgz")
|
||||
|
|
|
@ -726,7 +726,7 @@ struct OrtLiteCustomOp : public OrtCustomOp {
|
|||
OrtLiteCustomOp(const char* op_name,
|
||||
const char* execution_provider) : op_name_(op_name),
|
||||
execution_provider_(execution_provider) {
|
||||
OrtCustomOp::version = ORT_API_VERSION;
|
||||
OrtCustomOp::version = GetActiveOrtAPIVersion();
|
||||
|
||||
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
|
||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
|
||||
|
@ -867,11 +867,6 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
|
|||
void init(CustomComputeFn<Args...>) {
|
||||
ParseArgs<Args...>(input_types_, output_types_);
|
||||
|
||||
if (!input_types_.empty() && input_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ||
|
||||
!output_types_.empty() && output_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
|
||||
OrtCustomOp::version = 14;
|
||||
}
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
||||
std::vector<TensorPtr> tensors;
|
||||
|
|
|
@ -15,11 +15,13 @@
|
|||
#include <type_traits>
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
#include "exceptions.h"
|
||||
|
||||
#define MIN_ORT_VERSION_SUPPORTED 10
|
||||
|
||||
extern "C" int ORT_API_CALL GetActiveOrtAPIVersion();
|
||||
|
||||
|
||||
namespace OrtW {
|
||||
|
||||
//
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <cstdlib> // for std::atoi
|
||||
#include <string>
|
||||
|
||||
#include "onnxruntime_extensions.h"
|
||||
#include "ocos.h"
|
||||
|
@ -59,6 +61,34 @@ class ExternalCustomOps {
|
|||
std::vector<const OrtCustomOp*> op_array_;
|
||||
};
|
||||
|
||||
static int GetOrtVersion(const OrtApiBase* api_base = nullptr) {
|
||||
static int ort_version = 11; // the default version is 1.11.0
|
||||
|
||||
if (api_base != nullptr) {
|
||||
std::string str_version = api_base->GetVersionString();
|
||||
|
||||
std::size_t first_dot = str_version.find('.');
|
||||
if (first_dot != std::string::npos) {
|
||||
std::size_t second_dot = str_version.find('.', first_dot + 1);
|
||||
// If there is no second dot and the string has more than one character after the first dot, set second_dot to the string length
|
||||
if (second_dot == std::string::npos && first_dot + 1 < str_version.length()) {
|
||||
second_dot = str_version.length();
|
||||
}
|
||||
|
||||
if (second_dot != std::string::npos) {
|
||||
std::string str_minor_version = str_version.substr(first_dot + 1, second_dot - first_dot - 1);
|
||||
int ver = std::atoi(str_minor_version.c_str());
|
||||
// Only change ort_version if conversion is successful (non-zero value)
|
||||
if (ver != 0) {
|
||||
ort_version = ver;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ort_version;
|
||||
}
|
||||
|
||||
extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
|
||||
OCOS_API_IMPL_BEGIN
|
||||
ExternalCustomOps::instance().Add(c_op);
|
||||
|
@ -66,12 +96,21 @@ extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
|
|||
return true;
|
||||
}
|
||||
|
||||
extern "C" int ORT_API_CALL GetActiveOrtAPIVersion() {
|
||||
int ver = 0;
|
||||
OCOS_API_IMPL_BEGIN
|
||||
ver = GetOrtVersion();
|
||||
OCOS_API_IMPL_END
|
||||
return ver;
|
||||
}
|
||||
|
||||
extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
|
||||
OrtStatus* status = nullptr;
|
||||
OCOS_API_IMPL_BEGIN
|
||||
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
|
||||
auto ver = GetOrtVersion(api);
|
||||
const OrtApi* ortApi = api->GetApi(ver);
|
||||
std::set<std::string> pyop_nameset;
|
||||
|
||||
#if defined(PYTHON_OP_SUPPORT)
|
||||
|
|
|
@ -2,3 +2,4 @@ LIBRARY "ortextensions.dll"
|
|||
EXPORTS
|
||||
RegisterCustomOps @1
|
||||
AddExternalCustomOp @2
|
||||
GetActiveOrtAPIVersion @3
|
||||
|
|
|
@ -2,5 +2,6 @@
|
|||
global:
|
||||
RegisterCustomOps;
|
||||
AddExternalCustomOp;
|
||||
GetActiveOrtAPIVersion;
|
||||
local: *;
|
||||
};
|
||||
|
|
Загрузка…
Ссылка в новой задаче