Fix the bug in enabling ORT CustomOps in ONNXRuntime. (#90)

Co-authored-by: Zuwei Zhao <zuzhao@microsoft.com>
This commit is contained in:
Zuwei Zhao 2021-05-08 00:21:34 +08:00 коммит произвёл GitHub
Родитель 5fa95c9485
Коммит 3f377e0911
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 133 добавлений и 90 удалений

Просмотреть файл

@ -12,14 +12,15 @@ git clone https://github.com/microsoft/onnxruntime.git
cd onnxruntime
# there is no stable version including webassembly
git checkout 21c282ed
git checkout e6a3308db7c03a13e0f08b221b6770e17fc3a4ef
cd cmake/external
git clone git@github.com:microsoft/ort-customops.git
git clone git@github.com:microsoft/onnxruntime-extensions.git
cd ../..
cp cmake/external/ort-customops/test/data/custom_op_string_lower.onnx onnxruntime/test/testdata
git apply cmake/external/ort-customops/ci_build/onnxruntime_integration/onnxruntime_v1.8.patch
cp cmake/external/onnxruntime-extensions/test/data/custom_op_negpos.onnx onnxruntime/test/testdata
cp cmake/external/onnxruntime-extensions/test/data/custom_op_string_lower.onnx onnxruntime/test/testdata
git apply cmake/external/onnxruntime-extensions/ci_build/onnxruntime_integration/onnxruntime_v1.8.patch
#get ready and begin building
cd ..

Просмотреть файл

@ -1,18 +1,18 @@
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 46ba57cdc..5bde85a7b 100644
index ac9c62fb6..9893f703e 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -951,6 +951,14 @@ if (onnxruntime_USE_TVM)
@@ -966,6 +966,14 @@ if (onnxruntime_USE_TVM)
list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm nnvm_compiler)
endif()
+
+# ONNXRuntime-CustomOps
+set(OCOS_ENABLE_CTEST OFF CACHE INTERNAL "")
+set(OCOS_ENABLE_STATIC_LIB ON CACHE INTERNAL "")
+set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "")
+add_subdirectory(external/ort-customops EXCLUDE_FROM_ALL)
+add_subdirectory(external/onnxruntime-extensions EXCLUDE_FROM_ALL)
+target_include_directories(ortcustomops_static PRIVATE ${RE2_INCLUDE_DIR} external/json/include)
+target_include_directories(ortcustomops PUBLIC external/ort-customops/shared)
+target_include_directories(ortcustomops PUBLIC external/onnxruntime-extensions/shared)
+
if (APPLE OR CMAKE_SYSTEM_NAME STREQUAL "Android")
#onnx/onnx/proto_utils.h:34:16: error: 'SetTotalBytesLimit' is deprecated: Please use the single
@ -31,21 +31,21 @@ index df7eebf5a..2c511005a 100644
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
if (onnxruntime_USE_CUDA)
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index faca901ce..2fae0fb05 100644
index b28c44613..b932d55ce 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -458,6 +458,11 @@ struct OrtApi {
ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path,
void** library_handle);
+ /*
+ * Enable the operators in ORT Custom ops: https://github.com/microsoft/ort-customops
@@ -1277,6 +1277,11 @@ struct OrtApi {
*/
ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name,
_Out_ int64_t* out, _Inout_ size_t* size);
+
+ /**
+ * Enable custom operators in ORT CustomOps: https://github.com/microsoft/onnxruntime-extensions.git
+ */
+ ORT_API2_STATUS(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options);
+
/**
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
* functions to enable them in the session:
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index d85ecd776..d0e9fe6a3 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@ -75,19 +75,20 @@ index 64199ac6c..8ee32e4a1 100644
ThrowOnError(GetApi().EnableMemPattern(p_));
return *this;
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 12174163e..c21a92160 100644
index 12174163e..1fb57d39c 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -35,6 +35,8 @@
@@ -35,6 +35,9 @@
#include "abi_session_options_impl.h"
#include "core/framework/TensorSeq.h"
#include "core/platform/ort_mutex.h"
+
+#include "ortcustomops.h"
+
#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_factory.h"
#endif
@@ -403,6 +405,13 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions
@@ -403,6 +406,13 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions
API_IMPL_END
}
@ -101,101 +102,142 @@ index 12174163e..c21a92160 100644
namespace {
// provider either model_path, or modal_data + model_data_length.
static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options,
@@ -1963,7 +1972,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
&OrtApis::CustomOpDomain_Add,
&OrtApis::AddCustomOpDomain,
&OrtApis::RegisterCustomOpsLibrary,
-
@@ -2123,6 +2133,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
// Version 8 - In development, feel free to add/remove/rearrange here
&OrtApis::KernelInfoGetAttributeArray_float,
&OrtApis::KernelInfoGetAttributeArray_int64,
+ &OrtApis::EnableOrtCustomOps,
&OrtApis::SessionGetInputCount,
&OrtApis::SessionGetOutputCount,
&OrtApis::SessionGetOverridableInitializerCount,
@@ -2127,7 +2136,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
};
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
// If this assert hits, read the above 'Rules on how to add a new Ort API version'
-static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change");
+static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 102, "Size of version 1 API cannot change");
ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) {
if (version >= 1 && version <= ORT_API_VERSION)
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index f19b1d729..99def252c 100644
index f19b1d729..0dccf457a 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -73,6 +73,7 @@ ORT_API_STATUS_IMPL(CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtC
ORT_API_STATUS_IMPL(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op);
ORT_API_STATUS_IMPL(AddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain);
ORT_API_STATUS_IMPL(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, void** library_handle);
@@ -263,4 +263,6 @@ ORT_API_STATUS_IMPL(SetCurrentGpuDeviceId, _In_ int device_id);
ORT_API_STATUS_IMPL(GetCurrentGpuDeviceId, _In_ int* device_id);
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size);
ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out, _Inout_ size_t* size);
+
+ORT_API_STATUS_IMPL(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(SessionGetInputCount, _In_ const OrtSession* sess, _Out_ size_t* out);
ORT_API_STATUS_IMPL(SessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size_t* out);
} // namespace OrtApis
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
index 06073e528..ff51c92d8 100644
index 021636a16..5e34519a2 100644
--- a/onnxruntime/test/shared_lib/test_inference.cc
+++ b/onnxruntime/test/shared_lib/test_inference.cc
@@ -172,6 +172,7 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f
@@ -172,6 +172,8 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx");
+static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URL = TSTR("testdata/custom_op_string_lower.onnx");
+static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URI = TSTR("testdata/custom_op_string_lower.onnx");
+static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URI_2 = TSTR("testdata/custom_op_negpos.onnx");
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx");
@@ -265,6 +266,15 @@ TEST(CApiTest, custom_op_handler) {
@@ -265,6 +267,91 @@ TEST(CApiTest, custom_op_handler) {
#endif
}
+TEST(CApiTest, enable_custom_op) {
+TEST(CApiTest, test_enable_ort_customops_negpos) {
+
+ Ort::SessionOptions session_option;
+ session_option.EnableOrtCustomOps();
+ Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
+ auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
+
+ // only test session load
+ Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URL, session_option);
+ // Create Inputs
+ std::vector<Ort::Value> ort_inputs;
+ std::vector<float> input_data = {-1.1f, 2.2f, 4.4f, -5.5f};
+ std::vector<int64_t> input_dims = {2, 2};
+ ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(info, const_cast<float*>(input_data.data()), input_data.size(), input_dims.data(), input_dims.size()));
+
+ // Create Session with ORT CustomOps
+ Ort::SessionOptions session_options;
+ session_options.EnableOrtCustomOps();
+ Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URI_2, session_options);
+
+ // Create Input and Output Names
+ std::vector<const char*> input_names = {"X"};
+ const char* output_names[] = {"out0", "out1"};
+
+ // Run Session
+ std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names));
+
+ // Validate Results
+ ASSERT_EQ(ort_outputs.size(), 2u);
+
+ std::vector<int64_t> out_dims = {2, 2};
+ std::vector<float> values_out0 = {-1.1f, 0.0f, 0.0f, -5.5f};
+ auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(type_info.GetShape(), out_dims);
+ size_t total_len = type_info.GetElementCount();
+ ASSERT_EQ(values_out0.size(), total_len);
+
+ float* f = ort_outputs[0].GetTensorMutableData<float>();
+ for (size_t i = 0; i != total_len; ++i) {
+ ASSERT_EQ(values_out0[i], f[i]);
+ }
+}
+
+TEST(CApiTest, test_enable_ort_customops_stringlower) {
+
+ auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
+
+ // Create Inputs
+ std::vector<Ort::Value> ort_inputs;
+ std::string input_data{"HI, This is ENGINEER from Microsoft."};
+ const char* const input_strings[] = {input_data.c_str()};
+ std::vector<int64_t> input_dims = {1, 1};
+
+ Ort::Value input_tensor = Ort::Value::CreateTensor(allocator.get(), input_dims.data(), input_dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
+ input_tensor.FillStringTensor(input_strings, 1U);
+ ort_inputs.push_back(std::move(input_tensor));
+
+ // Create Session with ORT CustomOps
+ Ort::SessionOptions session_options;
+ session_options.EnableOrtCustomOps();
+ Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URI, session_options);
+
+ // Create Input and Output Names
+ std::vector<const char*> input_names = {"input_1"};
+ const char* output_names[] = {"customout"};
+
+ // Run Session
+ std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names));
+
+ // Validate Results
+ ASSERT_EQ(ort_outputs.size(), 1u);
+
+ std::vector<int64_t> out_dims = {1, 1};
+ auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
+ ASSERT_EQ(type_info.GetShape(), out_dims);
+ ASSERT_EQ(type_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
+
+ std::string output_data{"hi, this is engineer from microsoft."};
+ auto expected_string = output_data.c_str();
+ size_t expected_string_len = strlen(expected_string);
+ auto data_length = ort_outputs[0].GetStringTensorDataLength();
+ ASSERT_EQ(expected_string_len, data_length);
+
+ std::string result(data_length, '\0');
+ std::vector<size_t> offsets(type_info.GetElementCount());
+ ort_outputs[0].GetStringTensorContent((void*)result.data(), data_length, offsets.data(), offsets.size());
+ ASSERT_STREQ(result.c_str(), expected_string);
+}
+
//test custom op which accepts float and double as inputs
TEST(CApiTest, varied_input_custom_op_handler) {
std::vector<Input> inputs(2);
diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc
index c3ba4eedd..ff7bbd51e 100644
index 523a2fc6f..970ca1552 100644
--- a/onnxruntime/wasm/api.cc
+++ b/onnxruntime/wasm/api.cc
@@ -28,6 +28,16 @@ Ort::Session* OrtCreateSession(void* data, size_t data_length) {
return new Ort::Session(*g_env, data, data_length, session_options);
}
@@ -22,6 +22,10 @@ Ort::Session* OrtCreateSession(void* data, size_t data_length) {
Ort::SessionOptions session_options;
session_options.SetLogId("onnxruntime");
+Ort::Session* OrtCreateSessionWithCustomOps(void* data, size_t data_length) {
+ Ort::SessionOptions session_options;
+ session_options.SetLogId("onnxruntime");
+ // Enable ORT CustomOps
+ // TODO: add condition check here to enable ORT CustomOps
+ session_options.EnableOrtCustomOps();
+
+ // disable thread pool for now since not all major browsers support WebAssembly threading.
+ session_options.SetIntraOpNumThreads(1);
+
+ return new Ort::Session(*g_env, data, data_length, session_options);
+}
+
void OrtReleaseSession(Ort::Session* session) {
delete session;
}
diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h
index 4f98a198a..e9500e1fa 100644
--- a/onnxruntime/wasm/api.h
+++ b/onnxruntime/wasm/api.h
@@ -35,6 +35,14 @@ void EMSCRIPTEN_KEEPALIVE OrtInit();
*/
ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSession(void* data, size_t data_length);
+/**
+ * create an instance of ORT session with enabling Ort-Customops.
+ * @param data a pointer to a buffer that contains the ONNX or ORT format model.
+ * @param data_length the size of the buffer in bytes.
+ * @returns a handle of the ORT session.
+ */
+ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionWithCustomOps(void* data, size_t data_length);
+
/**
* release the specified ORT session.
*/
#if !defined(__EMSCRIPTEN_PTHREADS__)
// must disable thread pool when WebAssembly multi-threads support is disabled.
session_options.SetIntraOpNumThreads(1);

Двоичные данные
test/data/custom_op_negpos.onnx Normal file

Двоичный файл не отображается.