Fix the bug in enabling ORT CustomOps in ONNXRuntime. (#90)
Co-authored-by: Zuwei Zhao <zuzhao@microsoft.com>
This commit is contained in:
Родитель
5fa95c9485
Коммит
3f377e0911
|
@ -12,14 +12,15 @@ git clone https://github.com/microsoft/onnxruntime.git
|
|||
cd onnxruntime
|
||||
|
||||
# there is no stable version including webassembly
|
||||
git checkout 21c282ed
|
||||
git checkout e6a3308db7c03a13e0f08b221b6770e17fc3a4ef
|
||||
cd cmake/external
|
||||
git clone git@github.com:microsoft/ort-customops.git
|
||||
git clone git@github.com:microsoft/onnxruntime-extensions.git
|
||||
|
||||
cd ../..
|
||||
|
||||
cp cmake/external/ort-customops/test/data/custom_op_string_lower.onnx onnxruntime/test/testdata
|
||||
git apply cmake/external/ort-customops/ci_build/onnxruntime_integration/onnxruntime_v1.8.patch
|
||||
cp cmake/external/onnxruntime-extensions/test/data/custom_op_negpos.onnx onnxruntime/test/testdata
|
||||
cp cmake/external/onnxruntime-extensions/test/data/custom_op_string_lower.onnx onnxruntime/test/testdata
|
||||
git apply cmake/external/onnxruntime-extensions/ci_build/onnxruntime_integration/onnxruntime_v1.8.patch
|
||||
|
||||
#get ready and begin building
|
||||
cd ..
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
|
||||
index 46ba57cdc..5bde85a7b 100644
|
||||
index ac9c62fb6..9893f703e 100644
|
||||
--- a/cmake/CMakeLists.txt
|
||||
+++ b/cmake/CMakeLists.txt
|
||||
@@ -951,6 +951,14 @@ if (onnxruntime_USE_TVM)
|
||||
@@ -966,6 +966,14 @@ if (onnxruntime_USE_TVM)
|
||||
list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm nnvm_compiler)
|
||||
endif()
|
||||
|
||||
+
|
||||
+# ONNXRuntime-CustomOps
|
||||
+set(OCOS_ENABLE_CTEST OFF CACHE INTERNAL "")
|
||||
+set(OCOS_ENABLE_STATIC_LIB ON CACHE INTERNAL "")
|
||||
+set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "")
|
||||
+add_subdirectory(external/ort-customops EXCLUDE_FROM_ALL)
|
||||
+add_subdirectory(external/onnxruntime-extensions EXCLUDE_FROM_ALL)
|
||||
+target_include_directories(ortcustomops_static PRIVATE ${RE2_INCLUDE_DIR} external/json/include)
|
||||
+target_include_directories(ortcustomops PUBLIC external/ort-customops/shared)
|
||||
+target_include_directories(ortcustomops PUBLIC external/onnxruntime-extensions/shared)
|
||||
+
|
||||
if (APPLE OR CMAKE_SYSTEM_NAME STREQUAL "Android")
|
||||
#onnx/onnx/proto_utils.h:34:16: error: 'SetTotalBytesLimit' is deprecated: Please use the single
|
||||
|
@ -31,21 +31,21 @@ index df7eebf5a..2c511005a 100644
|
|||
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
|
||||
if (onnxruntime_USE_CUDA)
|
||||
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
|
||||
index faca901ce..2fae0fb05 100644
|
||||
index b28c44613..b932d55ce 100644
|
||||
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
|
||||
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
|
||||
@@ -458,6 +458,11 @@ struct OrtApi {
|
||||
ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path,
|
||||
void** library_handle);
|
||||
|
||||
+ /*
|
||||
+ * Enable the operators in ORT Custom ops: https://github.com/microsoft/ort-customops
|
||||
@@ -1277,6 +1277,11 @@ struct OrtApi {
|
||||
*/
|
||||
ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
|
||||
_Out_ int64_t* out, _Inout_ size_t* size);
|
||||
+
|
||||
+ /**
|
||||
+ * Enable custom operators in ORT CustomOps: https://github.com/microsoft/onnxruntime-extensions.git
|
||||
+ */
|
||||
+ ORT_API2_STATUS(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options);
|
||||
+
|
||||
/**
|
||||
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
|
||||
* functions to enable them in the session:
|
||||
};
|
||||
|
||||
/*
|
||||
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
|
||||
index d85ecd776..d0e9fe6a3 100644
|
||||
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
|
||||
|
@ -75,19 +75,20 @@ index 64199ac6c..8ee32e4a1 100644
|
|||
ThrowOnError(GetApi().EnableMemPattern(p_));
|
||||
return *this;
|
||||
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
|
||||
index 12174163e..c21a92160 100644
|
||||
index 12174163e..1fb57d39c 100644
|
||||
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
|
||||
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
|
||||
@@ -35,6 +35,8 @@
|
||||
@@ -35,6 +35,9 @@
|
||||
#include "abi_session_options_impl.h"
|
||||
#include "core/framework/TensorSeq.h"
|
||||
#include "core/platform/ort_mutex.h"
|
||||
+
|
||||
+#include "ortcustomops.h"
|
||||
+
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_provider_factory.h"
|
||||
#endif
|
||||
@@ -403,6 +405,13 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions
|
||||
@@ -403,6 +406,13 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
|
@ -101,101 +102,142 @@ index 12174163e..c21a92160 100644
|
|||
namespace {
|
||||
// provider either model_path, or modal_data + model_data_length.
|
||||
static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options,
|
||||
@@ -1963,7 +1972,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
|
||||
&OrtApis::CustomOpDomain_Add,
|
||||
&OrtApis::AddCustomOpDomain,
|
||||
&OrtApis::RegisterCustomOpsLibrary,
|
||||
-
|
||||
@@ -2123,6 +2133,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
|
||||
// Version 8 - In development, feel free to add/remove/rearrange here
|
||||
&OrtApis::KernelInfoGetAttributeArray_float,
|
||||
&OrtApis::KernelInfoGetAttributeArray_int64,
|
||||
+ &OrtApis::EnableOrtCustomOps,
|
||||
&OrtApis::SessionGetInputCount,
|
||||
&OrtApis::SessionGetOutputCount,
|
||||
&OrtApis::SessionGetOverridableInitializerCount,
|
||||
@@ -2127,7 +2136,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
|
||||
};
|
||||
|
||||
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
|
||||
// If this assert hits, read the above 'Rules on how to add a new Ort API version'
|
||||
-static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change");
|
||||
+static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 102, "Size of version 1 API cannot change");
|
||||
|
||||
ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) {
|
||||
if (version >= 1 && version <= ORT_API_VERSION)
|
||||
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
|
||||
index f19b1d729..99def252c 100644
|
||||
index f19b1d729..0dccf457a 100644
|
||||
--- a/onnxruntime/core/session/ort_apis.h
|
||||
+++ b/onnxruntime/core/session/ort_apis.h
|
||||
@@ -73,6 +73,7 @@ ORT_API_STATUS_IMPL(CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtC
|
||||
ORT_API_STATUS_IMPL(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op);
|
||||
ORT_API_STATUS_IMPL(AddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain);
|
||||
ORT_API_STATUS_IMPL(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, void** library_handle);
|
||||
@@ -263,4 +263,6 @@ ORT_API_STATUS_IMPL(SetCurrentGpuDeviceId, _In_ int device_id);
|
||||
ORT_API_STATUS_IMPL(GetCurrentGpuDeviceId, _In_ int* device_id);
|
||||
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size);
|
||||
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out, _Inout_ size_t* size);
|
||||
+
|
||||
+ORT_API_STATUS_IMPL(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options);
|
||||
|
||||
ORT_API_STATUS_IMPL(SessionGetInputCount, _In_ const OrtSession* sess, _Out_ size_t* out);
|
||||
ORT_API_STATUS_IMPL(SessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size_t* out);
|
||||
} // namespace OrtApis
|
||||
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
|
||||
index 06073e528..ff51c92d8 100644
|
||||
index 021636a16..5e34519a2 100644
|
||||
--- a/onnxruntime/test/shared_lib/test_inference.cc
|
||||
+++ b/onnxruntime/test/shared_lib/test_inference.cc
|
||||
@@ -172,6 +172,7 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f
|
||||
@@ -172,6 +172,8 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f
|
||||
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
|
||||
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
|
||||
static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx");
|
||||
+static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URL = TSTR("testdata/custom_op_string_lower.onnx");
|
||||
+static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URI = TSTR("testdata/custom_op_string_lower.onnx");
|
||||
+static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URI_2 = TSTR("testdata/custom_op_negpos.onnx");
|
||||
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx");
|
||||
@@ -265,6 +266,15 @@ TEST(CApiTest, custom_op_handler) {
|
||||
@@ -265,6 +267,91 @@ TEST(CApiTest, custom_op_handler) {
|
||||
#endif
|
||||
}
|
||||
|
||||
+TEST(CApiTest, enable_custom_op) {
|
||||
+TEST(CApiTest, test_enable_ort_customops_negpos) {
|
||||
+
|
||||
+ Ort::SessionOptions session_option;
|
||||
+ session_option.EnableOrtCustomOps();
|
||||
+ Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||
+ auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
+
|
||||
+ // only test session load
|
||||
+ Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URL, session_option);
|
||||
+ // Create Inputs
|
||||
+ std::vector<Ort::Value> ort_inputs;
|
||||
+ std::vector<float> input_data = {-1.1f, 2.2f, 4.4f, -5.5f};
|
||||
+ std::vector<int64_t> input_dims = {2, 2};
|
||||
+ ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(info, const_cast<float*>(input_data.data()), input_data.size(), input_dims.data(), input_dims.size()));
|
||||
+
|
||||
+ // Create Session with ORT CustomOps
|
||||
+ Ort::SessionOptions session_options;
|
||||
+ session_options.EnableOrtCustomOps();
|
||||
+ Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URI_2, session_options);
|
||||
+
|
||||
+ // Create Input and Output Names
|
||||
+ std::vector<const char*> input_names = {"X"};
|
||||
+ const char* output_names[] = {"out0", "out1"};
|
||||
+
|
||||
+ // Run Session
|
||||
+ std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names));
|
||||
+
|
||||
+ // Validate Results
|
||||
+ ASSERT_EQ(ort_outputs.size(), 2u);
|
||||
+
|
||||
+ std::vector<int64_t> out_dims = {2, 2};
|
||||
+ std::vector<float> values_out0 = {-1.1f, 0.0f, 0.0f, -5.5f};
|
||||
+ auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
|
||||
+ ASSERT_EQ(type_info.GetShape(), out_dims);
|
||||
+ size_t total_len = type_info.GetElementCount();
|
||||
+ ASSERT_EQ(values_out0.size(), total_len);
|
||||
+
|
||||
+ float* f = ort_outputs[0].GetTensorMutableData<float>();
|
||||
+ for (size_t i = 0; i != total_len; ++i) {
|
||||
+ ASSERT_EQ(values_out0[i], f[i]);
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+TEST(CApiTest, test_enable_ort_customops_stringlower) {
|
||||
+
|
||||
+ auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
+
|
||||
+ // Create Inputs
|
||||
+ std::vector<Ort::Value> ort_inputs;
|
||||
+ std::string input_data{"HI, This is ENGINEER from Microsoft."};
|
||||
+ const char* const input_strings[] = {input_data.c_str()};
|
||||
+ std::vector<int64_t> input_dims = {1, 1};
|
||||
+
|
||||
+ Ort::Value input_tensor = Ort::Value::CreateTensor(allocator.get(), input_dims.data(), input_dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
+ input_tensor.FillStringTensor(input_strings, 1U);
|
||||
+ ort_inputs.push_back(std::move(input_tensor));
|
||||
+
|
||||
+ // Create Session with ORT CustomOps
|
||||
+ Ort::SessionOptions session_options;
|
||||
+ session_options.EnableOrtCustomOps();
|
||||
+ Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URI, session_options);
|
||||
+
|
||||
+ // Create Input and Output Names
|
||||
+ std::vector<const char*> input_names = {"input_1"};
|
||||
+ const char* output_names[] = {"customout"};
|
||||
+
|
||||
+ // Run Session
|
||||
+ std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names));
|
||||
+
|
||||
+ // Validate Results
|
||||
+ ASSERT_EQ(ort_outputs.size(), 1u);
|
||||
+
|
||||
+ std::vector<int64_t> out_dims = {1, 1};
|
||||
+ auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
|
||||
+ ASSERT_EQ(type_info.GetShape(), out_dims);
|
||||
+ ASSERT_EQ(type_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
+
|
||||
+ std::string output_data{"hi, this is engineer from microsoft."};
|
||||
+ auto expected_string = output_data.c_str();
|
||||
+ size_t expected_string_len = strlen(expected_string);
|
||||
+ auto data_length = ort_outputs[0].GetStringTensorDataLength();
|
||||
+ ASSERT_EQ(expected_string_len, data_length);
|
||||
+
|
||||
+ std::string result(data_length, '\0');
|
||||
+ std::vector<size_t> offsets(type_info.GetElementCount());
|
||||
+ ort_outputs[0].GetStringTensorContent((void*)result.data(), data_length, offsets.data(), offsets.size());
|
||||
+ ASSERT_STREQ(result.c_str(), expected_string);
|
||||
+}
|
||||
+
|
||||
//test custom op which accepts float and double as inputs
|
||||
TEST(CApiTest, varied_input_custom_op_handler) {
|
||||
std::vector<Input> inputs(2);
|
||||
diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc
|
||||
index c3ba4eedd..ff7bbd51e 100644
|
||||
index 523a2fc6f..970ca1552 100644
|
||||
--- a/onnxruntime/wasm/api.cc
|
||||
+++ b/onnxruntime/wasm/api.cc
|
||||
@@ -28,6 +28,16 @@ Ort::Session* OrtCreateSession(void* data, size_t data_length) {
|
||||
return new Ort::Session(*g_env, data, data_length, session_options);
|
||||
}
|
||||
@@ -22,6 +22,10 @@ Ort::Session* OrtCreateSession(void* data, size_t data_length) {
|
||||
Ort::SessionOptions session_options;
|
||||
session_options.SetLogId("onnxruntime");
|
||||
|
||||
+Ort::Session* OrtCreateSessionWithCustomOps(void* data, size_t data_length) {
|
||||
+ Ort::SessionOptions session_options;
|
||||
+ session_options.SetLogId("onnxruntime");
|
||||
+ // Enable ORT CustomOps
|
||||
+ // TODO: add condition check here to enable ORT CustomOps
|
||||
+ session_options.EnableOrtCustomOps();
|
||||
+
|
||||
+ // disable thread pool for now since not all major browsers support WebAssembly threading.
|
||||
+ session_options.SetIntraOpNumThreads(1);
|
||||
+
|
||||
+ return new Ort::Session(*g_env, data, data_length, session_options);
|
||||
+}
|
||||
+
|
||||
void OrtReleaseSession(Ort::Session* session) {
|
||||
delete session;
|
||||
}
|
||||
diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h
|
||||
index 4f98a198a..e9500e1fa 100644
|
||||
--- a/onnxruntime/wasm/api.h
|
||||
+++ b/onnxruntime/wasm/api.h
|
||||
@@ -35,6 +35,14 @@ void EMSCRIPTEN_KEEPALIVE OrtInit();
|
||||
*/
|
||||
ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSession(void* data, size_t data_length);
|
||||
|
||||
+/**
|
||||
+ * create an instance of ORT session with enabling Ort-Customops.
|
||||
+ * @param data a pointer to a buffer that contains the ONNX or ORT format model.
|
||||
+ * @param data_length the size of the buffer in bytes.
|
||||
+ * @returns a handle of the ORT session.
|
||||
+ */
|
||||
+ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionWithCustomOps(void* data, size_t data_length);
|
||||
+
|
||||
/**
|
||||
* release the specified ORT session.
|
||||
*/
|
||||
#if !defined(__EMSCRIPTEN_PTHREADS__)
|
||||
// must disable thread pool when WebAssembly multi-threads support is disabled.
|
||||
session_options.SetIntraOpNumThreads(1);
|
||||
|
|
Двоичный файл не отображается.
Загрузка…
Ссылка в новой задаче