Update pytorch_custom_ops_tutorial.ipynb (#196)
* Update pytorch_custom_ops_tutorial.ipynb * Update mshost.yaml * tidy up
This commit is contained in:
Родитель
f418557d85
Коммит
459c4f7d61
|
@ -5,6 +5,7 @@ build_host_protoc
|
|||
build_android
|
||||
build_ios
|
||||
build_*
|
||||
.venv/
|
||||
_subbuild/
|
||||
.build_debug/*
|
||||
.build_release/*
|
||||
|
|
|
@ -258,9 +258,9 @@ jobs:
|
|||
py37-170:
|
||||
python.version: '3.7'
|
||||
ort.version: '1.7.0'
|
||||
py36-170:
|
||||
python.version: '3.6'
|
||||
ort.version: '1.6.0'
|
||||
# py36-160:
|
||||
# python.version: '3.6'
|
||||
# ort.version: '1.6.0'
|
||||
maxParallel: 1
|
||||
|
||||
steps:
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
BASE_DIR=onnxruntime_customops_integration
|
||||
|
||||
if [ ! -d $BASE_DIR ]
|
||||
then
|
||||
mkdir $BASE_DIR
|
||||
fi
|
||||
cd $BASE_DIR
|
||||
|
||||
git clone https://github.com/microsoft/onnxruntime.git
|
||||
cd onnxruntime
|
||||
|
||||
# The latest commit of ONNXRuntime that has been verified for integration (#7941)
|
||||
git checkout fd23b8caaddc4c460463774d696af13bef63aa46
|
||||
cd cmake/external
|
||||
git clone git@github.com:microsoft/onnxruntime-extensions.git
|
||||
|
||||
cd ../..
|
||||
|
||||
cp cmake/external/onnxruntime-extensions/test/data/custom_op_negpos.onnx onnxruntime/test/testdata
|
||||
cp cmake/external/onnxruntime-extensions/test/data/custom_op_string_lower.onnx onnxruntime/test/testdata
|
||||
git apply cmake/external/onnxruntime-extensions/ci_build/onnxruntime_integration/onnxruntime_v1.8.patch
|
||||
|
||||
#get ready and begin building
|
||||
cd ..
|
||||
if [ ! -d build ]
|
||||
then
|
||||
mkdir build
|
||||
fi
|
||||
python3 onnxruntime/tools/ci_build/build.py --build_dir build "$@"
|
||||
|
||||
|
|
@ -1,244 +0,0 @@
|
|||
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
|
||||
index 95bd1416c4..8c736bd69f 100644
|
||||
--- a/cmake/CMakeLists.txt
|
||||
+++ b/cmake/CMakeLists.txt
|
||||
@@ -1195,6 +1195,15 @@ if (onnxruntime_USE_TVM)
|
||||
list(APPEND onnxruntime_EXTERNAL_DEPENDENCIES tvm nnvm_compiler)
|
||||
endif()
|
||||
|
||||
+# ONNXRuntime-CustomOps
|
||||
+set(OCOS_ENABLE_CTEST OFF CACHE INTERNAL "")
|
||||
+set(OCOS_ENABLE_STATIC_LIB ON CACHE INTERNAL "")
|
||||
+set(OCOS_ENABLE_SPM_TOKENIZER OFF CACHE INTERNAL "")
|
||||
+add_subdirectory(external/onnxruntime-extensions EXCLUDE_FROM_ALL)
|
||||
+# target library or executable are defined in CMakeLists.txt of onnxruntime-extensions
|
||||
+target_include_directories(ocos_operators PRIVATE ${RE2_INCLUDE_DIR} external/json/include)
|
||||
+target_include_directories(ortcustomops PUBLIC external/onnxruntime-extensions/shared)
|
||||
+
|
||||
if (APPLE OR CMAKE_SYSTEM_NAME STREQUAL "Android")
|
||||
#onnx/onnx/proto_utils.h:34:16: error: 'SetTotalBytesLimit' is deprecated: Please use the single
|
||||
#parameter version of SetTotalBytesLimit(). The second parameter is ignored.
|
||||
diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake
|
||||
index bb5eb0d166..bbc7a954e9 100644
|
||||
--- a/cmake/onnxruntime_session.cmake
|
||||
+++ b/cmake/onnxruntime_session.cmake
|
||||
@@ -16,7 +16,7 @@ if(onnxruntime_ENABLE_INSTRUMENT)
|
||||
target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT)
|
||||
endif()
|
||||
target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS})
|
||||
-target_link_libraries(onnxruntime_session PRIVATE nlohmann_json::nlohmann_json)
|
||||
+target_link_libraries(onnxruntime_session PRIVATE nlohmann_json::nlohmann_json ortcustomops)
|
||||
add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
|
||||
if (onnxruntime_USE_CUDA)
|
||||
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
|
||||
index 838d202c28..e6db1b6eaa 100644
|
||||
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
|
||||
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
|
||||
@@ -1382,6 +1382,11 @@ struct OrtApi {
|
||||
_In_ const void* model_data, size_t model_data_length,
|
||||
_In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container,
|
||||
_Outptr_ OrtSession** out);
|
||||
+
|
||||
+ /**
|
||||
+ * Enable custom operators in ORT CustomOps: https://github.com/microsoft/onnxruntime-extensions.git
|
||||
+ */
|
||||
+ ORT_API2_STATUS(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options);
|
||||
};
|
||||
|
||||
/*
|
||||
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
|
||||
index 4bd07b47c6..0fa74bcf38 100644
|
||||
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
|
||||
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
|
||||
@@ -310,6 +310,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
|
||||
|
||||
SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
|
||||
SessionOptions& DisableProfiling();
|
||||
+ SessionOptions& EnableOrtCustomOps();
|
||||
|
||||
SessionOptions& EnableMemPattern();
|
||||
SessionOptions& DisableMemPattern();
|
||||
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
|
||||
index d08b890e80..f4b78b562f 100644
|
||||
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
|
||||
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
|
||||
@@ -440,6 +440,11 @@ inline SessionOptions& SessionOptions::DisableProfiling() {
|
||||
return *this;
|
||||
}
|
||||
|
||||
+inline SessionOptions& SessionOptions::EnableOrtCustomOps() {
|
||||
+ ThrowOnError(GetApi().EnableOrtCustomOps(p_));
|
||||
+ return *this;
|
||||
+}
|
||||
+
|
||||
inline SessionOptions& SessionOptions::EnableMemPattern() {
|
||||
ThrowOnError(GetApi().EnableMemPattern(p_));
|
||||
return *this;
|
||||
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
|
||||
index e73bdbe18a..46238f02df 100644
|
||||
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
|
||||
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
|
||||
@@ -35,6 +35,9 @@
|
||||
#include "abi_session_options_impl.h"
|
||||
#include "core/framework/TensorSeq.h"
|
||||
#include "core/platform/ort_mutex.h"
|
||||
+
|
||||
+#include "ortcustomops.h"
|
||||
+
|
||||
#ifdef USE_CUDA
|
||||
#include "core/providers/cuda/cuda_provider_factory.h"
|
||||
#endif
|
||||
@@ -403,6 +406,13 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
+ORT_API_STATUS_IMPL(OrtApis::EnableOrtCustomOps, _Inout_ OrtSessionOptions* options) {
|
||||
+ API_IMPL_BEGIN
|
||||
+
|
||||
+ return RegisterCustomOps(options, OrtGetApiBase());
|
||||
+ API_IMPL_END
|
||||
+}
|
||||
+
|
||||
namespace {
|
||||
// provider either model_path, or modal_data + model_data_length.
|
||||
static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options,
|
||||
@@ -2228,6 +2238,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
|
||||
// End of Version 8 - DO NOT MODIFY ABOVE (see above text for more information)
|
||||
|
||||
// Version 9 - In development, feel free to add/remove/rearrange here
|
||||
+ &OrtApis::EnableOrtCustomOps,
|
||||
};
|
||||
|
||||
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
|
||||
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
|
||||
index 2d99582724..0588e5af42 100644
|
||||
--- a/onnxruntime/core/session/ort_apis.h
|
||||
+++ b/onnxruntime/core/session/ort_apis.h
|
||||
@@ -277,4 +277,5 @@ ORT_API_STATUS_IMPL(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ co
|
||||
_In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container,
|
||||
_Outptr_ OrtSession** out);
|
||||
|
||||
+ORT_API_STATUS_IMPL(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options);
|
||||
} // namespace OrtApis
|
||||
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
|
||||
index 4f04588d93..d5c287893b 100644
|
||||
--- a/onnxruntime/test/shared_lib/test_inference.cc
|
||||
+++ b/onnxruntime/test/shared_lib/test_inference.cc
|
||||
@@ -173,6 +173,8 @@ static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/f
|
||||
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
|
||||
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
|
||||
static constexpr PATH_TYPE CUSTOM_OP_MODEL_WITH_ATTRIBUTES_URI = TSTR("testdata/foo_bar_3.onnx");
|
||||
+static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URI = TSTR("testdata/custom_op_string_lower.onnx");
|
||||
+static constexpr PATH_TYPE ORT_CUSTOM_OPS_MODEL_URI_2 = TSTR("testdata/custom_op_negpos.onnx");
|
||||
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
static constexpr PATH_TYPE PYOP_FLOAT_MODEL_URI = TSTR("testdata/pyop_1.onnx");
|
||||
@@ -266,6 +268,91 @@ TEST(CApiTest, custom_op_handler) {
|
||||
#endif
|
||||
}
|
||||
|
||||
+TEST(CApiTest, test_enable_ort_customops_negpos) {
|
||||
+
|
||||
+ Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
||||
+ auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
+
|
||||
+ // Create Inputs
|
||||
+ std::vector<Ort::Value> ort_inputs;
|
||||
+ std::vector<float> input_data = {-1.1f, 2.2f, 4.4f, -5.5f};
|
||||
+ std::vector<int64_t> input_dims = {2, 2};
|
||||
+ ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(info, const_cast<float*>(input_data.data()), input_data.size(), input_dims.data(), input_dims.size()));
|
||||
+
|
||||
+ // Create Session with ORT CustomOps
|
||||
+ Ort::SessionOptions session_options;
|
||||
+ session_options.EnableOrtCustomOps();
|
||||
+ Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URI_2, session_options);
|
||||
+
|
||||
+ // Create Input and Output Names
|
||||
+ std::vector<const char*> input_names = {"X"};
|
||||
+ const char* output_names[] = {"out0", "out1"};
|
||||
+
|
||||
+ // Run Session
|
||||
+ std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names));
|
||||
+
|
||||
+ // Validate Results
|
||||
+ ASSERT_EQ(ort_outputs.size(), 2u);
|
||||
+
|
||||
+ std::vector<int64_t> out_dims = {2, 2};
|
||||
+ std::vector<float> values_out0 = {-1.1f, 0.0f, 0.0f, -5.5f};
|
||||
+ auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
|
||||
+ ASSERT_EQ(type_info.GetShape(), out_dims);
|
||||
+ size_t total_len = type_info.GetElementCount();
|
||||
+ ASSERT_EQ(values_out0.size(), total_len);
|
||||
+
|
||||
+ float* f = ort_outputs[0].GetTensorMutableData<float>();
|
||||
+ for (size_t i = 0; i != total_len; ++i) {
|
||||
+ ASSERT_EQ(values_out0[i], f[i]);
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+TEST(CApiTest, test_enable_ort_customops_stringlower) {
|
||||
+
|
||||
+ auto allocator = onnxruntime::make_unique<MockedOrtAllocator>();
|
||||
+
|
||||
+ // Create Inputs
|
||||
+ std::vector<Ort::Value> ort_inputs;
|
||||
+ std::string input_data{"HI, This is ENGINEER from Microsoft."};
|
||||
+ const char* const input_strings[] = {input_data.c_str()};
|
||||
+ std::vector<int64_t> input_dims = {1, 1};
|
||||
+
|
||||
+ Ort::Value input_tensor = Ort::Value::CreateTensor(allocator.get(), input_dims.data(), input_dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
+ input_tensor.FillStringTensor(input_strings, 1U);
|
||||
+ ort_inputs.push_back(std::move(input_tensor));
|
||||
+
|
||||
+ // Create Session with ORT CustomOps
|
||||
+ Ort::SessionOptions session_options;
|
||||
+ session_options.EnableOrtCustomOps();
|
||||
+ Ort::Session session(*ort_env, ORT_CUSTOM_OPS_MODEL_URI, session_options);
|
||||
+
|
||||
+ // Create Input and Output Names
|
||||
+ std::vector<const char*> input_names = {"input_1"};
|
||||
+ const char* output_names[] = {"customout"};
|
||||
+
|
||||
+ // Run Session
|
||||
+ std::vector<Ort::Value> ort_outputs = session.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, countof(output_names));
|
||||
+
|
||||
+ // Validate Results
|
||||
+ ASSERT_EQ(ort_outputs.size(), 1u);
|
||||
+
|
||||
+ std::vector<int64_t> out_dims = {1, 1};
|
||||
+ auto type_info = ort_outputs[0].GetTensorTypeAndShapeInfo();
|
||||
+ ASSERT_EQ(type_info.GetShape(), out_dims);
|
||||
+ ASSERT_EQ(type_info.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING);
|
||||
+
|
||||
+ std::string output_data{"hi, this is engineer from microsoft."};
|
||||
+ auto expected_string = output_data.c_str();
|
||||
+ size_t expected_string_len = strlen(expected_string);
|
||||
+ auto data_length = ort_outputs[0].GetStringTensorDataLength();
|
||||
+ ASSERT_EQ(expected_string_len, data_length);
|
||||
+
|
||||
+ std::string result(data_length, '\0');
|
||||
+ std::vector<size_t> offsets(type_info.GetElementCount());
|
||||
+ ort_outputs[0].GetStringTensorContent((void*)result.data(), data_length, offsets.data(), offsets.size());
|
||||
+ ASSERT_STREQ(result.c_str(), expected_string);
|
||||
+}
|
||||
+
|
||||
//test custom op which accepts float and double as inputs
|
||||
TEST(CApiTest, varied_input_custom_op_handler) {
|
||||
std::vector<Input> inputs(2);
|
||||
diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc
|
||||
index 8413197086..16e8479d35 100644
|
||||
--- a/onnxruntime/wasm/api.cc
|
||||
+++ b/onnxruntime/wasm/api.cc
|
||||
@@ -122,6 +122,11 @@ OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions*
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
+ // Enable ORT CustomOps
|
||||
+ // TODO: add condition check here to enable ORT CustomOps
|
||||
+ // session_options->EnableOrtCustomOps();
|
||||
+ RETURN_NULLPTR_IF_ERROR(EnableOrtCustomOps, session_options);
|
||||
+
|
||||
#if defined(__EMSCRIPTEN_PTHREADS__)
|
||||
RETURN_NULLPTR_IF_ERROR(DisablePerSessionThreads, session_options);
|
||||
#else
|
|
@ -160,7 +160,7 @@
|
|||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Implement the customop in C++ (optional)\n",
|
||||
"To make the ONNX model with the CustomOp runn on all other language supported by ONNX Runtime and be independdent of Python, a C++ implmentation is needed, check here for the [inverse.hpp](../operators/math/inverse.hpp) for an example on how to do that."
|
||||
"To make the ONNX model with the CustomOp run on all other languages supported by the ONNX Runtime and be independent of Python, a C++ implementation is needed, check [inverse.hpp](../operators/math/dlib/inverse.hpp) for an example on how to do that."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
|
@ -216,4 +216,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче