using the latest Ort header instead of minimum compatible headers (#485)

* using the latest Ort header instead of minimum compatible headers

* Update ext_ortlib.cmake

* Update ortcustomops.def

* change the default ORT API version value
This commit is contained in:
Wenbing Li 2023-07-10 16:10:11 -07:00 коммит произвёл GitHub
Родитель 27132ced71
Коммит e3d9198de8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 49 добавлений и 11 удалений

Просмотреть файл

@ -8,8 +8,8 @@ else()
message(STATUS "CMAKE_SYSTEM_PROCESSOR=${CMAKE_SYSTEM_PROCESSOR}")
message(STATUS "CMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM}")
# default to 1.11.1 if not specified
set(ONNXRUNTIME_VER "1.12.1" CACHE STRING "ONNX Runtime version")
# 1.15.1 is the latest ORT release.
set(ONNXRUNTIME_VER "1.15.1" CACHE STRING "ONNX Runtime version")
if(APPLE)
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-osx-universal2-${ONNXRUNTIME_VER}.tgz")

Просмотреть файл

@ -726,7 +726,7 @@ struct OrtLiteCustomOp : public OrtCustomOp {
OrtLiteCustomOp(const char* op_name,
const char* execution_provider) : op_name_(op_name),
execution_provider_(execution_provider) {
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::version = GetActiveOrtAPIVersion();
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
@ -867,11 +867,6 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
void init(CustomComputeFn<Args...>) {
ParseArgs<Args...>(input_types_, output_types_);
if (!input_types_.empty() && input_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ||
!output_types_.empty() && output_types_[0] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
OrtCustomOp::version = 14;
}
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
std::vector<TensorPtr> tensors;

Просмотреть файл

@ -15,11 +15,13 @@
#include <type_traits>
#include "onnxruntime_c_api.h"
#include "exceptions.h"
#define MIN_ORT_VERSION_SUPPORTED 10
extern "C" int ORT_API_CALL GetActiveOrtAPIVersion();
namespace OrtW {
//

Просмотреть файл

@ -3,6 +3,8 @@
#include <mutex>
#include <set>
#include <cstdlib> // for std::atoi
#include <string>
#include "onnxruntime_extensions.h"
#include "ocos.h"
@ -59,6 +61,34 @@ class ExternalCustomOps {
std::vector<const OrtCustomOp*> op_array_;
};
static int GetOrtVersion(const OrtApiBase* api_base = nullptr) {
static int ort_version = 11; // the default version is 1.11.0
if (api_base != nullptr) {
std::string str_version = api_base->GetVersionString();
std::size_t first_dot = str_version.find('.');
if (first_dot != std::string::npos) {
std::size_t second_dot = str_version.find('.', first_dot + 1);
// If there is no second dot and the string has more than one character after the first dot, set second_dot to the string length
if (second_dot == std::string::npos && first_dot + 1 < str_version.length()) {
second_dot = str_version.length();
}
if (second_dot != std::string::npos) {
std::string str_minor_version = str_version.substr(first_dot + 1, second_dot - first_dot - 1);
int ver = std::atoi(str_minor_version.c_str());
// Only change ort_version if conversion is successful (non-zero value)
if (ver != 0) {
ort_version = ver;
}
}
}
}
return ort_version;
}
extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
OCOS_API_IMPL_BEGIN
ExternalCustomOps::instance().Add(c_op);
@ -66,12 +96,21 @@ extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
return true;
}
extern "C" int ORT_API_CALL GetActiveOrtAPIVersion() {
int ver = 0;
OCOS_API_IMPL_BEGIN
ver = GetOrtVersion();
OCOS_API_IMPL_END
return ver;
}
extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
OrtStatus* status = nullptr;
OCOS_API_IMPL_BEGIN
OrtCustomOpDomain* domain = nullptr;
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
auto ver = GetOrtVersion(api);
const OrtApi* ortApi = api->GetApi(ver);
std::set<std::string> pyop_nameset;
#if defined(PYTHON_OP_SUPPORT)

Просмотреть файл

@ -2,3 +2,4 @@ LIBRARY "ortextensions.dll"
EXPORTS
RegisterCustomOps @1
AddExternalCustomOp @2
GetActiveOrtAPIVersion @3

Просмотреть файл

@ -2,5 +2,6 @@
global:
RegisterCustomOps;
AddExternalCustomOp;
GetActiveOrtAPIVersion;
local: *;
};