Enable C++ end to end test with onnxruntime (#56)
* basic changes build issues fixing runing on Windows Platform deploy the ort library in CMake update gitignore * Add C++ shared tests * enable ctest * fixing the python build issue * remove cc test * why does macos needs openmp package?
This commit is contained in:
Родитель
ddf9b873ad
Коммит
33027d2578
|
@ -18,6 +18,7 @@ gen
|
|||
.DS_Store
|
||||
*~
|
||||
.vs
|
||||
Testing/
|
||||
TestResults/
|
||||
.idea/
|
||||
nuget_root/
|
||||
|
@ -28,6 +29,7 @@ __pycache__
|
|||
out/
|
||||
*.egg-info/
|
||||
.setuptools-cmake-build/
|
||||
onnxruntime-*-*-*/
|
||||
|
||||
# Compiled Dynamic libraries
|
||||
*.so
|
||||
|
|
217
CMakeLists.txt
217
CMakeLists.txt
|
@ -2,16 +2,11 @@ cmake_minimum_required(VERSION 3.16.0)
|
|||
project(ortcustomops VERSION 0.1.0 LANGUAGES C CXX)
|
||||
# set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
|
||||
# Enable CTest
|
||||
enable_testing()
|
||||
include(CTest)
|
||||
|
||||
if(NOT CMAKE_BUILD_TYPE)
|
||||
message(STATUS "Build type not set - using RelWithDebInfo")
|
||||
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Choose build type: Debug Release RelWithDebInfo." FORCE)
|
||||
endif()
|
||||
|
||||
set(ONNX_ML 1)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
@ -19,8 +14,15 @@ include(CheckCXXCompilerFlag)
|
|||
include(CheckLanguage)
|
||||
|
||||
option(CC_OPTIMIZE "Allow compiler optimizations, Set to OFF to disable" ON)
|
||||
option(ENABLE_PYTHON "Enable Python component building" OFF)
|
||||
option(ENABLE_TOKENIZER "Enable the tokenizer building" ON)
|
||||
option(OCOS_ENABLE_PYTHON "Enable Python component building" OFF)
|
||||
option(OCOS_ENABLE_TF_STRING "Enable String Operator Set" ON)
|
||||
option(OCOS_ENABLE_GPT2_TOKENIZER "Enable the GPT2 tokenizer building" ON)
|
||||
option(OCOS_ENABLE_SPM_TOKENIZER "Enable the SentencePiece tokenizer building" ON)
|
||||
|
||||
find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}")
|
||||
if ((NOT OCOS_ENABLE_PYTHON) AND (NOT ONNXRUNTIME))
|
||||
message(FATAL_ERROR "Cannot find onnxruntime in the default library paths, please specify the ONNXRUNTIME_LIB_DIR.")
|
||||
endif()
|
||||
|
||||
if(NOT CC_OPTIMIZE)
|
||||
message("!!!THE COMPILER OPTIMIZATION HAS BEEN DISABLED, DEBUG-ONLY!!!")
|
||||
|
@ -46,41 +48,87 @@ if(NOT "${CMAKE_FIND_FRAMEWORK}" STREQUAL "NEVER")
|
|||
message(FATAL_ERROR "CMAKE_FIND_FRAMEWORK is not NEVER")
|
||||
endif()
|
||||
|
||||
set(CPACK_PROJECT_NAME ${PROJECT_NAME})
|
||||
set(CPACK_PROJECT_VERSION ${PROJECT_VERSION})
|
||||
include(CPack)
|
||||
|
||||
# External dependencies
|
||||
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/externals)
|
||||
include(FetchContent)
|
||||
include(googlere2)
|
||||
include(farmhash)
|
||||
FetchContent_GetProperties(googlere2)
|
||||
FetchContent_GetProperties(farmhash)
|
||||
if (OCOS_ENABLE_TF_STRING)
|
||||
set(RE2_BUILD_TESTING OFF CACHE INTERNAL "")
|
||||
include(googlere2)
|
||||
include(farmhash)
|
||||
FetchContent_GetProperties(googlere2)
|
||||
FetchContent_GetProperties(farmhash)
|
||||
endif()
|
||||
|
||||
file(GLOB TARGET_SRC "./ocos/*.cc" "./ocos/*.h*")
|
||||
file(GLOB TARGET_SRC_KERNELS "./ocos/kernels/*.cc" "./ocos/kernels/*.h*")
|
||||
file(GLOB TARGET_SRC_PYOPS "./ocos/pyfunc/*.cc" "./ocos/pyfunc/*.h*")
|
||||
file(GLOB TARGET_SRC_HASH "${farmhash_SOURCE_DIR}/src/farmhash.*")
|
||||
if (OCOS_ENABLE_TF_STRING)
|
||||
file(GLOB TARGET_SRC_KERNELS "./ocos/kernels/*.cc" "./ocos/kernels/*.h*")
|
||||
file(GLOB TARGET_SRC_HASH "${farmhash_SOURCE_DIR}/src/farmhash.*")
|
||||
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
|
||||
endif()
|
||||
|
||||
if (ENABLE_TOKENIZER)
|
||||
if (OCOS_ENABLE_GPT2_TOKENIZER)
|
||||
# GPT2
|
||||
set(JSON_BuildTests OFF CACHE INTERNAL "")
|
||||
include(json)
|
||||
file(GLOB tok_TARGET_SRC "tokenizer/gpt*.cc" "tokenizer/unicode*.*")
|
||||
list(APPEND TARGET_SRC ${tok_TARGET_SRC})
|
||||
endif()
|
||||
if (OCOS_ENABLE_SPM_TOKENIZER)
|
||||
# SentencePiece
|
||||
set(SPM_ENABLE_TCMALLOC OFF CACHE INTERNAL "")
|
||||
set(SPM_ENABLE_SHARED OFF CACHE INTERNAL "")
|
||||
include(sentencepieceproject)
|
||||
file(GLOB stpiece_TARGET_SRC "sentencepiece/*.cc" "tokenizer/sentencepiece*")
|
||||
list(REMOVE_ITEM stpiece_TARGET_SRC INCLUDE REGEX ".*((spm)|(train)).*")
|
||||
list(APPEND TARGET_SRC ${stpiece_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
add_library(ortcustomops_static STATIC
|
||||
${TARGET_SRC}
|
||||
${TARGET_SRC_KERNELS}
|
||||
${TARGET_SRC_HASH})
|
||||
add_library(ortcustomops_static STATIC ${TARGET_SRC})
|
||||
|
||||
if(ENABLE_PYTHON)
|
||||
set(ocos_libraries ortcustomops_static)
|
||||
if (OCOS_ENABLE_TF_STRING)
|
||||
list(APPEND ocos_libraries re2)
|
||||
endif()
|
||||
|
||||
target_include_directories(ortcustomops_static PUBLIC
|
||||
${PROJECT_SOURCE_DIR}/includes
|
||||
${PROJECT_SOURCE_DIR}/includes/onnxruntime
|
||||
${PROJECT_SOURCE_DIR}/ocos)
|
||||
|
||||
set(OCOS_COMPILE_DEFINITIONS "")
|
||||
|
||||
if (OCOS_ENABLE_TF_STRING)
|
||||
target_include_directories(ortcustomops_static PUBLIC
|
||||
${googlere2_SOURCE_DIR}
|
||||
${farmhash_SOURCE_DIR}/src)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TF_STRING)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_GPT2_TOKENIZER)
|
||||
# GPT2
|
||||
target_include_directories(ortcustomops_static PRIVATE ${json_SOURCE_DIR}/single_include)
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_GPT2_TOKENIZER)
|
||||
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_SPM_TOKENIZER)
|
||||
# SentencePiece
|
||||
target_include_directories(ortcustomops_static PRIVATE ${PROJECT_SOURCE_DIR}/tokenizer ${sentencepieceproject_INCLUDE_DIRS})
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_SPM_TOKENIZER)
|
||||
list(APPEND ocos_libraries sentencepiece-static)
|
||||
endif()
|
||||
|
||||
if (OCOS_ENABLE_TF_STRING)
|
||||
target_compile_definitions(ortcustomops_static PRIVATE
|
||||
NOMINMAX
|
||||
FARMHASH_NO_BUILTIN_EXPECT)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(ortcustomops_static PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
|
||||
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h")
|
||||
if(OCOS_ENABLE_PYTHON)
|
||||
file(GLOB TARGET_SRC_PYOPS "./ocos/pyfunc/*.cc" "./ocos/pyfunc/*.h*")
|
||||
set(Python3_FIND_REGISTRY NEVER CACHE STRING "...")
|
||||
if(NOT "${Python3_FIND_REGISTRY}" STREQUAL "NEVER")
|
||||
message(FATAL_ERROR "Python3_FIND_REGISTRY is not NEVER")
|
||||
|
@ -88,75 +136,23 @@ if(ENABLE_PYTHON)
|
|||
find_package(Python3 COMPONENTS Interpreter Development)
|
||||
|
||||
if (WIN32)
|
||||
list(APPEND TARGET_SRC "${PROJECT_SOURCE_DIR}/onnxruntime_customops/ortcustomops.def")
|
||||
list(APPEND shared_TARGET_SRC "${PROJECT_SOURCE_DIR}/onnxruntime_customops/ortcustomops.def")
|
||||
endif()
|
||||
Python3_add_library(ortcustomops SHARED
|
||||
${TARGET_SRC}
|
||||
${TARGET_SRC_KERNELS}
|
||||
${TARGET_SRC_PYOPS}
|
||||
${TARGET_SRC_HASH})
|
||||
target_compile_definitions(ortcustomops PRIVATE PYTHON_OP_SUPPORT)
|
||||
|
||||
Python3_add_library(ortcustomops SHARED ${TARGET_SRC_PYOPS} ${shared_TARGET_SRC})
|
||||
list(APPEND OCOS_COMPILE_DEFINITIONS PYTHON_OP_SUPPORT)
|
||||
else()
|
||||
list(APPEND TARGET_SRC "${PROJECT_SOURCE_DIR}/ocos/ortcustomops.def")
|
||||
add_library(ortcustomops SHARED
|
||||
${TARGET_SRC}
|
||||
${TARGET_SRC_KERNELS}
|
||||
${TARGET_SRC_HASH})
|
||||
list(APPEND shared_TARGET_SRC "${PROJECT_SOURCE_DIR}/shared/ortcustomops.def")
|
||||
add_library(ortcustomops SHARED ${shared_TARGET_SRC})
|
||||
endif()
|
||||
|
||||
if (WIN32)
|
||||
set_source_files_properties(ortcustomops_pyd.def PROPERTIES HEADER_FILE_ONLY TRUE)
|
||||
target_compile_definitions(ortcustomops PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||
if (OCOS_ENABLE_SPM_TOKENIZER) # FIXME: this include path is not recommendeded.
|
||||
target_include_directories(ortcustomops PRIVATE ${PROJECT_SOURCE_DIR}/tokenizer ${sentencepieceproject_INCLUDE_DIRS})
|
||||
endif()
|
||||
target_link_libraries(ortcustomops PRIVATE ${ocos_libraries})
|
||||
|
||||
set(external_libraries re2)
|
||||
target_include_directories(ortcustomops PUBLIC
|
||||
${PROJECT_SOURCE_DIR}/includes
|
||||
${PROJECT_SOURCE_DIR}/includes/onnxruntime
|
||||
${PROJECT_SOURCE_DIR}/ocos
|
||||
${googlere2_SOURCE_DIR}
|
||||
${farmhash_SOURCE_DIR}/src)
|
||||
|
||||
target_include_directories(ortcustomops_static PUBLIC
|
||||
${PROJECT_SOURCE_DIR}/includes
|
||||
${PROJECT_SOURCE_DIR}/includes/onnxruntime
|
||||
${PROJECT_SOURCE_DIR}/ocos
|
||||
${googlere2_SOURCE_DIR}
|
||||
${farmhash_SOURCE_DIR}/src)
|
||||
|
||||
if (ENABLE_TOKENIZER)
|
||||
target_compile_definitions(ortcustomops PRIVATE ENABLE_TOKENIZER)
|
||||
# GPT2
|
||||
list(APPEND external_libraries nlohmann_json::nlohmann_json)
|
||||
# SentencePiece
|
||||
target_include_directories(ortcustomops PRIVATE
|
||||
${sentencepieceproject_INCLUDE_DIRS}
|
||||
${PROJECT_SOURCE_DIR}/tokenizer)
|
||||
list(APPEND external_libraries sentencepiece-static)
|
||||
|
||||
target_compile_definitions(ortcustomops_static PRIVATE ENABLE_TOKENIZER)
|
||||
target_include_directories(ortcustomops_static PRIVATE
|
||||
${sentencepieceproject_INCLUDE_DIRS}
|
||||
${PROJECT_SOURCE_DIR}/tokenizer
|
||||
${json_SOURCE_DIR}/single_include)
|
||||
endif()
|
||||
|
||||
target_link_libraries(ortcustomops PRIVATE ${external_libraries})
|
||||
|
||||
target_compile_definitions(ortcustomops PRIVATE
|
||||
ONNX_NAMESPACE=onnx
|
||||
ONNX_ML
|
||||
NOMINMAX
|
||||
FARMHASH_NO_BUILTIN_EXPECT)
|
||||
target_compile_features(ortcustomops PUBLIC cxx_std_11)
|
||||
|
||||
target_compile_definitions(ortcustomops_static PRIVATE
|
||||
ONNX_NAMESPACE=onnx
|
||||
ONNX_ML
|
||||
NOMINMAX
|
||||
FARMHASH_NO_BUILTIN_EXPECT)
|
||||
target_compile_features(ortcustomops_static PUBLIC cxx_std_11)
|
||||
|
||||
if(ENABLE_PYTHON)
|
||||
if(OCOS_ENABLE_PYTHON)
|
||||
include(pybind11)
|
||||
set(NUMPY_NOT_FOUND false)
|
||||
exec_program("${Python3_EXECUTABLE}"
|
||||
|
@ -183,10 +179,45 @@ if(ENABLE_PYTHON)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
# test section
|
||||
include(googletest)
|
||||
file(GLOB TEST_SRC "${PROJECT_SOURCE_DIR}/test/test*.cc")
|
||||
add_executable(ortcustomops_test ${TEST_SRC})
|
||||
target_link_libraries(ortcustomops_test gtest_main ortcustomops_static ${external_libraries})
|
||||
set(CPACK_PROJECT_NAME ${PROJECT_NAME})
|
||||
set(CPACK_PROJECT_VERSION ${PROJECT_VERSION})
|
||||
include(CPack)
|
||||
|
||||
add_test(NAME ortcustomops_test COMMAND test)
|
||||
# test section
|
||||
if (NOT OCOS_ENABLE_PYTHON)
|
||||
# Enable CTest
|
||||
enable_testing()
|
||||
include(CTest)
|
||||
|
||||
set(TEST_SRC_DIR ${PROJECT_SOURCE_DIR}/test)
|
||||
include(googletest)
|
||||
file(GLOB static_TEST_SRC "${TEST_SRC_DIR}/static_test/*.cc")
|
||||
add_executable(ortcustomops_static_test ${static_TEST_SRC})
|
||||
target_link_libraries(ortcustomops_static_test gtest_main ${ocos_libraries})
|
||||
add_test(NAME ortcustomops_static_test COMMAND $<TARGET_FILE:ortcustomops_static_test>)
|
||||
|
||||
file(GLOB shared_TEST_SRC "${TEST_SRC_DIR}/shared_test/*.cc")
|
||||
add_executable(ortcustomops_test ${shared_TEST_SRC})
|
||||
if (ONNXRUNTIME_LIB_DIR)
|
||||
target_link_directories(ortcustomops_test PRIVATE ${ONNXRUNTIME_LIB_DIR})
|
||||
target_link_libraries(ortcustomops_test ortcustomops onnxruntime gtest_main ${ocos_libraries})
|
||||
if (WIN32)
|
||||
file(TO_CMAKE_PATH "${ONNXRUNTIME_LIB_DIR}/*" ONNXRUNTIME_LIB_FILEPATTERN)
|
||||
file(GLOB ONNXRUNTIME_LIB_FILES CONFIGURE_DEPENDS "${ONNXRUNTIME_LIB_FILEPATTERN}")
|
||||
add_custom_command(
|
||||
TARGET ortcustomops_test POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy ${ONNXRUNTIME_LIB_FILES} $<TARGET_FILE_DIR:ortcustomops_test>)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(TEST_DATA_SRC ${TEST_SRC_DIR}/data)
|
||||
set(TEST_DATA_DES ${ortcustomops_BINARY_DIR}/data)
|
||||
|
||||
# Copy test data from source to destination.
|
||||
add_custom_command(
|
||||
TARGET ortcustomops_test POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
${TEST_DATA_SRC}
|
||||
${TEST_DATA_DES})
|
||||
add_test(NAME ortcustomops_test COMMAND $<TARGET_FILE:ortcustomops_test>)
|
||||
endif()
|
||||
|
|
|
@ -26,15 +26,9 @@ jobs:
|
|||
displayName: Install requirements.txt
|
||||
|
||||
- script: |
|
||||
sh ./build.sh
|
||||
python setup.py develop
|
||||
displayName: Build the library and tests
|
||||
|
||||
- script: |
|
||||
cd out/Linux
|
||||
./ortcustomops_test
|
||||
displayName: Run the native only unit tests
|
||||
|
||||
- script: python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
displayName: Install pytorch
|
||||
|
||||
|
@ -78,16 +72,10 @@ jobs:
|
|||
displayName: Check installation
|
||||
|
||||
- script: |
|
||||
sh ./build.sh
|
||||
call activate pyenv
|
||||
python setup.py develop
|
||||
displayName: Build the library and tests
|
||||
|
||||
- script: |
|
||||
cd out/Darwin
|
||||
./ortcustomops_test
|
||||
displayName: Run the native only unit tests
|
||||
|
||||
- script: python -m pip install -r requirements-dev.txt
|
||||
displayName: Install requirements-dev.txt
|
||||
|
||||
|
@ -136,22 +124,17 @@ jobs:
|
|||
python -m pip install -r requirements-dev.txt
|
||||
displayName: Install requirements.txt
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
echo Test numpy installation... && python -c "import numpy"
|
||||
call .\build.bat
|
||||
python setup.py develop
|
||||
displayName: Build the custom-op library
|
||||
|
||||
- script: |
|
||||
.\out\Windows\RelWithDebInfo\ortcustomops_test.exe
|
||||
displayName: Run C++ Test
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
python -m pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
displayName: Install pytorch
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
echo Test numpy installation... && python -c "import numpy"
|
||||
python setup.py develop
|
||||
displayName: Build the custom-op library
|
||||
|
||||
- script: |
|
||||
call activate pyenv
|
||||
python -m pytest test
|
||||
|
|
|
@ -4,16 +4,47 @@
|
|||
#pragma once
|
||||
|
||||
#define ORT_API_MANUAL_INIT
|
||||
#define EXCLUDE_REFERENCE_TO_ORT_DLL
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#undef EXCLUDE_REFERENCE_TO_ORT_DLL
|
||||
#undef ORT_API_MANUAL_INIT
|
||||
|
||||
|
||||
#if defined(ENABLE_GPT2_TOKENIZER)
|
||||
const OrtCustomOp** LoadTokenizerSchemaList();
|
||||
#endif // ENABLE_GPT2_TOKENIZER
|
||||
|
||||
|
||||
#if defined(PYTHON_OP_SUPPORT)
|
||||
const OrtCustomOp* FetchPyCustomOps(size_t& count);
|
||||
bool EnablePyCustomOps(bool enable=true);
|
||||
#endif
|
||||
|
||||
// A helper API to support test kernels.
|
||||
// Must be invoked before RegisterCustomOps.
|
||||
extern "C" bool AddExternalCustomOp(const OrtCustomOp* c_op);
|
||||
|
||||
const char c_OpDomain[] = "ai.onnx.contrib";
|
||||
|
||||
#if defined(PYTHON_OP_SUPPORT)
|
||||
struct BaseKernel {
|
||||
BaseKernel(OrtApi api) : api_(api), info_(nullptr), ort_(api_) {}
|
||||
BaseKernel(OrtApi api, const OrtKernelInfo *info) : api_(api), info_(info), ort_(api_) {}
|
||||
|
||||
const OrtCustomOp* FetchPyCustomOps(size_t& count);
|
||||
bool EnablePyCustomOps(bool enable=true);
|
||||
protected:
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
Ort::CustomOpApi ort_;
|
||||
const OrtKernelInfo* info_;
|
||||
};
|
||||
|
||||
#endif
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
OrtTensorDimensions(Ort::CustomOpApi& ort, const OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
||||
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
|
||||
ort.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
const std::vector<int64_t>& GetDims() const { return *this; }
|
||||
int64_t Size() const {
|
||||
int64_t s = 1.;
|
||||
for (auto it = begin(); it != end(); ++it)
|
||||
s *= *it;
|
||||
return s;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -4,37 +4,3 @@
|
|||
#pragma once
|
||||
|
||||
#include "ocos.h"
|
||||
|
||||
typedef OrtCustomOp const* CPTR_OrtCustomOp;
|
||||
typedef CPTR_OrtCustomOp (*FxGetSchemaInstance)();
|
||||
|
||||
FxGetSchemaInstance const* GetCustomOpSchemaList();
|
||||
|
||||
struct BaseKernel {
|
||||
BaseKernel(OrtApi api) : api_(api), info_(nullptr), ort_(api_) {}
|
||||
BaseKernel(OrtApi api, const OrtKernelInfo *info) : api_(api), info_(info), ort_(api_) {}
|
||||
|
||||
protected:
|
||||
OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
|
||||
Ort::CustomOpApi ort_;
|
||||
const OrtKernelInfo* info_;
|
||||
};
|
||||
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
OrtTensorDimensions(Ort::CustomOpApi& ort, const OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
||||
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
|
||||
ort.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
const std::vector<int64_t>& GetDims() const { return *this; }
|
||||
int64_t Size() const {
|
||||
int64_t s = 1.;
|
||||
for (auto it = begin(); it != end(); ++it)
|
||||
s *= *it;
|
||||
return s;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(ENABLE_TOKENIZER)
|
||||
const OrtCustomOp** LoadTokenizerSchemaList();
|
||||
#endif // ENABLE_TEXT_DOMAIN
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
|
||||
struct KernelNegPos : BaseKernel {
|
||||
KernelNegPos(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context){
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
|
||||
OrtValue* output0 = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out0 = ort_.GetTensorMutableData<float>(output0);
|
||||
OrtValue* output1 = ort_.KernelContext_GetOutput(context, 1, dimensions.data(), dimensions.size());
|
||||
float* out1 = ort_.GetTensorMutableData<float>(output1);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output0);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
if (X[i] > 0) {
|
||||
out0[i] = 0;
|
||||
out1[i] = X[i];
|
||||
} else {
|
||||
out0[i] = X[i];
|
||||
out1[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const{
|
||||
return new KernelNegPos(api);
|
||||
}
|
||||
|
||||
const char* GetName() const{
|
||||
return "NegPos";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const{
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const{
|
||||
return 2;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
};
|
|
@ -1,158 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include <math.h>
|
||||
#include "test_output.hpp"
|
||||
|
||||
KernelOne::KernelOne(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelOne::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
const float* Y = ort_.GetTensorData<float>(input_Y);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out = ort_.GetTensorMutableData<float>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i] + Y[i];
|
||||
}
|
||||
}
|
||||
|
||||
void* CustomOpOne::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelOne(api);
|
||||
};
|
||||
|
||||
const char* CustomOpOne::GetName() const {
|
||||
return "CustomOpOne";
|
||||
};
|
||||
|
||||
size_t CustomOpOne::GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpOne::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
size_t CustomOpOne::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpOne::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
KernelTwo::KernelTwo(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelTwo::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
int32_t* out = ort_.GetTensorMutableData<int32_t>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = (int32_t)(round(X[i]));
|
||||
}
|
||||
}
|
||||
|
||||
void* CustomOpTwo::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelTwo(api);
|
||||
};
|
||||
|
||||
const char* CustomOpTwo::GetName() const {
|
||||
return "CustomOpTwo";
|
||||
};
|
||||
|
||||
size_t CustomOpTwo::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpTwo::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
size_t CustomOpTwo::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpTwo::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
||||
};
|
||||
|
||||
KernelNegPos::KernelNegPos(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
|
||||
void KernelNegPos::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
|
||||
OrtValue* output0 = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out0 = ort_.GetTensorMutableData<float>(output0);
|
||||
OrtValue* output1 = ort_.KernelContext_GetOutput(context, 1, dimensions.data(), dimensions.size());
|
||||
float* out1 = ort_.GetTensorMutableData<float>(output1);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output0);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
if (X[i] > 0) {
|
||||
out0[i] = 0;
|
||||
out1[i] = X[i];
|
||||
} else {
|
||||
out0[i] = X[i];
|
||||
out1[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void* CustomOpNegPos::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
|
||||
return new KernelNegPos(api);
|
||||
};
|
||||
|
||||
const char* CustomOpNegPos::GetName() const {
|
||||
return "NegPos";
|
||||
};
|
||||
|
||||
size_t CustomOpNegPos::GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpNegPos::GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
size_t CustomOpNegPos::GetOutputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
|
||||
ONNXTensorElementDataType CustomOpNegPos::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
|
@ -1,49 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "kernels.h"
|
||||
#include "utils.h"
|
||||
|
||||
struct KernelOne : BaseKernel {
|
||||
KernelOne(OrtApi api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
||||
struct KernelTwo : BaseKernel {
|
||||
KernelTwo(OrtApi api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
|
||||
struct KernelNegPos : BaseKernel {
|
||||
KernelNegPos(OrtApi api);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
2
setup.py
2
setup.py
|
@ -68,7 +68,7 @@ class BuildCMakeExt(_build_ext):
|
|||
config = 'RelWithDebInfo' if self.debug else 'Release'
|
||||
cmake_args = [
|
||||
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(ext_fullpath.parent.absolute()),
|
||||
'-DENABLE_PYTHON=ON',
|
||||
'-DOCOS_ENABLE_PYTHON=ON',
|
||||
'-DOCOS_EXTENTION_NAME=' + pathlib.Path(self.get_ext_filename(extension.name)).name,
|
||||
'-DCMAKE_BUILD_TYPE=' + config
|
||||
]
|
||||
|
|
|
@ -11,17 +11,17 @@
|
|||
#include "kernels/string_regex_replace.hpp"
|
||||
#include "kernels/string_split.hpp"
|
||||
#include "kernels/string_upper.hpp"
|
||||
#include "kernels/test_output.hpp"
|
||||
#include "kernels/negpos.hpp"
|
||||
#include "utils.h"
|
||||
|
||||
#ifdef ENABLE_TOKENIZER
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
#include "sentencepiece_tokenizer.hpp"
|
||||
#endif
|
||||
|
||||
CustomOpNegPos c_CustomOpNegPos;
|
||||
CustomOpSegmentSum c_CustomOpSegmentSum;
|
||||
CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse;
|
||||
#ifdef ENABLE_TOKENIZER
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
CustomOpSentencepieceTokenizer c_CustomOpSentencepieceTokenizer;
|
||||
#endif
|
||||
CustomOpStringEqual c_CustomOpStringEqual;
|
||||
|
@ -31,14 +31,12 @@ CustomOpStringJoin c_CustomOpStringJoin;
|
|||
CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
|
||||
CustomOpStringSplit c_CustomOpStringSplit;
|
||||
CustomOpStringUpper c_CustomOpStringUpper;
|
||||
CustomOpOne c_CustomOpOne;
|
||||
CustomOpTwo c_CustomOpTwo;
|
||||
|
||||
OrtCustomOp* operator_lists[] = {
|
||||
&c_CustomOpNegPos,
|
||||
&c_CustomOpRaggedTensorToSparse,
|
||||
&c_CustomOpSegmentSum,
|
||||
#ifdef ENABLE_TOKENIZER
|
||||
#ifdef ENABLE_SPM_TOKENIZER
|
||||
&c_CustomOpSentencepieceTokenizer,
|
||||
#endif
|
||||
&c_CustomOpStringEqual,
|
||||
|
@ -48,10 +46,46 @@ OrtCustomOp* operator_lists[] = {
|
|||
&c_CustomOpStringRegexReplace,
|
||||
&c_CustomOpStringSplit,
|
||||
&c_CustomOpStringUpper,
|
||||
&c_CustomOpOne,
|
||||
&c_CustomOpTwo,
|
||||
nullptr};
|
||||
|
||||
|
||||
class ExternalCustomOps
|
||||
{
|
||||
public:
|
||||
ExternalCustomOps(){
|
||||
}
|
||||
|
||||
static ExternalCustomOps& instance() {
|
||||
static ExternalCustomOps g_instance;
|
||||
return g_instance;
|
||||
}
|
||||
|
||||
void Add(const OrtCustomOp* c_op) {
|
||||
op_array_.push_back(c_op);
|
||||
}
|
||||
|
||||
const OrtCustomOp* GetNextOp(size_t& idx) {
|
||||
if (idx >= op_array_.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return op_array_[idx ++];
|
||||
}
|
||||
|
||||
ExternalCustomOps(ExternalCustomOps const&) = delete;
|
||||
void operator=(ExternalCustomOps const&) = delete;
|
||||
|
||||
private:
|
||||
std::vector<const OrtCustomOp*> op_array_;
|
||||
};
|
||||
|
||||
|
||||
extern "C" bool AddExternalCustomOp(const OrtCustomOp* c_op) {
|
||||
ExternalCustomOps::instance().Add(c_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
|
||||
|
@ -85,7 +119,7 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
|
|||
++ops;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_TOKENIZER)
|
||||
#if defined(ENABLE_GPT2_TOKENIZER)
|
||||
const OrtCustomOp** t_ops = LoadTokenizerSchemaList();
|
||||
while (*t_ops != nullptr) {
|
||||
if (pyop_nameset.find((*t_ops)->GetName(*t_ops)) == pyop_nameset.end()) {
|
||||
|
@ -97,5 +131,16 @@ extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
|
|||
}
|
||||
#endif
|
||||
|
||||
size_t idx = 0;
|
||||
const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx);
|
||||
while (e_ops != nullptr) {
|
||||
if (pyop_nameset.find(e_ops->GetName(e_ops)) == pyop_nameset.end()) {
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, e_ops)){
|
||||
return status;
|
||||
}
|
||||
e_ops = ExternalCustomOps::instance().GetNextOp(idx);
|
||||
}
|
||||
}
|
||||
|
||||
return ortApi->AddCustomOpDomain(options, domain);
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
LIBRARY "ortcustomops.dll"
|
||||
EXPORTS
|
||||
RegisterCustomOps @1
|
||||
AddExternalCustomOp @2
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <math.h>
|
||||
|
||||
struct KernelOne : BaseKernel {
|
||||
KernelOne(OrtApi api): BaseKernel(api) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
const float* Y = ort_.GetTensorData<float>(input_Y);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
float* out = ort_.GetTensorMutableData<float>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = X[i] + Y[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
return new KernelOne(api);
|
||||
};
|
||||
const char* GetName() const {
|
||||
return "CustomOpOne";
|
||||
};
|
||||
size_t GetInputTypeCount() const {
|
||||
return 2;
|
||||
};
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
};
|
||||
|
||||
struct KernelTwo : BaseKernel {
|
||||
KernelTwo(OrtApi api) : BaseKernel(api) {
|
||||
}
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
||||
// Setup output
|
||||
OrtTensorDimensions dimensions(ort_, input_X);
|
||||
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||
int32_t* out = ort_.GetTensorMutableData<int32_t>(output);
|
||||
|
||||
OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
|
||||
int64_t size = ort_.GetTensorShapeElementCount(output_info);
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(output_info);
|
||||
|
||||
// Do computation
|
||||
for (int64_t i = 0; i < size; i++) {
|
||||
out[i] = (int32_t)(round(X[i]));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
return new KernelTwo(api);
|
||||
};
|
||||
const char* GetName() const {
|
||||
return "CustomOpTwo";
|
||||
};
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 1;
|
||||
};
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
||||
};
|
||||
};
|
|
@ -0,0 +1,113 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "onnxruntime_cxx_api.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ocos.h"
|
||||
|
||||
#include "test_kernel.hpp"
|
||||
|
||||
|
||||
struct Input {
|
||||
const char* name = nullptr;
|
||||
std::vector<int64_t> dims;
|
||||
std::vector<float> values;
|
||||
};
|
||||
|
||||
void RunSession(Ort::Session& session_object,
|
||||
const std::vector<Input>& inputs,
|
||||
const char* output_name,
|
||||
const std::vector<int64_t>& dims_y,
|
||||
const std::vector<int32_t>& values_y) {
|
||||
|
||||
std::vector<Ort::Value> ort_inputs;
|
||||
std::vector<const char*> input_names;
|
||||
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
input_names.emplace_back(inputs[i].name);
|
||||
ort_inputs.emplace_back(Ort::Value::CreateTensor<float>(memory_info,
|
||||
const_cast<float*>(inputs[i].values.data()), inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size()));
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
ort_outputs = session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1);
|
||||
ASSERT_EQ(ort_outputs.size(), 1u);
|
||||
auto output_tensor = &ort_outputs[0];
|
||||
|
||||
auto type_info = output_tensor->GetTensorTypeAndShapeInfo();
|
||||
ASSERT_EQ(type_info.GetShape(), dims_y);
|
||||
size_t total_len = type_info.GetElementCount();
|
||||
ASSERT_EQ(values_y.size(), total_len);
|
||||
|
||||
int32_t* f = output_tensor->GetTensorMutableData<int32_t>();
|
||||
for (size_t i = 0; i != total_len; ++i) {
|
||||
ASSERT_EQ(values_y[i], f[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void TestInference(Ort::Env& env, const ORTCHAR_T* model_uri,
|
||||
const std::vector<Input>& inputs,
|
||||
const char* output_name,
|
||||
const std::vector<int64_t>& expected_dims_y,
|
||||
const std::vector<int32_t>& expected_values_y,
|
||||
const char* custom_op_library_filename) {
|
||||
Ort::SessionOptions session_options;
|
||||
void* handle = nullptr;
|
||||
if (custom_op_library_filename) {
|
||||
Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, &handle));
|
||||
}
|
||||
|
||||
// if session creation passes, model loads fine
|
||||
Ort::Session session(env, model_uri, session_options);
|
||||
|
||||
// Now run
|
||||
RunSession(session,
|
||||
inputs,
|
||||
output_name,
|
||||
expected_dims_y,
|
||||
expected_values_y);
|
||||
}
|
||||
|
||||
static CustomOpOne op_1st;
|
||||
static CustomOpTwo op_2nd;
|
||||
|
||||
TEST(utils, test_ort_case) {
|
||||
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
std::cout << "Running custom op inference" << std::endl;
|
||||
|
||||
std::vector<Input> inputs(2);
|
||||
inputs[0].name = "input_1";
|
||||
inputs[0].dims = {3, 5};
|
||||
inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
|
||||
6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
|
||||
11.1f, 12.2f, 13.3f, 14.4f, 15.5f};
|
||||
inputs[1].name = "input_2";
|
||||
inputs[1].dims = {3, 5};
|
||||
inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
|
||||
10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
|
||||
5.5f, 4.4f, 3.3f, 2.2f, 1.1f};
|
||||
|
||||
// prepare expected inputs and outputs
|
||||
std::vector<int64_t> expected_dims_y = {3, 5};
|
||||
std::vector<int32_t> expected_values_y =
|
||||
{17, 17, 17, 17, 17,
|
||||
17, 18, 18, 18, 17,
|
||||
17, 17, 17, 17, 17};
|
||||
|
||||
#if defined(_WIN32)
|
||||
const char lib_name[] = "ortcustomops.dll";
|
||||
const ORTCHAR_T model_path[] = L"data\\custom_op_test.onnx";
|
||||
#elif defined(__APPLE__)
|
||||
const char lib_name[] = "libortcustomops.dylib";
|
||||
const ORTCHAR_T model_path[] = "data/custom_op_test.onnx";
|
||||
#else
|
||||
const char lib_name[] = "./libortcustomops.so";
|
||||
const ORTCHAR_T model_path[] = "data/custom_op_test.onnx";
|
||||
#endif
|
||||
AddExternalCustomOp(&op_1st);
|
||||
AddExternalCustomOp(&op_2nd);
|
||||
TestInference(*ort_env, model_path, inputs, "output", expected_dims_y, expected_values_y, lib_name);
|
||||
}
|
|
@ -1,10 +1,10 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
import numpy as np
|
||||
import onnxruntime as _ort
|
||||
|
||||
from pathlib import Path
|
||||
from onnx import helper, onnx_pb as onnx_proto
|
||||
from transformers import GPT2Tokenizer
|
||||
import onnxruntime as _ort
|
||||
from onnxruntime_customops import (
|
||||
onnx_op,
|
||||
enable_custom_op,
|
||||
|
|
|
@ -93,6 +93,16 @@ class TestPythonOp(unittest.TestCase):
|
|||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@onnx_op(op_type="CustomOpOne",
|
||||
inputs=[PyCustomOpDef.dt_float, PyCustomOpDef.dt_float])
|
||||
def custom_one_op(x, y):
|
||||
return np.add(x, y)
|
||||
|
||||
@onnx_op(op_type="CustomOpTwo",
|
||||
outputs=[PyCustomOpDef.dt_int32])
|
||||
def custom_two_op(f):
|
||||
return np.round(f).astype(np.int32)
|
||||
|
||||
@onnx_op(op_type="PyReverseMatrix")
|
||||
def reverse_matrix(x):
|
||||
# The user custom op implementation here.
|
||||
|
|
|
@ -602,13 +602,13 @@ struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeT
|
|||
const OrtCustomOp** LoadTokenizerSchemaList() {
|
||||
// create the global objects here to let the ORT catch the expection if any
|
||||
static std::unique_ptr<CustomOpBpeTokenizer> p_CoBpeTokenizer;
|
||||
static const OrtCustomOp* c_DomainList[2] = {nullptr}; // {&c_CoBpeTokenizer, nullptr};
|
||||
static const OrtCustomOp* c_CustomOpList[2] = {nullptr}; // {&c_CoBpeTokenizer, nullptr};
|
||||
static std::mutex mtx_loaded;
|
||||
std::lock_guard<std::mutex> lck(mtx_loaded);
|
||||
if (p_CoBpeTokenizer.get() == nullptr) {
|
||||
p_CoBpeTokenizer = std::make_unique<CustomOpBpeTokenizer>();
|
||||
c_DomainList[0] = p_CoBpeTokenizer.get();
|
||||
c_CustomOpList[0] = p_CoBpeTokenizer.get();
|
||||
}
|
||||
|
||||
return c_DomainList;
|
||||
return c_CustomOpList;
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче