Using the header files from the ONNXRuntime package (#322)
* Using the header files from the ONNXRuntime package * Update includes/onnxruntime_customop.hpp Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> * fix the build break. * one more fixing * wired top project * ort 1.9.0 used * switch to 1.10.0 package. * change the vmimage to latest * URL issue * cmake policy * ignore onnxruntime.dll native scan * update the Onebranch exclusedPaths * fixing some build tool issues * update again * typo * undo of ORT dll removal Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
Родитель
69e6ec7cf1
Коммит
c599b00d07
|
@ -59,7 +59,7 @@ jobs:
|
||||||
displayName: Unpack ONNXRuntime package.
|
displayName: Unpack ONNXRuntime package.
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
CPU_NUMBER=2 sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-linux-x64-$(ort.version)/lib -DOCOS_ENABLE_CTEST=ON
|
CPU_NUMBER=2 sh ./build.sh -DOCOS_ENABLE_CTEST=ON
|
||||||
displayName: build the customop library with onnxruntime
|
displayName: build the customop library with onnxruntime
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
|
@ -139,7 +139,7 @@ jobs:
|
||||||
displayName: Unpack ONNXRuntime package.
|
displayName: Unpack ONNXRuntime package.
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
sh ./build.sh -DONNXRUNTIME_LIB_DIR=$(ort.dirname)/lib -DOCOS_ENABLE_CTEST=ON
|
sh ./build.sh -DOCOS_ENABLE_CTEST=ON
|
||||||
displayName: build the customop library with onnxruntime
|
displayName: build the customop library with onnxruntime
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
|
@ -266,7 +266,7 @@ jobs:
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
call $(vsdevcmd)
|
call $(vsdevcmd)
|
||||||
call .\build.bat -DONNXRUNTIME_LIB_DIR=.\onnxruntime-win-x64-$(ort.version)\lib -DOCOS_ENABLE_CTEST=ON
|
call .\build.bat -DOCOS_ENABLE_CTEST=ON
|
||||||
displayName: build the customop library with onnxruntime
|
displayName: build the customop library with onnxruntime
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
|
@ -427,6 +427,7 @@ jobs:
|
||||||
|
|
||||||
#############
|
#############
|
||||||
# iOS
|
# iOS
|
||||||
|
# Only cmake==3.25.0 works OpenCV compiling now.
|
||||||
#############
|
#############
|
||||||
- job: IosPackage
|
- job: IosPackage
|
||||||
pool:
|
pool:
|
||||||
|
@ -441,6 +442,7 @@ jobs:
|
||||||
displayName: "Use Python 3.9"
|
displayName: "Use Python 3.9"
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
|
python -m pip install cmake==3.25.0
|
||||||
python ./tools/ios/build_xcframework.py \
|
python ./tools/ios/build_xcframework.py \
|
||||||
--output-dir $(Build.BinariesDirectory)/xcframework_out \
|
--output-dir $(Build.BinariesDirectory)/xcframework_out \
|
||||||
--platform-arch iphonesimulator x86_64 \
|
--platform-arch iphonesimulator x86_64 \
|
||||||
|
|
|
@ -44,6 +44,7 @@ extends:
|
||||||
enabled: true
|
enabled: true
|
||||||
binskim:
|
binskim:
|
||||||
break: true # always break the build on binskim issues in addition to TSA upload
|
break: true # always break the build on binskim issues in addition to TSA upload
|
||||||
|
analyzeTargetGlob: '**\bin\*' # only scan the DLLs in extensions bin folder.
|
||||||
codeql:
|
codeql:
|
||||||
python:
|
python:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
@ -87,7 +88,7 @@ extends:
|
||||||
@echo ##vso[task.setvariable variable=vsdevcmd]%vsdevcmd%
|
@echo ##vso[task.setvariable variable=vsdevcmd]%vsdevcmd%
|
||||||
@echo ##vso[task.setvariable variable=vscmake]%vscmake%
|
@echo ##vso[task.setvariable variable=vscmake]%vscmake%
|
||||||
@echo ##vso[task.setvariable variable=vsmsbuild]%vsmsbuild%
|
@echo ##vso[task.setvariable variable=vsmsbuild]%vsmsbuild%
|
||||||
displayName: 'locate vsdevcmd via vswhere'
|
displayName: 'locate vsdevcmd via vswhere'
|
||||||
- script: |
|
- script: |
|
||||||
call $(vsdevcmd)
|
call $(vsdevcmd)
|
||||||
set PYTHONPATH=
|
set PYTHONPATH=
|
||||||
|
@ -95,6 +96,7 @@ extends:
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install cibuildwheel numpy
|
python -m pip install cibuildwheel numpy
|
||||||
python -m cibuildwheel --platform windows --output-dir $(REPOROOT)\out
|
python -m cibuildwheel --platform windows --output-dir $(REPOROOT)\out
|
||||||
|
del /s /q /f .setuptools-cmake-build\*onnxruntime.dll
|
||||||
displayName: Build wheels
|
displayName: Build wheels
|
||||||
- task: SDLNativeRules@3
|
- task: SDLNativeRules@3
|
||||||
inputs:
|
inputs:
|
||||||
|
|
|
@ -44,6 +44,7 @@ extends:
|
||||||
enabled: false
|
enabled: false
|
||||||
binskim:
|
binskim:
|
||||||
break: true # always break the build on binskim issues in addition to TSA upload
|
break: true # always break the build on binskim issues in addition to TSA upload
|
||||||
|
analyzeTargetGlob: '**\bin\*' # only scan the DLLs in extensions bin folder.
|
||||||
codeql:
|
codeql:
|
||||||
python:
|
python:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
@ -86,8 +87,8 @@ extends:
|
||||||
@echo ##vso[task.setvariable variable=vslatest]%vslatest%
|
@echo ##vso[task.setvariable variable=vslatest]%vslatest%
|
||||||
@echo ##vso[task.setvariable variable=vsdevcmd]%vsdevcmd%
|
@echo ##vso[task.setvariable variable=vsdevcmd]%vsdevcmd%
|
||||||
@echo ##vso[task.setvariable variable=vscmake]%vscmake%
|
@echo ##vso[task.setvariable variable=vscmake]%vscmake%
|
||||||
@echo ##vso[task.setvariable variable=vsmsbuild]%vsmsbuild%
|
@echo ##vso[task.setvariable variable=vsmsbuild]%vsmsbuild%
|
||||||
displayName: 'locate vsdevcmd via vswhere'
|
displayName: 'locate vsdevcmd via vswhere'
|
||||||
- script: |
|
- script: |
|
||||||
call $(vsdevcmd)
|
call $(vsdevcmd)
|
||||||
set PYTHONPATH=
|
set PYTHONPATH=
|
||||||
|
@ -95,6 +96,7 @@ extends:
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install cibuildwheel numpy
|
python -m pip install cibuildwheel numpy
|
||||||
python -m cibuildwheel --platform windows --output-dir $(REPOROOT)\out
|
python -m cibuildwheel --platform windows --output-dir $(REPOROOT)\out
|
||||||
|
del /s /q /f .setuptools-cmake-build\*onnxruntime.dll
|
||||||
displayName: Build wheels
|
displayName: Build wheels
|
||||||
- task: SDLNativeRules@3
|
- task: SDLNativeRules@3
|
||||||
inputs:
|
inputs:
|
||||||
|
|
157
CMakeLists.txt
157
CMakeLists.txt
|
@ -1,13 +1,12 @@
|
||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
project(onnxruntime_extensions LANGUAGES C CXX)
|
project(onnxruntime_extensions LANGUAGES C CXX)
|
||||||
# set(CMAKE_VERBOSE_MAKEFILE ON)
|
|
||||||
|
|
||||||
|
# set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||||
if(NOT CMAKE_BUILD_TYPE)
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
message(STATUS "Build type not set - using RelWithDebInfo")
|
message(STATUS "Build type not set - using RelWithDebInfo")
|
||||||
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Choose build type: Debug Release RelWithDebInfo." FORCE)
|
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Choose build type: Debug Release RelWithDebInfo." FORCE)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
set(CPACK_PACKAGE_NAME "onnxruntime_extensions")
|
set(CPACK_PACKAGE_NAME "onnxruntime_extensions")
|
||||||
set(CPACK_PACKAGE_VERSION_MAJOR "0")
|
set(CPACK_PACKAGE_VERSION_MAJOR "0")
|
||||||
set(CPACK_PACKAGE_VERSION_MINOR "5")
|
set(CPACK_PACKAGE_VERSION_MINOR "5")
|
||||||
|
@ -68,7 +67,7 @@ if(NOT CC_OPTIMIZE)
|
||||||
string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}")
|
string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}")
|
||||||
string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
|
string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
|
||||||
|
|
||||||
if (NOT WIN32)
|
if(NOT WIN32)
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0")
|
||||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0")
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0")
|
||||||
else()
|
else()
|
||||||
|
@ -77,15 +76,16 @@ if(NOT CC_OPTIMIZE)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
|
if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
|
||||||
message("OCOS_ENABLE_PYTHON IS DEPRECATED, USE OCOS_BUILD_PYTHON INSTEAD")
|
message("OCOS_ENABLE_PYTHON IS DEPRECATED, USE OCOS_BUILD_PYTHON INSTEAD")
|
||||||
set(OCOS_BUILD_PYTHON ON CACHE INTERNAL "")
|
set(OCOS_BUILD_PYTHON ON CACHE INTERNAL "")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_BUILD_ANDROID)
|
if(OCOS_BUILD_ANDROID)
|
||||||
if (NOT ANDROID_SDK_ROOT OR NOT ANDROID_NDK)
|
if(NOT ANDROID_SDK_ROOT OR NOT ANDROID_NDK)
|
||||||
message("Cannot the find Android SDK/NDK")
|
message("Cannot the find Android SDK/NDK")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(OCOS_BUILD_JAVA ON CACHE INTERNAL "")
|
set(OCOS_BUILD_JAVA ON CACHE INTERNAL "")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -95,6 +95,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
|
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
|
||||||
|
|
||||||
set(CMAKE_FIND_FRAMEWORK NEVER CACHE STRING "...")
|
set(CMAKE_FIND_FRAMEWORK NEVER CACHE STRING "...")
|
||||||
|
|
||||||
if(NOT "${CMAKE_FIND_FRAMEWORK}" STREQUAL "NEVER")
|
if(NOT "${CMAKE_FIND_FRAMEWORK}" STREQUAL "NEVER")
|
||||||
message(FATAL_ERROR "CMAKE_FIND_FRAMEWORK is not NEVER")
|
message(FATAL_ERROR "CMAKE_FIND_FRAMEWORK is not NEVER")
|
||||||
endif()
|
endif()
|
||||||
|
@ -102,7 +103,7 @@ endif()
|
||||||
# External dependencies
|
# External dependencies
|
||||||
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/externals ${PROJECT_SOURCE_DIR}/cmake)
|
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/externals ${PROJECT_SOURCE_DIR}/cmake)
|
||||||
|
|
||||||
if (OCOS_ENABLE_SELECTED_OPLIST)
|
if(OCOS_ENABLE_SELECTED_OPLIST)
|
||||||
# Need to ensure _selectedoplist.cmake file is already generated in folder: ${PROJECT_SOURCE_DIR}/cmake/
|
# Need to ensure _selectedoplist.cmake file is already generated in folder: ${PROJECT_SOURCE_DIR}/cmake/
|
||||||
# You could run gen_selectedops.py in folder: tools/ to generate _selectedoplist.cmake
|
# You could run gen_selectedops.py in folder: tools/ to generate _selectedoplist.cmake
|
||||||
message(STATUS "Looking for the _selectedoplist.cmake")
|
message(STATUS "Looking for the _selectedoplist.cmake")
|
||||||
|
@ -113,6 +114,7 @@ endif()
|
||||||
if(NOT OCOS_ENABLE_CPP_EXCEPTIONS)
|
if(NOT OCOS_ENABLE_CPP_EXCEPTIONS)
|
||||||
include(noexcep_ops)
|
include(noexcep_ops)
|
||||||
add_compile_definitions(OCOS_NO_EXCEPTIONS ORT_NO_EXCEPTIONS)
|
add_compile_definitions(OCOS_NO_EXCEPTIONS ORT_NO_EXCEPTIONS)
|
||||||
|
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
string(REGEX REPLACE "/EHsc" "/EHs-c-" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
string(REGEX REPLACE "/EHsc" "/EHs-c-" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||||
add_compile_definitions("_HAS_EXCEPTIONS=0")
|
add_compile_definitions("_HAS_EXCEPTIONS=0")
|
||||||
|
@ -123,8 +125,15 @@ endif()
|
||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
if (OCOS_ENABLE_RE2_REGEX)
|
# PROJECT_IS_TOP_LEVEL is available until 3.21
|
||||||
if (NOT TARGET re2::re2)
|
get_property(not_top DIRECTORY PROPERTY PARENT_DIRECTORY)
|
||||||
|
if(not_top)
|
||||||
|
set(_ONNXRUNTIME_EMBEDDED TRUE)
|
||||||
|
endif()
|
||||||
|
include(ext_ortlib)
|
||||||
|
|
||||||
|
if(OCOS_ENABLE_RE2_REGEX)
|
||||||
|
if(NOT TARGET re2::re2)
|
||||||
set(RE2_BUILD_TESTING OFF CACHE INTERNAL "")
|
set(RE2_BUILD_TESTING OFF CACHE INTERNAL "")
|
||||||
message(STATUS "Fetch googlere2")
|
message(STATUS "Fetch googlere2")
|
||||||
include(googlere2)
|
include(googlere2)
|
||||||
|
@ -136,30 +145,29 @@ if (OCOS_ENABLE_RE2_REGEX)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
macro(standardize_output_folder bin_target)
|
macro(standardize_output_folder bin_target)
|
||||||
set_target_properties(${bin_target} PROPERTIES
|
set_target_properties(${bin_target} PROPERTIES
|
||||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
|
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
|
||||||
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
|
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
|
||||||
PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
|
PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
|
||||||
endmacro()
|
endmacro()
|
||||||
|
|
||||||
|
# ### scan all source files
|
||||||
#### scan all source files
|
|
||||||
|
|
||||||
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h")
|
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h")
|
||||||
if (OCOS_ENABLE_TF_STRING)
|
|
||||||
|
if(OCOS_ENABLE_TF_STRING)
|
||||||
set(farmhash_SOURCE_DIR ${PROJECT_SOURCE_DIR}/cmake/externals/farmhash)
|
set(farmhash_SOURCE_DIR ${PROJECT_SOURCE_DIR}/cmake/externals/farmhash)
|
||||||
file(GLOB TARGET_SRC_KERNELS "operators/text/*.cc" "operators/text/*.h*")
|
file(GLOB TARGET_SRC_KERNELS "operators/text/*.cc" "operators/text/*.h*")
|
||||||
file(GLOB TARGET_SRC_HASH "${farmhash_SOURCE_DIR}/src/farmhash.*")
|
file(GLOB TARGET_SRC_HASH "${farmhash_SOURCE_DIR}/src/farmhash.*")
|
||||||
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
|
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_RE2_REGEX)
|
if(OCOS_ENABLE_RE2_REGEX)
|
||||||
file(GLOB TARGET_SRC_RE2_KERNELS "operators/text/re2_strings/*.cc" "operators/text/re2_strings/*.h*")
|
file(GLOB TARGET_SRC_RE2_KERNELS "operators/text/re2_strings/*.cc" "operators/text/re2_strings/*.h*")
|
||||||
list(APPEND TARGET_SRC ${TARGET_SRC_RE2_KERNELS})
|
list(APPEND TARGET_SRC ${TARGET_SRC_RE2_KERNELS})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_MATH)
|
if(OCOS_ENABLE_MATH)
|
||||||
if(OCOS_ENABLE_DLIB)
|
if(OCOS_ENABLE_DLIB)
|
||||||
set(DLIB_ISO_CPP_ONLY ON CACHE INTERNAL "")
|
set(DLIB_ISO_CPP_ONLY ON CACHE INTERNAL "")
|
||||||
set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "")
|
set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "")
|
||||||
|
@ -167,31 +175,33 @@ if (OCOS_ENABLE_MATH)
|
||||||
set(DLIB_USE_LAPACK OFF CACHE INTERNAL "")
|
set(DLIB_USE_LAPACK OFF CACHE INTERNAL "")
|
||||||
set(DLIB_USE_BLAS OFF CACHE INTERNAL "")
|
set(DLIB_USE_BLAS OFF CACHE INTERNAL "")
|
||||||
include(dlib)
|
include(dlib)
|
||||||
|
|
||||||
# Ideally, dlib should be included as
|
# Ideally, dlib should be included as
|
||||||
# file(GLOB TARGET_SRC_DLIB "${dlib_SOURCE_DIR}/dlib/all/source.cpp")
|
# file(GLOB TARGET_SRC_DLIB "${dlib_SOURCE_DIR}/dlib/all/source.cpp")
|
||||||
# To avoid the unintentional using some unwanted component, only include
|
# To avoid the unintentional using some unwanted component, only include
|
||||||
file(GLOB TARGET_SRC_DLIB "${dlib_SOURCE_DIR}/dlib/test_for_odr_violations.cpp")
|
file(GLOB TARGET_SRC_DLIB "${dlib_SOURCE_DIR}/dlib/test_for_odr_violations.cpp")
|
||||||
file(GLOB TARGET_SRC_INVERSE "operators/math/dlib/*.cc" "operators/math/dlib/*.h*")
|
file(GLOB TARGET_SRC_INVERSE "operators/math/dlib/*.cc" "operators/math/dlib/*.h*")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
file(GLOB TARGET_SRC_MATH "operators/math/*.cc" "operators/math/*.h*")
|
file(GLOB TARGET_SRC_MATH "operators/math/*.cc" "operators/math/*.h*")
|
||||||
list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE})
|
list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# enable the opencv dependency if we have ops that require it
|
# enable the opencv dependency if we have ops that require it
|
||||||
if (OCOS_ENABLE_CV2 OR OCOS_ENABLE_VISION)
|
if(OCOS_ENABLE_CV2 OR OCOS_ENABLE_VISION)
|
||||||
set(_ENABLE_OPENCV ON)
|
set(_ENABLE_OPENCV ON)
|
||||||
message(STATUS "Fetch opencv")
|
message(STATUS "Fetch opencv")
|
||||||
include(opencv)
|
include(opencv)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_CV2)
|
if(OCOS_ENABLE_CV2)
|
||||||
file(GLOB TARGET_SRC_CV2 "operators/cv2/*.cc" "operators/cv2/*.h*")
|
file(GLOB TARGET_SRC_CV2 "operators/cv2/*.cc" "operators/cv2/*.h*")
|
||||||
list(APPEND TARGET_SRC ${TARGET_SRC_CV2})
|
list(APPEND TARGET_SRC ${TARGET_SRC_CV2})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_VISION)
|
if(OCOS_ENABLE_VISION)
|
||||||
if (NOT OCOS_ENABLE_OPENCV_CODECS)
|
if(NOT OCOS_ENABLE_OPENCV_CODECS)
|
||||||
message(FATAL_ERROR "OCOS_ENABLE_VISION requires OCOS_ENABLE_OPENCV_CODECS to be ON")
|
message(FATAL_ERROR "OCOS_ENABLE_VISION requires OCOS_ENABLE_OPENCV_CODECS to be ON")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
file(GLOB TARGET_SRC_VISION "operators/vision/*.cc" "operators/vision/*.h*")
|
file(GLOB TARGET_SRC_VISION "operators/vision/*.cc" "operators/vision/*.h*")
|
||||||
|
@ -199,14 +209,15 @@ if (OCOS_ENABLE_VISION)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(_HAS_TOKENIZER OFF)
|
set(_HAS_TOKENIZER OFF)
|
||||||
if (OCOS_ENABLE_GPT2_TOKENIZER)
|
|
||||||
|
if(OCOS_ENABLE_GPT2_TOKENIZER)
|
||||||
# GPT2
|
# GPT2
|
||||||
set(_HAS_TOKENIZER ON)
|
set(_HAS_TOKENIZER ON)
|
||||||
file(GLOB tok_TARGET_SRC "operators/tokenizer/gpt*.cc" "operators/tokenizer/unicode*.*")
|
file(GLOB tok_TARGET_SRC "operators/tokenizer/gpt*.cc" "operators/tokenizer/unicode*.*")
|
||||||
list(APPEND TARGET_SRC ${tok_TARGET_SRC})
|
list(APPEND TARGET_SRC ${tok_TARGET_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_SPM_TOKENIZER)
|
if(OCOS_ENABLE_SPM_TOKENIZER)
|
||||||
# SentencePiece
|
# SentencePiece
|
||||||
set(_HAS_TOKENIZER ON)
|
set(_HAS_TOKENIZER ON)
|
||||||
set(SPM_ENABLE_TCMALLOC OFF CACHE INTERNAL "")
|
set(SPM_ENABLE_TCMALLOC OFF CACHE INTERNAL "")
|
||||||
|
@ -218,42 +229,41 @@ if (OCOS_ENABLE_SPM_TOKENIZER)
|
||||||
list(APPEND TARGET_SRC ${stpiece_TARGET_SRC})
|
list(APPEND TARGET_SRC ${stpiece_TARGET_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
if(OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
||||||
set(_HAS_TOKENIZER ON)
|
set(_HAS_TOKENIZER ON)
|
||||||
file(GLOB wordpiece_TARGET_SRC "operators/tokenizer/wordpiece*.*")
|
file(GLOB wordpiece_TARGET_SRC "operators/tokenizer/wordpiece*.*")
|
||||||
list(APPEND TARGET_SRC ${wordpiece_TARGET_SRC})
|
list(APPEND TARGET_SRC ${wordpiece_TARGET_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_BERT_TOKENIZER)
|
if(OCOS_ENABLE_BERT_TOKENIZER)
|
||||||
# Bert
|
# Bert
|
||||||
set(_HAS_TOKENIZER ON)
|
set(_HAS_TOKENIZER ON)
|
||||||
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*")
|
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*")
|
||||||
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
|
list(APPEND TARGET_SRC ${bert_TARGET_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_BLINGFIRE)
|
if(OCOS_ENABLE_BLINGFIRE)
|
||||||
# blingfire
|
# blingfire
|
||||||
set(_HAS_TOKENIZER ON)
|
set(_HAS_TOKENIZER ON)
|
||||||
file(GLOB blingfire_TARGET_SRC "operators/tokenizer/blingfire*.*")
|
file(GLOB blingfire_TARGET_SRC "operators/tokenizer/blingfire*.*")
|
||||||
list(APPEND TARGET_SRC ${blingfire_TARGET_SRC})
|
list(APPEND TARGET_SRC ${blingfire_TARGET_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
||||||
if (NOT TARGET nlohmann_json)
|
if(NOT TARGET nlohmann_json)
|
||||||
set(JSON_BuildTests OFF CACHE INTERNAL "")
|
set(JSON_BuildTests OFF CACHE INTERNAL "")
|
||||||
message(STATUS "Fetch json")
|
message(STATUS "Fetch json")
|
||||||
include(json)
|
include(json)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (_HAS_TOKENIZER)
|
if(_HAS_TOKENIZER)
|
||||||
message(STATUS "Tokenizer needed.")
|
message(STATUS "Tokenizer needed.")
|
||||||
file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*")
|
file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*")
|
||||||
list(APPEND TARGET_SRC ${tokenizer_TARGET_SRC})
|
list(APPEND TARGET_SRC ${tokenizer_TARGET_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
#### make all compile options.
|
# ### make all compile options.
|
||||||
|
|
||||||
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
|
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
|
||||||
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
|
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
|
||||||
add_library(ocos_operators STATIC ${TARGET_SRC})
|
add_library(ocos_operators STATIC ${TARGET_SRC})
|
||||||
|
@ -274,22 +284,22 @@ source_group(TREE ${PROJECT_SOURCE_DIR} FILES ${_TARGET_SRC_FOR_SOURCE_GROUP})
|
||||||
standardize_output_folder(ocos_operators)
|
standardize_output_folder(ocos_operators)
|
||||||
|
|
||||||
target_include_directories(ocos_operators PUBLIC
|
target_include_directories(ocos_operators PUBLIC
|
||||||
|
${ONNXRUNTIME_INCLUDE_DIR}
|
||||||
${PROJECT_SOURCE_DIR}/includes
|
${PROJECT_SOURCE_DIR}/includes
|
||||||
${PROJECT_SOURCE_DIR}/includes/onnxruntime
|
|
||||||
${PROJECT_SOURCE_DIR}/operators
|
${PROJECT_SOURCE_DIR}/operators
|
||||||
${PROJECT_SOURCE_DIR}/operators/tokenizer)
|
${PROJECT_SOURCE_DIR}/operators/tokenizer)
|
||||||
set(ocos_libraries "")
|
set(ocos_libraries "")
|
||||||
set(OCOS_COMPILE_DEFINITIONS "")
|
set(OCOS_COMPILE_DEFINITIONS "")
|
||||||
|
|
||||||
if (OCOS_ENABLE_DLIB)
|
if(OCOS_ENABLE_DLIB)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_DLIB)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_DLIB)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (_HAS_TOKENIZER)
|
if(_HAS_TOKENIZER)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TOKENIZER)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TOKENIZER)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_TF_STRING)
|
if(OCOS_ENABLE_TF_STRING)
|
||||||
target_include_directories(ocos_operators PUBLIC
|
target_include_directories(ocos_operators PUBLIC
|
||||||
${googlere2_SOURCE_DIR}
|
${googlere2_SOURCE_DIR}
|
||||||
${farmhash_SOURCE_DIR}/src)
|
${farmhash_SOURCE_DIR}/src)
|
||||||
|
@ -297,62 +307,63 @@ if (OCOS_ENABLE_TF_STRING)
|
||||||
list(APPEND ocos_libraries re2)
|
list(APPEND ocos_libraries re2)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_RE2_REGEX)
|
if(OCOS_ENABLE_RE2_REGEX)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_RE2_REGEX)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_RE2_REGEX)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_MATH)
|
if(OCOS_ENABLE_MATH)
|
||||||
target_include_directories(ocos_operators PUBLIC ${dlib_SOURCE_DIR})
|
target_include_directories(ocos_operators PUBLIC ${dlib_SOURCE_DIR})
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_MATH)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_MATH)
|
||||||
|
|
||||||
# The dlib matrix implementation is all in the headers, no library compiling needed.
|
# The dlib matrix implementation is all in the headers, no library compiling needed.
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (_ENABLE_OPENCV)
|
if(_ENABLE_OPENCV)
|
||||||
list(APPEND ocos_libraries ${opencv_LIBS})
|
list(APPEND ocos_libraries ${opencv_LIBS})
|
||||||
target_include_directories(ocos_operators PUBLIC ${opencv_INCLUDE_DIRS})
|
target_include_directories(ocos_operators PUBLIC ${opencv_INCLUDE_DIRS})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_OPENCV_CODECS)
|
if(OCOS_ENABLE_OPENCV_CODECS)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_OPENCV_CODECS)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_OPENCV_CODECS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_CV2)
|
if(OCOS_ENABLE_CV2)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_CV2)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_CV2)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_VISION)
|
if(OCOS_ENABLE_VISION)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_VISION)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_VISION)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_GPT2_TOKENIZER)
|
if(OCOS_ENABLE_GPT2_TOKENIZER)
|
||||||
# GPT2
|
# GPT2
|
||||||
target_include_directories(ocos_operators PRIVATE ${json_SOURCE_DIR}/single_include)
|
target_include_directories(ocos_operators PRIVATE ${json_SOURCE_DIR}/single_include)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_GPT2_TOKENIZER)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_GPT2_TOKENIZER)
|
||||||
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
|
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
if(OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_WORDPIECE_TOKENIZER)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_WORDPIECE_TOKENIZER)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_BERT_TOKENIZER)
|
if(OCOS_ENABLE_BERT_TOKENIZER)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BERT_TOKENIZER)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BERT_TOKENIZER)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_SPM_TOKENIZER)
|
if(OCOS_ENABLE_SPM_TOKENIZER)
|
||||||
# SentencePiece
|
# SentencePiece
|
||||||
target_include_directories(ocos_operators PUBLIC ${spm_INCLUDE_DIRS})
|
target_include_directories(ocos_operators PUBLIC ${spm_INCLUDE_DIRS})
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_SPM_TOKENIZER)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_SPM_TOKENIZER)
|
||||||
list(APPEND ocos_libraries sentencepiece-static)
|
list(APPEND ocos_libraries sentencepiece-static)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_BLINGFIRE)
|
if(OCOS_ENABLE_BLINGFIRE)
|
||||||
include(blingfire)
|
include(blingfire)
|
||||||
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BLINGFIRE)
|
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BLINGFIRE)
|
||||||
list(APPEND ocos_libraries bingfirtinydll_static)
|
list(APPEND ocos_libraries bingfirtinydll_static)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
|
||||||
target_include_directories(ocos_operators PRIVATE ${json_SOURCE_DIR}/single_include)
|
target_include_directories(ocos_operators PRIVATE ${json_SOURCE_DIR}/single_include)
|
||||||
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
|
list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
|
||||||
endif()
|
endif()
|
||||||
|
@ -362,6 +373,7 @@ target_compile_definitions(ocos_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS})
|
||||||
target_link_libraries(ocos_operators PRIVATE ${ocos_libraries})
|
target_link_libraries(ocos_operators PRIVATE ${ocos_libraries})
|
||||||
|
|
||||||
file(GLOB shared_TARGET_LIB_SRC "shared/lib/*.cc" "shared/lib/*.h")
|
file(GLOB shared_TARGET_LIB_SRC "shared/lib/*.cc" "shared/lib/*.h")
|
||||||
|
|
||||||
if(NOT OCOS_ENABLE_STATIC_LIB AND CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
if(NOT OCOS_ENABLE_STATIC_LIB AND CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
||||||
add_executable(ortcustomops ${shared_TARGET_LIB_SRC})
|
add_executable(ortcustomops ${shared_TARGET_LIB_SRC})
|
||||||
set_target_properties(ortcustomops PROPERTIES LINK_FLAGS " \
|
set_target_properties(ortcustomops PROPERTIES LINK_FLAGS " \
|
||||||
|
@ -375,7 +387,8 @@ if(NOT OCOS_ENABLE_STATIC_LIB AND CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
||||||
-s EXPORT_ALL=0 \
|
-s EXPORT_ALL=0 \
|
||||||
-s VERBOSE=0 \
|
-s VERBOSE=0 \
|
||||||
--no-entry")
|
--no-entry")
|
||||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
|
||||||
|
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||||
set_property(TARGET ortcustomops APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=1 -s DEMANGLE_SUPPORT=1")
|
set_property(TARGET ortcustomops APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=1 -s DEMANGLE_SUPPORT=1")
|
||||||
else()
|
else()
|
||||||
set_property(TARGET ortcustomops APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=0 -s DEMANGLE_SUPPORT=0")
|
set_property(TARGET ortcustomops APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=0 -s DEMANGLE_SUPPORT=0")
|
||||||
|
@ -392,18 +405,19 @@ target_include_directories(ortcustomops PUBLIC
|
||||||
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
|
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||||
target_link_libraries(ortcustomops PUBLIC ocos_operators)
|
target_link_libraries(ortcustomops PUBLIC ocos_operators)
|
||||||
|
|
||||||
if (_BUILD_SHARED_LIBRARY)
|
if(_BUILD_SHARED_LIBRARY)
|
||||||
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h" "shared/*.def")
|
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h" "shared/*.def")
|
||||||
add_library(extensions_shared SHARED ${shared_TARGET_SRC})
|
add_library(extensions_shared SHARED ${shared_TARGET_SRC})
|
||||||
source_group(TREE ${PROJECT_SOURCE_DIR} FILES ${shared_TARGET_SRC})
|
source_group(TREE ${PROJECT_SOURCE_DIR} FILES ${shared_TARGET_SRC})
|
||||||
standardize_output_folder(extensions_shared)
|
standardize_output_folder(extensions_shared)
|
||||||
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
|
|
||||||
if (OCOS_ENABLE_SPM_TOKENIZER)
|
if(CMAKE_SYSTEM_NAME STREQUAL "Android")
|
||||||
|
if(OCOS_ENABLE_SPM_TOKENIZER)
|
||||||
target_link_libraries(extensions_shared PUBLIC log)
|
target_link_libraries(extensions_shared PUBLIC log)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (LINUX OR CMAKE_SYSTEM_NAME STREQUAL "Android")
|
if(LINUX OR CMAKE_SYSTEM_NAME STREQUAL "Android")
|
||||||
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS "-Wl,-s -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
|
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS "-Wl,-s -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
@ -429,14 +443,16 @@ endif()
|
||||||
|
|
||||||
# clean up the requirements.txt files from 3rd party project folder to suppress the code security false alarms
|
# clean up the requirements.txt files from 3rd party project folder to suppress the code security false alarms
|
||||||
file(GLOB_RECURSE NO_USE_FILES ${CMAKE_BINARY_DIR}/_deps/*requirements.txt)
|
file(GLOB_RECURSE NO_USE_FILES ${CMAKE_BINARY_DIR}/_deps/*requirements.txt)
|
||||||
message("Found the follow requirements.txt: ${NO_USE_FILES}")
|
message(STATUS "Found the follow requirements.txt: ${NO_USE_FILES}")
|
||||||
|
|
||||||
foreach(nf ${NO_USE_FILES})
|
foreach(nf ${NO_USE_FILES})
|
||||||
file(TO_NATIVE_PATH ${nf} nf_native)
|
file(TO_NATIVE_PATH ${nf} nf_native)
|
||||||
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
|
||||||
execute_process(COMMAND cmd /c "del ${nf_native}")
|
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
else()
|
execute_process(COMMAND cmd /c "del ${nf_native}")
|
||||||
execute_process(COMMAND bash -c "rm ${nf_native}")
|
else()
|
||||||
endif()
|
execute_process(COMMAND bash -c "rm ${nf_native}")
|
||||||
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
# test section
|
# test section
|
||||||
|
@ -460,10 +476,12 @@ elseif(OCOS_ENABLE_CTEST AND NOT OCOS_ENABLE_SELECTED_OPLIST)
|
||||||
|
|
||||||
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
|
||||||
find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}")
|
find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}")
|
||||||
if (ONNXRUNTIME-NOTFOUND)
|
|
||||||
|
if(ONNXRUNTIME-NOTFOUND)
|
||||||
message(WARNING "The prebuilt onnxruntime libraries directory cannot found (via ONNXRUNTIME_LIB_DIR), the extensions_test will be skipped.")
|
message(WARNING "The prebuilt onnxruntime libraries directory cannot found (via ONNXRUNTIME_LIB_DIR), the extensions_test will be skipped.")
|
||||||
else()
|
else()
|
||||||
set(LINUX_CC_FLAGS "")
|
set(LINUX_CC_FLAGS "")
|
||||||
|
|
||||||
# needs to link with stdc++fs in Linux
|
# needs to link with stdc++fs in Linux
|
||||||
if(UNIX AND NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android")
|
if(UNIX AND NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android")
|
||||||
list(APPEND LINUX_CC_FLAGS stdc++fs -pthread)
|
list(APPEND LINUX_CC_FLAGS stdc++fs -pthread)
|
||||||
|
@ -474,11 +492,14 @@ elseif(OCOS_ENABLE_CTEST AND NOT OCOS_ENABLE_SELECTED_OPLIST)
|
||||||
standardize_output_folder(extensions_test)
|
standardize_output_folder(extensions_test)
|
||||||
target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS}
|
target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS}
|
||||||
"$<TARGET_PROPERTY:extensions_shared,INTERFACE_INCLUDE_DIRECTORIES>")
|
"$<TARGET_PROPERTY:extensions_shared,INTERFACE_INCLUDE_DIRECTORIES>")
|
||||||
if (ONNXRUNTIME_LIB_DIR)
|
|
||||||
|
if(ONNXRUNTIME_LIB_DIR)
|
||||||
target_link_directories(extensions_test PRIVATE ${ONNXRUNTIME_LIB_DIR})
|
target_link_directories(extensions_test PRIVATE ${ONNXRUNTIME_LIB_DIR})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_link_libraries(extensions_test PRIVATE ocos_operators extensions_shared onnxruntime gtest_main ${ocos_libraries} ${LINUX_CC_FLAGS})
|
target_link_libraries(extensions_test PRIVATE ocos_operators extensions_shared onnxruntime gtest_main ${ocos_libraries} ${LINUX_CC_FLAGS})
|
||||||
if (WIN32)
|
|
||||||
|
if(WIN32)
|
||||||
file(TO_CMAKE_PATH "${ONNXRUNTIME_LIB_DIR}/*" ONNXRUNTIME_LIB_FILEPATTERN)
|
file(TO_CMAKE_PATH "${ONNXRUNTIME_LIB_DIR}/*" ONNXRUNTIME_LIB_FILEPATTERN)
|
||||||
file(GLOB ONNXRUNTIME_LIB_FILES CONFIGURE_DEPENDS "${ONNXRUNTIME_LIB_FILEPATTERN}")
|
file(GLOB ONNXRUNTIME_LIB_FILES CONFIGURE_DEPENDS "${ONNXRUNTIME_LIB_FILEPATTERN}")
|
||||||
add_custom_command(
|
add_custom_command(
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
if(_ONNXRUNTIME_EMBEDDED)
|
||||||
|
set(ONNXRUNTIME_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/include/onnxruntime/core/session)
|
||||||
|
set(ONNXRUNTIME_LIB_DIR "")
|
||||||
|
else()
|
||||||
|
set(ONNXRUNTIME_VER "1.10.0" CACHE STRING "ONNX Runtime version")
|
||||||
|
|
||||||
|
if(CMAKE_HOST_APPLE)
|
||||||
|
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-osx-universal2-${ONNXRUNTIME_VER}.tgz")
|
||||||
|
elseif(CMAKE_HOST_WIN32)
|
||||||
|
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-win-x64-${ONNXRUNTIME_VER}.zip")
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64")
|
||||||
|
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-win-arm64-${ONNXRUNTIME_VER}.zip")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
# Linux or other, using Linux package to retrieve the headers
|
||||||
|
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-linux-x64-${ONNXRUNTIME_VER}.tgz")
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
|
||||||
|
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-linux-aarch64-${ONNXRUNTIME_VER}.tgz")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
|
||||||
|
if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
|
||||||
|
cmake_policy(SET CMP0135 NEW)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message(STATUS "ONNX Runtime URL suffix: ${ONNXRUNTIME_URL}")
|
||||||
|
FetchContent_Declare(
|
||||||
|
onnxruntime
|
||||||
|
URL https://github.com/microsoft/onnxruntime/releases/download/${ONNXRUNTIME_URL}
|
||||||
|
)
|
||||||
|
FetchContent_makeAvailable(onnxruntime)
|
||||||
|
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include)
|
||||||
|
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT EXISTS ${ONNXRUNTIME_INCLUDE_DIR})
|
||||||
|
message(FATAL_ERROR "ONNX Runtime headers not found at ${ONNXRUNTIME_INCLUDE_DIR}")
|
||||||
|
endif()
|
|
@ -8,9 +8,7 @@
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#define ORT_API_MANUAL_INIT
|
#include "onnxruntime_customop.hpp"
|
||||||
#include "onnxruntime_cxx_api.h"
|
|
||||||
#undef ORT_API_MANUAL_INIT
|
|
||||||
|
|
||||||
// A helper API to support test kernels.
|
// A helper API to support test kernels.
|
||||||
// Must be invoked before RegisterCustomOps.
|
// Must be invoked before RegisterCustomOps.
|
||||||
|
@ -40,13 +38,13 @@ struct BaseKernel {
|
||||||
protected:
|
protected:
|
||||||
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
|
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
|
||||||
const OrtApi& api_;
|
const OrtApi& api_;
|
||||||
Ort::CustomOpApi ort_;
|
OrtW::CustomOpApi ort_;
|
||||||
const OrtKernelInfo* info_;
|
const OrtKernelInfo* info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||||
OrtTensorDimensions() = default;
|
OrtTensorDimensions() = default;
|
||||||
OrtTensorDimensions(Ort::CustomOpApi& ort, const OrtValue* value) {
|
OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
|
||||||
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
||||||
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
|
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
|
||||||
ort.ReleaseTensorTypeAndShapeInfo(info);
|
ort.ReleaseTensorTypeAndShapeInfo(info);
|
||||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,945 +0,0 @@
|
||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
||||||
// Licensed under the MIT License.
|
|
||||||
|
|
||||||
// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
|
|
||||||
//
|
|
||||||
// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
|
|
||||||
// and automatically releasing resources in the destructors.
|
|
||||||
//
|
|
||||||
// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
|
|
||||||
// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};).
|
|
||||||
//
|
|
||||||
// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone'
|
|
||||||
// methods for this purpose.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include "onnxruntime_c_api.h"
|
|
||||||
#include <cstddef>
|
|
||||||
#include <array>
|
|
||||||
#include <memory>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#ifdef ORT_NO_EXCEPTIONS
|
|
||||||
#include <iostream>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace Ort {
|
|
||||||
|
|
||||||
// All C++ methods that can fail will throw an exception of this type
|
|
||||||
struct Exception : std::exception {
|
|
||||||
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
|
|
||||||
|
|
||||||
OrtErrorCode GetOrtErrorCode() const { return code_; }
|
|
||||||
const char* what() const noexcept override { return message_.c_str(); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::string message_;
|
|
||||||
OrtErrorCode code_;
|
|
||||||
};
|
|
||||||
|
|
||||||
#ifdef ORT_NO_EXCEPTIONS
|
|
||||||
#define ORT_CXX_API_THROW(string, code) \
|
|
||||||
do { \
|
|
||||||
std::cerr << Ort::Exception(string, code) \
|
|
||||||
.what() \
|
|
||||||
<< std::endl; \
|
|
||||||
abort(); \
|
|
||||||
} while (false)
|
|
||||||
#else
|
|
||||||
#define ORT_CXX_API_THROW(string, code) \
|
|
||||||
throw Ort::Exception(string, code)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make
|
|
||||||
// it transparent to the users of the API.
|
|
||||||
template <typename T>
|
|
||||||
struct Global {
|
|
||||||
static const OrtApi* api_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
#ifdef ORT_API_MANUAL_INIT
|
|
||||||
const OrtApi* Global<T>::api_{};
|
|
||||||
inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
|
|
||||||
#else
|
|
||||||
const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
|
|
||||||
inline const OrtApi& GetApi() { return *Global<void>::api_; }
|
|
||||||
|
|
||||||
// This is a C++ wrapper for GetAvailableProviders() C API and returns
|
|
||||||
// a vector of strings representing the available execution providers.
|
|
||||||
std::vector<std::string> GetAvailableProviders();
|
|
||||||
|
|
||||||
// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
|
|
||||||
// This can't be done in the C API since C doesn't have function overloading.
|
|
||||||
#define ORT_DEFINE_RELEASE(NAME) \
|
|
||||||
inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
|
|
||||||
|
|
||||||
ORT_DEFINE_RELEASE(Allocator);
|
|
||||||
ORT_DEFINE_RELEASE(MemoryInfo);
|
|
||||||
ORT_DEFINE_RELEASE(CustomOpDomain);
|
|
||||||
ORT_DEFINE_RELEASE(Env);
|
|
||||||
ORT_DEFINE_RELEASE(RunOptions);
|
|
||||||
ORT_DEFINE_RELEASE(Session);
|
|
||||||
ORT_DEFINE_RELEASE(SessionOptions);
|
|
||||||
ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
|
|
||||||
ORT_DEFINE_RELEASE(SequenceTypeInfo);
|
|
||||||
ORT_DEFINE_RELEASE(MapTypeInfo);
|
|
||||||
ORT_DEFINE_RELEASE(TypeInfo);
|
|
||||||
ORT_DEFINE_RELEASE(Value);
|
|
||||||
ORT_DEFINE_RELEASE(ModelMetadata);
|
|
||||||
ORT_DEFINE_RELEASE(ThreadingOptions);
|
|
||||||
ORT_DEFINE_RELEASE(IoBinding);
|
|
||||||
ORT_DEFINE_RELEASE(ArenaCfg);
|
|
||||||
|
|
||||||
/*! \class Ort::Float16_t
|
|
||||||
* \brief it is a structure that represents float16 data.
|
|
||||||
* \details It is necessary for type dispatching to make use of C++ API
|
|
||||||
* The type is implicitly convertible to/from uint16_t.
|
|
||||||
* The size of the structure should align with uint16_t and one can freely cast
|
|
||||||
* uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
|
|
||||||
*
|
|
||||||
* Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
|
|
||||||
* on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
|
|
||||||
* And you can also feed a array of uint16_t elements directly. For example,
|
|
||||||
*
|
|
||||||
* \code{.unparsed}
|
|
||||||
* uint16_t values[] = { 15360, 16384, 16896, 17408, 17664};
|
|
||||||
* constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
|
|
||||||
* std::vector<int64_t> dims = {values_length}; // one dimensional example
|
|
||||||
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
|
||||||
* // Note we are passing bytes count in this api, not number of elements -> sizeof(values)
|
|
||||||
* auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values),
|
|
||||||
* dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
|
||||||
* \endcode
|
|
||||||
*
|
|
||||||
* Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
|
|
||||||
* a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
|
|
||||||
* template specialization.
|
|
||||||
*
|
|
||||||
* \code{.unparsed}
|
|
||||||
* namespace yours { struct half {}; } // assume this is your type, define this:
|
|
||||||
* namespace Ort {
|
|
||||||
* template<>
|
|
||||||
* struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
|
|
||||||
* } //namespace Ort
|
|
||||||
*
|
|
||||||
* std::vector<yours::half> values;
|
|
||||||
* std::vector<int64_t> dims = {values.size()}; // one dimensional example
|
|
||||||
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
|
|
||||||
* // Here we are passing element count -> values.size()
|
|
||||||
* auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
|
|
||||||
*
|
|
||||||
* \endcode
|
|
||||||
*/
|
|
||||||
struct Float16_t {
|
|
||||||
uint16_t value;
|
|
||||||
constexpr Float16_t() noexcept : value(0) {}
|
|
||||||
constexpr Float16_t(uint16_t v) noexcept : value(v) {}
|
|
||||||
constexpr operator uint16_t() const noexcept { return value; }
|
|
||||||
constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
|
|
||||||
constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
|
|
||||||
};
|
|
||||||
|
|
||||||
static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
|
|
||||||
|
|
||||||
/*! \class Ort::BFloat16_t
|
|
||||||
* \brief is a structure that represents bfloat16 data.
|
|
||||||
* \details It is necessary for type dispatching to make use of C++ API
|
|
||||||
* The type is implicitly convertible to/from uint16_t.
|
|
||||||
* The size of the structure should align with uint16_t and one can freely cast
|
|
||||||
* uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
|
|
||||||
*
|
|
||||||
* See also code examples for Float16_t above.
|
|
||||||
*/
|
|
||||||
struct BFloat16_t {
|
|
||||||
uint16_t value;
|
|
||||||
constexpr BFloat16_t() noexcept : value(0) {}
|
|
||||||
constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
|
|
||||||
constexpr operator uint16_t() const noexcept { return value; }
|
|
||||||
constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
|
|
||||||
constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
|
|
||||||
};
|
|
||||||
|
|
||||||
static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
|
|
||||||
|
|
||||||
// This is used internally by the C++ API. This is the common base class used by the wrapper objects.
|
|
||||||
template <typename T>
|
|
||||||
struct Base {
|
|
||||||
using contained_type = T;
|
|
||||||
|
|
||||||
Base() = default;
|
|
||||||
Base(T* p) : p_{p} {
|
|
||||||
if (!p)
|
|
||||||
ORT_CXX_API_THROW("Allocation failure", ORT_FAIL);
|
|
||||||
}
|
|
||||||
~Base() { OrtRelease(p_); }
|
|
||||||
|
|
||||||
operator T*() { return p_; }
|
|
||||||
operator const T*() const { return p_; }
|
|
||||||
|
|
||||||
T* release() {
|
|
||||||
T* p = p_;
|
|
||||||
p_ = nullptr;
|
|
||||||
return p;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Base(const Base&) = delete;
|
|
||||||
Base& operator=(const Base&) = delete;
|
|
||||||
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
|
|
||||||
void operator=(Base&& v) noexcept {
|
|
||||||
OrtRelease(p_);
|
|
||||||
p_ = v.p_;
|
|
||||||
v.p_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
T* p_{};
|
|
||||||
|
|
||||||
template <typename>
|
|
||||||
friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Base<const T> {
|
|
||||||
using contained_type = const T;
|
|
||||||
|
|
||||||
Base() = default;
|
|
||||||
Base(const T* p) : p_{p} {
|
|
||||||
if (!p)
|
|
||||||
ORT_CXX_API_THROW("Invalid instance ptr", ORT_INVALID_ARGUMENT);
|
|
||||||
}
|
|
||||||
~Base() = default;
|
|
||||||
|
|
||||||
operator const T*() const { return p_; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Base(const Base&) = delete;
|
|
||||||
Base& operator=(const Base&) = delete;
|
|
||||||
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
|
|
||||||
void operator=(Base&& v) noexcept {
|
|
||||||
p_ = v.p_;
|
|
||||||
v.p_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const T* p_{};
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Unowned : T {
|
|
||||||
Unowned(decltype(T::p_) p) : T{p} {}
|
|
||||||
Unowned(Unowned&& v) : T{v.p_} {}
|
|
||||||
~Unowned() { this->release(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
struct AllocatorWithDefaultOptions;
|
|
||||||
struct MemoryInfo;
|
|
||||||
struct Env;
|
|
||||||
struct TypeInfo;
|
|
||||||
struct Value;
|
|
||||||
struct ModelMetadata;
|
|
||||||
|
|
||||||
struct Env : Base<OrtEnv> {
|
|
||||||
Env(std::nullptr_t) {}
|
|
||||||
Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
|
||||||
Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
|
||||||
Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
|
|
||||||
Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
|
|
||||||
OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
|
|
||||||
explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
|
|
||||||
|
|
||||||
Env& EnableTelemetryEvents();
|
|
||||||
Env& DisableTelemetryEvents();
|
|
||||||
|
|
||||||
Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
|
|
||||||
|
|
||||||
static const OrtApi* s_api;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CustomOpDomain : Base<OrtCustomOpDomain> {
|
|
||||||
explicit CustomOpDomain(std::nullptr_t) {}
|
|
||||||
explicit CustomOpDomain(const char* domain);
|
|
||||||
|
|
||||||
void Add(OrtCustomOp* op);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct RunOptions : Base<OrtRunOptions> {
|
|
||||||
RunOptions(std::nullptr_t) {}
|
|
||||||
RunOptions();
|
|
||||||
|
|
||||||
RunOptions& SetRunLogVerbosityLevel(int);
|
|
||||||
int GetRunLogVerbosityLevel() const;
|
|
||||||
|
|
||||||
RunOptions& SetRunLogSeverityLevel(int);
|
|
||||||
int GetRunLogSeverityLevel() const;
|
|
||||||
|
|
||||||
RunOptions& SetRunTag(const char* run_tag);
|
|
||||||
const char* GetRunTag() const;
|
|
||||||
|
|
||||||
RunOptions& AddConfigEntry(const char* config_key, const char* config_value);
|
|
||||||
|
|
||||||
// terminate ALL currently executing Session::Run calls that were made using this RunOptions instance
|
|
||||||
RunOptions& SetTerminate();
|
|
||||||
// unset the terminate flag so this RunOptions instance can be used in a new Session::Run call
|
|
||||||
RunOptions& UnsetTerminate();
|
|
||||||
};
|
|
||||||
|
|
||||||
struct SessionOptions : Base<OrtSessionOptions> {
|
|
||||||
explicit SessionOptions(std::nullptr_t) {}
|
|
||||||
SessionOptions();
|
|
||||||
explicit SessionOptions(OrtSessionOptions* p) : Base<OrtSessionOptions>{p} {}
|
|
||||||
|
|
||||||
SessionOptions Clone() const;
|
|
||||||
|
|
||||||
SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads);
|
|
||||||
SessionOptions& SetInterOpNumThreads(int inter_op_num_threads);
|
|
||||||
SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
|
|
||||||
|
|
||||||
SessionOptions& EnableCpuMemArena();
|
|
||||||
SessionOptions& DisableCpuMemArena();
|
|
||||||
|
|
||||||
SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
|
|
||||||
|
|
||||||
SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
|
|
||||||
SessionOptions& DisableProfiling();
|
|
||||||
|
|
||||||
SessionOptions& EnableOrtCustomOps();
|
|
||||||
|
|
||||||
SessionOptions& EnableMemPattern();
|
|
||||||
SessionOptions& DisableMemPattern();
|
|
||||||
|
|
||||||
SessionOptions& SetExecutionMode(ExecutionMode execution_mode);
|
|
||||||
|
|
||||||
SessionOptions& SetLogId(const char* logid);
|
|
||||||
SessionOptions& SetLogSeverityLevel(int level);
|
|
||||||
|
|
||||||
SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);
|
|
||||||
|
|
||||||
SessionOptions& DisablePerSessionThreads();
|
|
||||||
|
|
||||||
SessionOptions& AddConfigEntry(const char* config_key, const char* config_value);
|
|
||||||
SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
|
|
||||||
|
|
||||||
SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
|
|
||||||
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
|
|
||||||
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
|
|
||||||
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ModelMetadata : Base<OrtModelMetadata> {
|
|
||||||
explicit ModelMetadata(std::nullptr_t) {}
|
|
||||||
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {}
|
|
||||||
|
|
||||||
char* GetProducerName(OrtAllocator* allocator) const;
|
|
||||||
char* GetGraphName(OrtAllocator* allocator) const;
|
|
||||||
char* GetDomain(OrtAllocator* allocator) const;
|
|
||||||
char* GetDescription(OrtAllocator* allocator) const;
|
|
||||||
char* GetGraphDescription(OrtAllocator* allocator) const;
|
|
||||||
char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const;
|
|
||||||
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const;
|
|
||||||
int64_t GetVersion() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Session : Base<OrtSession> {
|
|
||||||
explicit Session(std::nullptr_t) {}
|
|
||||||
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
|
|
||||||
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container);
|
|
||||||
Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
|
|
||||||
|
|
||||||
// Run that will allocate the output values
|
|
||||||
std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
|
||||||
const char* const* output_names, size_t output_count);
|
|
||||||
// Run for when there is a list of preallocated outputs
|
|
||||||
void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
|
||||||
const char* const* output_names, Value* output_values, size_t output_count);
|
|
||||||
|
|
||||||
void Run(const RunOptions& run_options, const struct IoBinding&);
|
|
||||||
|
|
||||||
size_t GetInputCount() const;
|
|
||||||
size_t GetOutputCount() const;
|
|
||||||
size_t GetOverridableInitializerCount() const;
|
|
||||||
|
|
||||||
char* GetInputName(size_t index, OrtAllocator* allocator) const;
|
|
||||||
char* GetOutputName(size_t index, OrtAllocator* allocator) const;
|
|
||||||
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const;
|
|
||||||
char* EndProfiling(OrtAllocator* allocator) const;
|
|
||||||
uint64_t GetProfilingStartTimeNs() const;
|
|
||||||
ModelMetadata GetModelMetadata() const;
|
|
||||||
|
|
||||||
TypeInfo GetInputTypeInfo(size_t index) const;
|
|
||||||
TypeInfo GetOutputTypeInfo(size_t index) const;
|
|
||||||
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TensorTypeAndShapeInfo : Base<OrtTensorTypeAndShapeInfo> {
|
|
||||||
explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
|
|
||||||
explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base<OrtTensorTypeAndShapeInfo>{p} {}
|
|
||||||
|
|
||||||
ONNXTensorElementDataType GetElementType() const;
|
|
||||||
size_t GetElementCount() const;
|
|
||||||
|
|
||||||
size_t GetDimensionsCount() const;
|
|
||||||
void GetDimensions(int64_t* values, size_t values_count) const;
|
|
||||||
void GetSymbolicDimensions(const char** values, size_t values_count) const;
|
|
||||||
|
|
||||||
std::vector<int64_t> GetShape() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct SequenceTypeInfo : Base<OrtSequenceTypeInfo> {
|
|
||||||
explicit SequenceTypeInfo(std::nullptr_t) {}
|
|
||||||
explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base<OrtSequenceTypeInfo>{p} {}
|
|
||||||
|
|
||||||
TypeInfo GetSequenceElementType() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct MapTypeInfo : Base<OrtMapTypeInfo> {
|
|
||||||
explicit MapTypeInfo(std::nullptr_t) {}
|
|
||||||
explicit MapTypeInfo(OrtMapTypeInfo* p) : Base<OrtMapTypeInfo>{p} {}
|
|
||||||
|
|
||||||
ONNXTensorElementDataType GetMapKeyType() const;
|
|
||||||
TypeInfo GetMapValueType() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TypeInfo : Base<OrtTypeInfo> {
|
|
||||||
explicit TypeInfo(std::nullptr_t) {}
|
|
||||||
explicit TypeInfo(OrtTypeInfo* p) : Base<OrtTypeInfo>{p} {}
|
|
||||||
|
|
||||||
Unowned<TensorTypeAndShapeInfo> GetTensorTypeAndShapeInfo() const;
|
|
||||||
Unowned<SequenceTypeInfo> GetSequenceTypeInfo() const;
|
|
||||||
Unowned<MapTypeInfo> GetMapTypeInfo() const;
|
|
||||||
|
|
||||||
ONNXType GetONNXType() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Value : Base<OrtValue> {
|
|
||||||
// This structure is used to feed sparse tensor values
|
|
||||||
// information for use with FillSparseTensor<Format>() API
|
|
||||||
// if the data type for the sparse tensor values is numeric
|
|
||||||
// use data.p_data, otherwise, use data.str pointer to feed
|
|
||||||
// values. data.str is an array of const char* that are zero terminated.
|
|
||||||
// number of strings in the array must match shape size.
|
|
||||||
// For fully sparse tensors use shape {0} and set p_data/str
|
|
||||||
// to nullptr.
|
|
||||||
struct OrtSparseValuesParam {
|
|
||||||
const int64_t* values_shape;
|
|
||||||
size_t values_shape_len;
|
|
||||||
union {
|
|
||||||
const void* p_data;
|
|
||||||
const char** str;
|
|
||||||
} data;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Provides a way to pass shape in a single
|
|
||||||
// argument
|
|
||||||
struct Shape {
|
|
||||||
const int64_t* shape;
|
|
||||||
size_t shape_len;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
|
|
||||||
static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
|
|
||||||
ONNXTensorElementDataType type);
|
|
||||||
|
|
||||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
||||||
/// <summary>
|
|
||||||
/// This is a simple forwarding method to the other overload that helps deducing
|
|
||||||
/// data type enum value from the type of the buffer.
|
|
||||||
/// </summary>
|
|
||||||
/// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
|
|
||||||
/// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
|
|
||||||
/// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
|
|
||||||
/// <param name="dense_shape">a would be dense shape of the tensor</param>
|
|
||||||
/// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
|
|
||||||
/// <returns></returns>
|
|
||||||
template <typename T>
|
|
||||||
static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
|
|
||||||
const Shape& values_shape);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Creates an OrtValue instance containing SparseTensor. This constructs
|
|
||||||
/// a sparse tensor that makes use of user allocated buffers. It does not make copies
|
|
||||||
/// of the user provided data and does not modify it. The lifespan of user provided buffers should
|
|
||||||
/// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
|
|
||||||
/// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
|
|
||||||
/// to supply a sparse format specific indices.
|
|
||||||
/// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
|
|
||||||
/// can be properly copied into the allocated buffer.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
|
|
||||||
/// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
|
|
||||||
/// <param name="dense_shape">a would be dense shape of the tensor</param>
|
|
||||||
/// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
|
|
||||||
/// <param name="type">data type</param>
|
|
||||||
/// <returns>Ort::Value instance containing SparseTensor</returns>
|
|
||||||
static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
|
|
||||||
const Shape& values_shape, ONNXTensorElementDataType type);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
|
|
||||||
/// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
|
|
||||||
/// allocated buffers lifespan must eclipse that of the OrtValue.
|
|
||||||
/// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
|
|
||||||
/// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
|
|
||||||
void UseCooIndices(int64_t* indices_data, size_t indices_num);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
|
|
||||||
/// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
|
|
||||||
/// allocated buffers lifespan must eclipse that of the OrtValue.
|
|
||||||
/// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
|
|
||||||
/// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
|
|
||||||
/// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
|
|
||||||
/// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
|
|
||||||
void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
|
|
||||||
/// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
|
|
||||||
/// allocated buffers lifespan must eclipse that of the OrtValue.
|
|
||||||
/// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
|
|
||||||
/// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
|
|
||||||
void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
|
|
||||||
|
|
||||||
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
|
|
||||||
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
|
|
||||||
|
|
||||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
||||||
/// <summary>
|
|
||||||
/// This is a simple forwarding method the below CreateSparseTensor.
|
|
||||||
/// This helps to specify data type enum in terms of C++ data type.
|
|
||||||
/// Use CreateSparseTensor<T>
|
|
||||||
/// </summary>
|
|
||||||
/// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
|
|
||||||
/// <param name="allocator">allocator to use</param>
|
|
||||||
/// <param name="dense_shape">a would be dense shape of the tensor</param>
|
|
||||||
/// <returns>Ort::Value</returns>
|
|
||||||
template <typename T>
|
|
||||||
static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
|
|
||||||
/// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
|
|
||||||
/// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
|
|
||||||
/// Use this API to create OrtValues that contain sparse tensors with all supported data types including
|
|
||||||
/// strings.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
|
|
||||||
/// <param name="dense_shape">a would be dense shape of the tensor</param>
|
|
||||||
/// <param name="type">data type</param>
|
|
||||||
/// <returns>an instance of Ort::Value</returns>
|
|
||||||
static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
|
|
||||||
/// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
|
|
||||||
/// at difference device than the allocator, a X-device copy will be performed if possible.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="data_mem_info">specified buffer memory description</param>
|
|
||||||
/// <param name="values_param">values buffer information.</param>
|
|
||||||
/// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
|
|
||||||
/// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
|
|
||||||
void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
|
|
||||||
const int64_t* indices_data, size_t indices_num);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
|
|
||||||
/// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
|
|
||||||
/// at difference device than the allocator, a X-device copy will be performed if possible.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="data_mem_info">specified buffer memory description</param>
|
|
||||||
/// <param name="values_param">values buffer information</param>
|
|
||||||
/// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
|
|
||||||
/// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
|
|
||||||
/// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
|
|
||||||
/// <param name="outer_indices_num">number of csr outer indices or 0</param>
|
|
||||||
void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
|
|
||||||
const OrtSparseValuesParam& values,
|
|
||||||
const int64_t* inner_indices_data, size_t inner_indices_num,
|
|
||||||
const int64_t* outer_indices_data, size_t outer_indices_num);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
|
|
||||||
/// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
|
|
||||||
/// at difference device than the allocator, a X-device copy will be performed if possible.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="data_mem_info">specified buffer memory description</param>
|
|
||||||
/// <param name="values_param">values buffer information</param>
|
|
||||||
/// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
|
|
||||||
/// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
|
|
||||||
void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
|
|
||||||
const OrtSparseValuesParam& values,
|
|
||||||
const Shape& indices_shape,
|
|
||||||
const int32_t* indices_data);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API returns the sparse data format this OrtValue holds in a sparse tensor.
|
|
||||||
/// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
|
|
||||||
/// the value returned is ORT_SPARSE_UNDEFINED.
|
|
||||||
/// </summary>
|
|
||||||
/// <returns>Format enum</returns>
|
|
||||||
OrtSparseFormat GetSparseFormat() const;
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API returns type and shape information for stored non-zero values of the
|
|
||||||
/// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
|
|
||||||
/// </summary>
|
|
||||||
/// <returns>TensorTypeAndShapeInfo values information</returns>
|
|
||||||
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API returns type and shape information for the specified indices. Each supported
|
|
||||||
/// indices have their own enum values even if a give format has more than one kind of indices.
|
|
||||||
/// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="">enum requested</param>
|
|
||||||
/// <returns>type and shape information</returns>
|
|
||||||
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat) const;
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API retrieves a pointer to the internal indices buffer. The API merely performs
|
|
||||||
/// a convenience data type casting on the return type pointer. Make sure you are requesting
|
|
||||||
/// the right type, use GetSparseTensorIndicesTypeShapeInfo();
|
|
||||||
/// </summary>
|
|
||||||
/// <typeparam name="T">type to cast to</typeparam>
|
|
||||||
/// <param name="indices_format">requested indices kind</param>
|
|
||||||
/// <param name="num_indices">number of indices entries</param>
|
|
||||||
/// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
|
|
||||||
template <typename T>
|
|
||||||
const T* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
|
|
||||||
|
|
||||||
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
|
||||||
|
|
||||||
static Value CreateMap(Value& keys, Value& values);
|
|
||||||
static Value CreateSequence(std::vector<Value>& values);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static Value CreateOpaque(const char* domain, const char* type_name, const T&);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void GetOpaqueData(const char* domain, const char* type_name, T&) const;
|
|
||||||
|
|
||||||
explicit Value(std::nullptr_t) {}
|
|
||||||
explicit Value(OrtValue* p) : Base<OrtValue>{p} {}
|
|
||||||
Value(Value&&) = default;
|
|
||||||
Value& operator=(Value&&) = default;
|
|
||||||
|
|
||||||
bool IsTensor() const;
|
|
||||||
|
|
||||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
||||||
/// <summary>
|
|
||||||
/// Returns true if the OrtValue contains a sparse tensor
|
|
||||||
/// </summary>
|
|
||||||
/// <returns></returns>
|
|
||||||
bool IsSparseTensor() const;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
|
|
||||||
Value GetValue(int index, OrtAllocator* allocator) const;
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
|
|
||||||
/// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
|
|
||||||
/// for allocating necessary memory and calling GetStringTensorContent().
|
|
||||||
/// </summary>
|
|
||||||
/// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
|
|
||||||
size_t GetStringTensorDataLength() const;
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
|
|
||||||
/// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
|
|
||||||
/// The user must also allocate offsets buffer with the number of entries equal to that of the contained
|
|
||||||
/// strings.
|
|
||||||
///
|
|
||||||
/// Strings are always assumed to be on CPU, no X-device copy.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="buffer">user allocated buffer</param>
|
|
||||||
/// <param name="buffer_length">length in bytes of the allocated buffer</param>
|
|
||||||
/// <param name="offsets">a pointer to the offsets user allocated buffer</param>
|
|
||||||
/// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
|
|
||||||
/// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
|
|
||||||
/// for sparse tensors</param>
|
|
||||||
void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T* GetTensorMutableData();
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
const T* GetTensorData() const;
|
|
||||||
|
|
||||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
||||||
/// <summary>
|
|
||||||
/// The API returns a pointer to an internal buffer of the sparse tensor
|
|
||||||
/// containing non-zero values. The API merely does casting. Make sure you
|
|
||||||
/// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
|
|
||||||
/// first.
|
|
||||||
/// </summary>
|
|
||||||
/// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
|
|
||||||
/// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
|
|
||||||
template <typename T>
|
|
||||||
const T* GetSparseTensorValues() const;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T& At(const std::vector<int64_t>& location);
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API returns type information for data contained in a tensor. For sparse
|
|
||||||
/// tensors it returns type information for contained non-zero values.
|
|
||||||
/// It returns dense shape for sparse tensors.
|
|
||||||
/// </summary>
|
|
||||||
/// <returns>TypeInfo</returns>
|
|
||||||
TypeInfo GetTypeInfo() const;
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API returns type information for data contained in a tensor. For sparse
|
|
||||||
/// tensors it returns type information for contained non-zero values.
|
|
||||||
/// It returns dense shape for sparse tensors.
|
|
||||||
/// </summary>
|
|
||||||
/// <returns>TensorTypeAndShapeInfo</returns>
|
|
||||||
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API returns a byte length of UTF-8 encoded string element
|
|
||||||
/// contained in either a tensor or a spare tensor values.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="element_index"></param>
|
|
||||||
/// <returns>byte length for the specified string element</returns>
|
|
||||||
size_t GetStringTensorElementLength(size_t element_index) const;
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// The API copies UTF-8 encoded bytes for the requested string element
|
|
||||||
/// contained within a tensor or a sparse tensor into a provided buffer.
|
|
||||||
/// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="buffer_length"></param>
|
|
||||||
/// <param name="element_index"></param>
|
|
||||||
/// <param name="buffer"></param>
|
|
||||||
void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
|
|
||||||
|
|
||||||
void FillStringTensor(const char* const* s, size_t s_len);
|
|
||||||
void FillStringTensorElement(const char* s, size_t index);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Represents native memory allocation
|
|
||||||
struct MemoryAllocation {
|
|
||||||
MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
|
|
||||||
~MemoryAllocation();
|
|
||||||
MemoryAllocation(const MemoryAllocation&) = delete;
|
|
||||||
MemoryAllocation& operator=(const MemoryAllocation&) = delete;
|
|
||||||
MemoryAllocation(MemoryAllocation&&) noexcept;
|
|
||||||
MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
|
|
||||||
|
|
||||||
void* get() { return p_; }
|
|
||||||
size_t size() const { return size_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
OrtAllocator* allocator_;
|
|
||||||
void* p_;
|
|
||||||
size_t size_;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct AllocatorWithDefaultOptions {
|
|
||||||
AllocatorWithDefaultOptions();
|
|
||||||
|
|
||||||
operator OrtAllocator*() { return p_; }
|
|
||||||
operator const OrtAllocator*() const { return p_; }
|
|
||||||
|
|
||||||
void* Alloc(size_t size);
|
|
||||||
// The return value will own the allocation
|
|
||||||
MemoryAllocation GetAllocation(size_t size);
|
|
||||||
void Free(void* p);
|
|
||||||
|
|
||||||
const OrtMemoryInfo* GetInfo() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
OrtAllocator* p_{};
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename B>
|
|
||||||
struct BaseMemoryInfo : B {
|
|
||||||
BaseMemoryInfo() = default;
|
|
||||||
explicit BaseMemoryInfo(typename B::contained_type* p) : B(p) {}
|
|
||||||
~BaseMemoryInfo() = default;
|
|
||||||
BaseMemoryInfo(BaseMemoryInfo&&) = default;
|
|
||||||
BaseMemoryInfo& operator=(BaseMemoryInfo&&) = default;
|
|
||||||
|
|
||||||
std::string GetAllocatorName() const;
|
|
||||||
OrtAllocatorType GetAllocatorType() const;
|
|
||||||
int GetDeviceId() const;
|
|
||||||
OrtMemType GetMemoryType() const;
|
|
||||||
template <typename U>
|
|
||||||
bool operator==(const BaseMemoryInfo<U>& o) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct UnownedMemoryInfo : BaseMemoryInfo<Base<const OrtMemoryInfo> > {
|
|
||||||
explicit UnownedMemoryInfo(std::nullptr_t) {}
|
|
||||||
explicit UnownedMemoryInfo(const OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct MemoryInfo : BaseMemoryInfo<Base<OrtMemoryInfo> > {
|
|
||||||
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
|
|
||||||
|
|
||||||
explicit MemoryInfo(std::nullptr_t) {}
|
|
||||||
explicit MemoryInfo(OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
|
|
||||||
MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
|
|
||||||
};
|
|
||||||
|
|
||||||
struct Allocator : public Base<OrtAllocator> {
|
|
||||||
Allocator(const Session& session, const MemoryInfo&);
|
|
||||||
|
|
||||||
void* Alloc(size_t size) const;
|
|
||||||
// The return value will own the allocation
|
|
||||||
MemoryAllocation GetAllocation(size_t size);
|
|
||||||
void Free(void* p) const;
|
|
||||||
UnownedMemoryInfo GetInfo() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct IoBinding : public Base<OrtIoBinding> {
|
|
||||||
private:
|
|
||||||
std::vector<std::string> GetOutputNamesHelper(OrtAllocator*) const;
|
|
||||||
std::vector<Value> GetOutputValuesHelper(OrtAllocator*) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
explicit IoBinding(Session& session);
|
|
||||||
void BindInput(const char* name, const Value&);
|
|
||||||
void BindOutput(const char* name, const Value&);
|
|
||||||
void BindOutput(const char* name, const MemoryInfo&);
|
|
||||||
std::vector<std::string> GetOutputNames() const;
|
|
||||||
std::vector<std::string> GetOutputNames(Allocator&) const;
|
|
||||||
std::vector<Value> GetOutputValues() const;
|
|
||||||
std::vector<Value> GetOutputValues(Allocator&) const;
|
|
||||||
void ClearBoundInputs();
|
|
||||||
void ClearBoundOutputs();
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \struct Ort::ArenaCfg
|
|
||||||
* \brief it is a structure that represents the configuration of an arena based allocator
|
|
||||||
* \details Please see docs/C_API.md for details
|
|
||||||
*/
|
|
||||||
struct ArenaCfg : Base<OrtArenaCfg> {
|
|
||||||
explicit ArenaCfg(std::nullptr_t) {}
|
|
||||||
/**
|
|
||||||
* \param max_mem - use 0 to allow ORT to choose the default
|
|
||||||
* \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
|
|
||||||
* \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
|
|
||||||
* \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
|
|
||||||
* See docs/C_API.md for details on what the following parameters mean and how to choose these values
|
|
||||||
*/
|
|
||||||
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
|
||||||
// Custom OPs (only needed to implement custom OPs)
|
|
||||||
//
|
|
||||||
|
|
||||||
struct CustomOpApi {
|
|
||||||
CustomOpApi(const OrtApi& api) : api_(api) {}
|
|
||||||
|
|
||||||
template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
|
|
||||||
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
|
|
||||||
|
|
||||||
OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
|
|
||||||
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
|
|
||||||
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
|
|
||||||
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
|
|
||||||
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
|
|
||||||
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T* GetTensorMutableData(_Inout_ OrtValue* value);
|
|
||||||
template <typename T>
|
|
||||||
const T* GetTensorData(_Inout_ const OrtValue* value);
|
|
||||||
|
|
||||||
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
|
|
||||||
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
|
|
||||||
size_t KernelContext_GetInputCount(const OrtKernelContext* context);
|
|
||||||
const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
|
|
||||||
size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
|
|
||||||
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
|
|
||||||
|
|
||||||
void ThrowOnError(OrtStatus* result);
|
|
||||||
|
|
||||||
private:
|
|
||||||
const OrtApi& api_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TOp, typename TKernel>
|
|
||||||
struct CustomOpBase : OrtCustomOp {
|
|
||||||
CustomOpBase() {
|
|
||||||
OrtCustomOp::version = ORT_API_VERSION;
|
|
||||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
|
|
||||||
OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
|
|
||||||
|
|
||||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
|
|
||||||
|
|
||||||
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
|
|
||||||
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
|
|
||||||
|
|
||||||
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
|
|
||||||
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
|
|
||||||
|
|
||||||
|
|
||||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
|
|
||||||
#if defined(_MSC_VER) && !defined(__clang__)
|
|
||||||
#pragma warning(push)
|
|
||||||
#pragma warning(disable : 26409)
|
|
||||||
#endif
|
|
||||||
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
|
|
||||||
#if defined(_MSC_VER) && !defined(__clang__)
|
|
||||||
#pragma warning(pop)
|
|
||||||
#endif
|
|
||||||
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
|
|
||||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... Args>
|
|
||||||
TKernel* CreateKernelImpl(Args&&... args) const {
|
|
||||||
#if defined(_MSC_VER) && !defined(__clang__)
|
|
||||||
#pragma warning(push)
|
|
||||||
#pragma warning(disable : 26409)
|
|
||||||
#endif
|
|
||||||
return new TKernel(std::forward<Args>(args)...);
|
|
||||||
#if defined(_MSC_VER) && !defined(__clang__)
|
|
||||||
#pragma warning(pop)
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
|
||||||
return CreateKernelImpl(api);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
|
|
||||||
const char* GetExecutionProviderType() const { return nullptr; }
|
|
||||||
|
|
||||||
// Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
|
|
||||||
// (inputs and outputs are required by default)
|
|
||||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
|
|
||||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
|
||||||
}
|
|
||||||
|
|
||||||
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
|
|
||||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace Ort
|
|
||||||
|
|
||||||
#include "onnxruntime_cxx_inline.h"
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,306 @@
|
||||||
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
|
//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
|
||||||
|
// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
|
||||||
|
// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <cstddef>
|
||||||
|
#include <array>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#ifdef ORT_NO_EXCEPTIONS
|
||||||
|
#include <iostream>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "onnxruntime_c_api.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace OrtW {
|
||||||
|
|
||||||
|
// All C++ methods that can fail will throw an exception of this type
|
||||||
|
struct Exception : std::exception {
|
||||||
|
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
|
||||||
|
|
||||||
|
OrtErrorCode GetOrtErrorCode() const { return code_; }
|
||||||
|
const char* what() const noexcept override { return message_.c_str(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string message_;
|
||||||
|
OrtErrorCode code_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef ORT_NO_EXCEPTIONS
|
||||||
|
#define ORTX_CXX_API_THROW(string, code) \
|
||||||
|
do { \
|
||||||
|
std::cerr << OrtW::Exception(string, code) \
|
||||||
|
.what() \
|
||||||
|
<< std::endl; \
|
||||||
|
abort(); \
|
||||||
|
} while (false)
|
||||||
|
#else
|
||||||
|
#define ORTX_CXX_API_THROW(string, code) \
|
||||||
|
throw OrtW::Exception(string, code)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) {
|
||||||
|
if (status) {
|
||||||
|
std::string error_message = ort.GetErrorMessage(status);
|
||||||
|
OrtErrorCode error_code = ort.GetErrorCode(status);
|
||||||
|
ort.ReleaseStatus(status);
|
||||||
|
ORTX_CXX_API_THROW(std::move(error_message), error_code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// Custom OPs (only needed to implement custom OPs)
|
||||||
|
//
|
||||||
|
struct CustomOpApi {
|
||||||
|
CustomOpApi(const OrtApi& api) : api_(api) {}
|
||||||
|
|
||||||
|
template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
|
||||||
|
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
|
||||||
|
|
||||||
|
OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
|
||||||
|
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
|
||||||
|
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
|
||||||
|
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
|
||||||
|
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const;
|
||||||
|
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T* GetTensorMutableData(_Inout_ OrtValue* value) const;
|
||||||
|
template <typename T>
|
||||||
|
const T* GetTensorData(_Inout_ const OrtValue* value) const;
|
||||||
|
|
||||||
|
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
|
||||||
|
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
|
||||||
|
size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
|
||||||
|
const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
|
||||||
|
size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
|
||||||
|
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) const;
|
||||||
|
|
||||||
|
void ThrowOnError(OrtStatus* status) const {
|
||||||
|
OrtW::ThrowOnError(api_, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const OrtApi& api_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename TOp, typename TKernel>
|
||||||
|
struct CustomOpBase : OrtCustomOp {
|
||||||
|
CustomOpBase() {
|
||||||
|
OrtCustomOp::version = ORT_API_VERSION;
|
||||||
|
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
|
||||||
|
OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
|
||||||
|
|
||||||
|
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
|
||||||
|
|
||||||
|
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
|
||||||
|
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
|
||||||
|
|
||||||
|
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
|
||||||
|
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
|
||||||
|
|
||||||
|
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
|
||||||
|
#if defined(_MSC_VER) && !defined(__clang__)
|
||||||
|
#pragma warning(push)
|
||||||
|
#pragma warning(disable : 26409)
|
||||||
|
#endif
|
||||||
|
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
|
||||||
|
#if defined(_MSC_VER) && !defined(__clang__)
|
||||||
|
#pragma warning(pop)
|
||||||
|
#endif
|
||||||
|
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
|
||||||
|
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
TKernel* CreateKernelImpl(Args&&... args) const {
|
||||||
|
#if defined(_MSC_VER) && !defined(__clang__)
|
||||||
|
#pragma warning(push)
|
||||||
|
#pragma warning(disable : 26409)
|
||||||
|
#endif
|
||||||
|
return new TKernel(std::forward<Args>(args)...);
|
||||||
|
#if defined(_MSC_VER) && !defined(__clang__)
|
||||||
|
#pragma warning(pop)
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
|
return CreateKernelImpl(api);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
|
||||||
|
const char* GetExecutionProviderType() const { return nullptr; }
|
||||||
|
|
||||||
|
// Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
|
||||||
|
// (inputs and outputs are required by default)
|
||||||
|
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
|
||||||
|
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
|
||||||
|
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// Custom OP API Inlines
|
||||||
|
//
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
|
||||||
|
float out;
|
||||||
|
ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
|
||||||
|
int64_t out;
|
||||||
|
ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
|
||||||
|
size_t size = 0;
|
||||||
|
std::string out;
|
||||||
|
|
||||||
|
// Feed nullptr for the data buffer to query the true size of the string attribute
|
||||||
|
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
|
||||||
|
|
||||||
|
if (status == nullptr) {
|
||||||
|
out.resize(size);
|
||||||
|
ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
|
||||||
|
out.resize(size - 1); // remove the terminating character '\0'
|
||||||
|
} else {
|
||||||
|
ThrowOnError(status);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
|
||||||
|
size_t size = 0;
|
||||||
|
std::vector<float> out;
|
||||||
|
|
||||||
|
// Feed nullptr for the data buffer to query the true size of the attribute
|
||||||
|
OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
|
||||||
|
|
||||||
|
if (status == nullptr) {
|
||||||
|
out.resize(size);
|
||||||
|
ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
|
||||||
|
} else {
|
||||||
|
ThrowOnError(status);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
|
||||||
|
size_t size = 0;
|
||||||
|
std::vector<int64_t> out;
|
||||||
|
|
||||||
|
// Feed nullptr for the data buffer to query the true size of the attribute
|
||||||
|
OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
|
||||||
|
|
||||||
|
if (status == nullptr) {
|
||||||
|
out.resize(size);
|
||||||
|
ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
|
||||||
|
} else {
|
||||||
|
ThrowOnError(status);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
|
||||||
|
OrtTensorTypeAndShapeInfo* out;
|
||||||
|
ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
|
||||||
|
size_t out;
|
||||||
|
ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
|
||||||
|
ONNXTensorElementDataType out;
|
||||||
|
ThrowOnError(api_.GetTensorElementType(info, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
|
||||||
|
size_t out;
|
||||||
|
ThrowOnError(api_.GetDimensionsCount(info, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
|
||||||
|
ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
|
||||||
|
ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
|
||||||
|
T* data;
|
||||||
|
ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
|
||||||
|
return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
|
||||||
|
std::vector<int64_t> output(GetDimensionsCount(info));
|
||||||
|
GetDimensions(info, output.data(), output.size());
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
|
||||||
|
api_.ReleaseTensorTypeAndShapeInfo(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
|
||||||
|
size_t out;
|
||||||
|
ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
|
||||||
|
const OrtValue* out;
|
||||||
|
ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
|
||||||
|
size_t out;
|
||||||
|
ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
|
||||||
|
_In_ const int64_t* dim_values, size_t dim_count) const {
|
||||||
|
OrtValue* out;
|
||||||
|
ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace OrtW
|
|
@ -16,7 +16,7 @@ struct KernelGaussianBlur : BaseKernel {
|
||||||
const OrtValue* input_ksize = ort_.KernelContext_GetInput(context, 1);
|
const OrtValue* input_ksize = ort_.KernelContext_GetInput(context, 1);
|
||||||
OrtTensorDimensions dim_ksize(ort_, input_ksize);
|
OrtTensorDimensions dim_ksize(ort_, input_ksize);
|
||||||
if (dim_ksize.size() != 1 || dim_ksize[0] != 2) {
|
if (dim_ksize.size() != 1 || dim_ksize[0] != 2) {
|
||||||
ORT_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
std::copy_n(ort_.GetTensorData<std::int64_t>(input_ksize), 2, ksize);
|
std::copy_n(ort_.GetTensorData<std::int64_t>(input_ksize), 2, ksize);
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ struct KernelGaussianBlur : BaseKernel {
|
||||||
const OrtValue* input_sigma = ort_.KernelContext_GetInput(context, 2);
|
const OrtValue* input_sigma = ort_.KernelContext_GetInput(context, 2);
|
||||||
OrtTensorDimensions dim_sigma(ort_, input_sigma);
|
OrtTensorDimensions dim_sigma(ort_, input_sigma);
|
||||||
if (dim_sigma.size() != 1 || dim_sigma[0] != 2) {
|
if (dim_sigma.size() != 1 || dim_sigma[0] != 2) {
|
||||||
ORT_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
std::copy_n(ort_.GetTensorData<double>(input_sigma), 2, sigma);
|
std::copy_n(ort_.GetTensorData<double>(input_sigma), 2, sigma);
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ struct KernelGaussianBlur : BaseKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpGaussianBlur : Ort::CustomOpBase<CustomOpGaussianBlur, KernelGaussianBlur> {
|
struct CustomOpGaussianBlur : OrtW::CustomOpBase<CustomOpGaussianBlur, KernelGaussianBlur> {
|
||||||
size_t GetInputTypeCount() const {
|
size_t GetInputTypeCount() const {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ struct KernelImageDecoder : BaseKernel {
|
||||||
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
|
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
|
||||||
OrtTensorDimensions dimensions(ort_, inputs);
|
OrtTensorDimensions dimensions(ort_, inputs);
|
||||||
if (dimensions.size() != 1ULL) {
|
if (dimensions.size() != 1ULL) {
|
||||||
ORT_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get data & the length
|
// Get data & the length
|
||||||
|
@ -48,7 +48,7 @@ struct KernelImageDecoder : BaseKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpImageDecoder : Ort::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
|
struct CustomOpImageDecoder : OrtW::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelImageDecoder(api);
|
return new KernelImageDecoder(api);
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ struct CustomOpImageDecoder : Ort::CustomOpBase<CustomOpImageDecoder, KernelImag
|
||||||
case 0:
|
case 0:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ struct CustomOpImageDecoder : Ort::CustomOpBase<CustomOpImageDecoder, KernelImag
|
||||||
case 0:
|
case 0:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -14,7 +14,7 @@ struct KernelImageReader : BaseKernel {
|
||||||
|
|
||||||
int n = input_data_dimensions[0];
|
int n = input_data_dimensions[0];
|
||||||
if (n != 1) {
|
if (n != 1) {
|
||||||
ORT_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> image_paths;
|
std::vector<std::string> image_paths;
|
||||||
|
@ -28,7 +28,7 @@ struct KernelImageReader : BaseKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpImageReader : Ort::CustomOpBase<CustomOpImageReader, KernelImageReader> {
|
struct CustomOpImageReader : OrtW::CustomOpBase<CustomOpImageReader, KernelImageReader> {
|
||||||
size_t GetInputTypeCount() const {
|
size_t GetInputTypeCount() const {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,7 +81,7 @@ ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetInputType(size_
|
||||||
case 2:
|
case 2:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,6 +94,6 @@ ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetOutputType(size
|
||||||
case 0:
|
case 0:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ struct KernelSuperResolutionPostProcess : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpSuperResolutionPostProcess : Ort::CustomOpBase<CustomOpSuperResolutionPostProcess, KernelSuperResolutionPostProcess> {
|
struct CustomOpSuperResolutionPostProcess : OrtW::CustomOpBase<CustomOpSuperResolutionPostProcess, KernelSuperResolutionPostProcess> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -71,7 +71,7 @@ ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetInputType(size_t
|
||||||
case 0:
|
case 0:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,6 +86,6 @@ ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetOutputType(size_
|
||||||
case 2:
|
case 2:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ struct KernelSuperResolutionPreProcess : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpSuperResolutionPreProcess : Ort::CustomOpBase<CustomOpSuperResolutionPreProcess, KernelSuperResolutionPreProcess> {
|
struct CustomOpSuperResolutionPreProcess : OrtW::CustomOpBase<CustomOpSuperResolutionPreProcess, KernelSuperResolutionPreProcess> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -33,7 +33,7 @@ struct KernelInverse : BaseKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> {
|
struct CustomOpInverse : OrtW::CustomOpBase<CustomOpInverse, KernelInverse> {
|
||||||
const char* GetName() const {
|
const char* GetName() const {
|
||||||
return "Inverse";
|
return "Inverse";
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ struct KernelNegPos : BaseKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
struct CustomOpNegPos : OrtW::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
||||||
const char* GetName() const{
|
const char* GetName() const{
|
||||||
return "NegPos";
|
return "NegPos";
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
||||||
const int64_t* p_data = ort_.GetTensorData<int64_t>(input);
|
const int64_t* p_data = ort_.GetTensorData<int64_t>(input);
|
||||||
OrtTensorDimensions input_dim(ort_, input);
|
OrtTensorDimensions input_dim(ort_, input);
|
||||||
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
|
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
|
||||||
ORT_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n]." , ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n]." , ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::int64_t> segment_value;
|
std::vector<std::int64_t> segment_value;
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelSegmentExtraction : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpSegmentExtraction : Ort::CustomOpBase<CustomOpSegmentExtraction, KernelSegmentExtraction> {
|
struct CustomOpSegmentExtraction : OrtW::CustomOpBase<CustomOpSegmentExtraction, KernelSegmentExtraction> {
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
#include "segment_sum.hpp"
|
#include "segment_sum.hpp"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) {
|
void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
|
||||||
// Setup inputs
|
// Setup inputs
|
||||||
const OrtValue* data = ort_.KernelContext_GetInput(context, 0);
|
const OrtValue* data = ort_.KernelContext_GetInput(context, 0);
|
||||||
const T* p_data = ort_.GetTensorData<T>(data);
|
const T* p_data = ort_.GetTensorData<T>(data);
|
||||||
|
@ -15,11 +15,11 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
|
||||||
OrtTensorDimensions dim_data(ort_, data);
|
OrtTensorDimensions dim_data(ort_, data);
|
||||||
OrtTensorDimensions dim_seg(ort_, segment_ids);
|
OrtTensorDimensions dim_seg(ort_, segment_ids);
|
||||||
if (dim_data.size() == 0 || dim_seg.size() == 0)
|
if (dim_data.size() == 0 || dim_seg.size() == 0)
|
||||||
ORT_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
|
||||||
if (dim_seg.size() != 1)
|
if (dim_seg.size() != 1)
|
||||||
ORT_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
|
||||||
if (dim_data[0] != dim_seg[0])
|
if (dim_data[0] != dim_seg[0])
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
|
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
|
||||||
" segment_ids shape: ", dim_seg), ORT_INVALID_GRAPH);
|
" segment_ids shape: ", dim_seg), ORT_INVALID_GRAPH);
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
|
||||||
const int64_t* p_seg = p_segment_ids;
|
const int64_t* p_seg = p_segment_ids;
|
||||||
for (; begin != end; ++p_seg) {
|
for (; begin != end; ++p_seg) {
|
||||||
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
|
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
|
||||||
ORT_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
|
ORTX_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
|
||||||
*(p_seg - 1), " and ", *p_seg, " at position ",
|
*(p_seg - 1), " and ", *p_seg, " at position ",
|
||||||
std::distance(p_segment_ids, p_seg), "."), ORT_RUNTIME_EXCEPTION);
|
std::distance(p_segment_ids, p_seg), "."), ORT_RUNTIME_EXCEPTION);
|
||||||
p_out = p_output + *p_seg * in_stride;
|
p_out = p_output + *p_seg * in_stride;
|
||||||
|
@ -82,6 +82,6 @@ ONNXTensorElementDataType CustomOpSegmentSum::GetInputType(size_t index) const {
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelSegmentSum : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpSegmentSum : Ort::CustomOpBase<CustomOpSegmentSum, KernelSegmentSum> {
|
struct CustomOpSegmentSum : OrtW::CustomOpBase<CustomOpSegmentSum, KernelSegmentSum> {
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
bool BaseKernel::HasAttribute(const char* name) const {
|
bool BaseKernel::HasAttribute(const char* name) const {
|
||||||
if (info_ == nullptr) {
|
if (info_ == nullptr) {
|
||||||
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
size_t size;
|
size_t size;
|
||||||
std::string out;
|
std::string out;
|
||||||
|
@ -46,7 +46,7 @@ void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std:
|
||||||
template <>
|
template <>
|
||||||
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
||||||
if (info_ == nullptr) {
|
if (info_ == nullptr) {
|
||||||
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
|
@ -71,7 +71,7 @@ bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
||||||
template <>
|
template <>
|
||||||
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
|
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
|
||||||
if (info_ == nullptr) {
|
if (info_ == nullptr) {
|
||||||
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
|
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
|
||||||
|
@ -80,7 +80,7 @@ bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
|
||||||
template <>
|
template <>
|
||||||
bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
|
bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
|
||||||
if (info_ == nullptr) {
|
if (info_ == nullptr) {
|
||||||
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
|
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
|
||||||
|
@ -89,7 +89,7 @@ bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
|
||||||
template <>
|
template <>
|
||||||
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) {
|
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) {
|
||||||
if (info_ == nullptr) {
|
if (info_ == nullptr) {
|
||||||
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t origin_value = 0;
|
int64_t origin_value = 0;
|
||||||
|
|
|
@ -3,17 +3,17 @@
|
||||||
#include "string_utils.h"
|
#include "string_utils.h"
|
||||||
#include "string_tensor.h"
|
#include "string_tensor.h"
|
||||||
|
|
||||||
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||||
const OrtValue* value, std::vector<std::string>& output) {
|
const OrtValue* value, std::vector<std::string>& output) {
|
||||||
(void)context;
|
(void)context;
|
||||||
OrtTensorDimensions dimensions(ort, value);
|
OrtTensorDimensions dimensions(ort, value);
|
||||||
size_t len = static_cast<size_t>(dimensions.Size());
|
size_t len = static_cast<size_t>(dimensions.Size());
|
||||||
size_t data_len;
|
size_t data_len;
|
||||||
Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
|
OrtW::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
|
||||||
output.resize(len);
|
output.resize(len);
|
||||||
std::vector<char> result(data_len + len + 1, '\0');
|
std::vector<char> result(data_len + len + 1, '\0');
|
||||||
std::vector<size_t> offsets(len);
|
std::vector<size_t> offsets(len);
|
||||||
Ort::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
|
OrtW::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
|
||||||
output.resize(len);
|
output.resize(len);
|
||||||
for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
|
for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
|
||||||
if (i < static_cast<int64_t>(len) - 1)
|
if (i < static_cast<int64_t>(len) - 1)
|
||||||
|
@ -22,7 +22,7 @@ void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||||
const std::vector<std::string>& value, OrtValue* output) {
|
const std::vector<std::string>& value, OrtValue* output) {
|
||||||
(void)ort;
|
(void)ort;
|
||||||
(void)context;
|
(void)context;
|
||||||
|
@ -31,10 +31,10 @@ void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelCon
|
||||||
temp[i] = value[i].c_str();
|
temp[i] = value[i].c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
Ort::ThrowOnError(api,api.FillStringTensor(output, temp.data(), value.size()));
|
OrtW::ThrowOnError(api,api.FillStringTensor(output, temp.data(), value.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||||
const OrtValue* value, std::vector<ustring>& output) {
|
const OrtValue* value, std::vector<ustring>& output) {
|
||||||
std::vector<std::string> utf8_strings;
|
std::vector<std::string> utf8_strings;
|
||||||
GetTensorMutableDataString(api, ort, context, value, utf8_strings);
|
GetTensorMutableDataString(api, ort, context, value, utf8_strings);
|
||||||
|
@ -46,7 +46,7 @@ void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||||
const std::vector<ustring>& value, OrtValue* output) {
|
const std::vector<ustring>& value, OrtValue* output) {
|
||||||
std::vector<std::string> utf8_strings;
|
std::vector<std::string> utf8_strings;
|
||||||
utf8_strings.reserve(value.size());
|
utf8_strings.reserve(value.size());
|
||||||
|
|
|
@ -10,14 +10,14 @@
|
||||||
|
|
||||||
// Retrieves a vector of strings if the input type is std::string.
|
// Retrieves a vector of strings if the input type is std::string.
|
||||||
// It is a copy of the input data and can be modified to compute the output.
|
// It is a copy of the input data and can be modified to compute the output.
|
||||||
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||||
const OrtValue* value, std::vector<std::string>& output);
|
const OrtValue* value, std::vector<std::string>& output);
|
||||||
|
|
||||||
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||||
const OrtValue* value, std::vector<ustring>& output);
|
const OrtValue* value, std::vector<ustring>& output);
|
||||||
|
|
||||||
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||||
const std::vector<std::string>& value, OrtValue* output);
|
const std::vector<std::string>& value, OrtValue* output);
|
||||||
|
|
||||||
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context,
|
void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
|
||||||
const std::vector<ustring>& value, OrtValue* output);
|
const std::vector<ustring>& value, OrtValue* output);
|
||||||
|
|
|
@ -21,11 +21,11 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
||||||
OrtTensorDimensions mask_dimensions(ort_, input_mask);
|
OrtTensorDimensions mask_dimensions(ort_, input_mask);
|
||||||
|
|
||||||
if (!(value_dimensions.IsScalar() || value_dimensions.IsVector())) {
|
if (!(value_dimensions.IsScalar() || value_dimensions.IsVector())) {
|
||||||
ORT_CXX_API_THROW("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (value_dimensions != mask_dimensions) {
|
if (value_dimensions != mask_dimensions) {
|
||||||
ORT_CXX_API_THROW("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> value;
|
std::vector<std::string> value;
|
||||||
|
@ -68,7 +68,7 @@ ONNXTensorElementDataType CustomOpMaskedFill::GetInputType(size_t index) const {
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}};
|
}};
|
||||||
|
|
||||||
size_t CustomOpMaskedFill::GetOutputTypeCount() const {
|
size_t CustomOpMaskedFill::GetOutputTypeCount() const {
|
||||||
|
|
|
@ -14,7 +14,7 @@ struct KernelMaskedFill : BaseKernel {
|
||||||
std::unordered_map<std::string, std::string> map_;
|
std::unordered_map<std::string, std::string> map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpMaskedFill : Ort::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
|
struct CustomOpMaskedFill : OrtW::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelStringEqual : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringEqual : Ort::CustomOpBase<CustomOpStringEqual, KernelStringEqual> {
|
struct CustomOpStringEqual : OrtW::CustomOpBase<CustomOpStringEqual, KernelStringEqual> {
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
|
|
|
@ -13,7 +13,7 @@ class BroadcastIteratorRight {
|
||||||
const std::vector<int64_t>& shape2,
|
const std::vector<int64_t>& shape2,
|
||||||
const T1* p1, const T2* p2, T3* p3) : shape1_(shape1), p1_(p1), p2_(p2), p3_(p3) {
|
const T1* p1, const T2* p2, T3* p3) : shape1_(shape1), p1_(p1), p2_(p2), p3_(p3) {
|
||||||
if (shape2.size() > shape1.size())
|
if (shape2.size() > shape1.size())
|
||||||
ORT_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT);
|
||||||
shape2_.resize(shape1_.size());
|
shape2_.resize(shape1_.size());
|
||||||
cum_shape2_.resize(shape1_.size());
|
cum_shape2_.resize(shape1_.size());
|
||||||
total_ = 1;
|
total_ = 1;
|
||||||
|
@ -26,7 +26,7 @@ class BroadcastIteratorRight {
|
||||||
shape2_[i] = shape2[i];
|
shape2_[i] = shape2[i];
|
||||||
}
|
}
|
||||||
if (shape2[i] != 1 && shape1[i] != shape2[i]) {
|
if (shape2[i] != 1 && shape1[i] != shape2[i]) {
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT);
|
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,7 @@ class BroadcastIteratorRight {
|
||||||
template <typename TCMP>
|
template <typename TCMP>
|
||||||
void loop(TCMP& cmp, BroadcastIteratorRightState& /*it*/, int64_t pos = 0) {
|
void loop(TCMP& cmp, BroadcastIteratorRightState& /*it*/, int64_t pos = 0) {
|
||||||
if (pos != 0)
|
if (pos != 0)
|
||||||
ORT_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED);
|
ORTX_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED);
|
||||||
while (!end()) {
|
while (!end()) {
|
||||||
*p3 = cmp(*p1, *p2);
|
*p3 = cmp(*p1, *p2);
|
||||||
next();
|
next();
|
||||||
|
@ -114,7 +114,7 @@ inline bool Compare<std::string>::operator()(const std::string& s1, const std::s
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void KernelEqual_Compute(const OrtApi& api, Ort::CustomOpApi& ort_, OrtKernelContext* context) {
|
void KernelEqual_Compute(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
|
||||||
// Setup inputs
|
// Setup inputs
|
||||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||||
const T* X = ort_.GetTensorData<T>(input_X);
|
const T* X = ort_.GetTensorData<T>(input_X);
|
||||||
|
@ -144,7 +144,7 @@ void KernelEqual_Compute(const OrtApi& api, Ort::CustomOpApi& ort_, OrtKernelCon
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void KernelEqual_Compute<std::string>(const OrtApi& api, Ort::CustomOpApi& ort_, OrtKernelContext* context) {
|
void KernelEqual_Compute<std::string>(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
|
||||||
// Setup inputs
|
// Setup inputs
|
||||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||||
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
|
||||||
|
|
|
@ -11,7 +11,7 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
||||||
OrtTensorDimensions d_length(ort_, n_elements);
|
OrtTensorDimensions d_length(ort_, n_elements);
|
||||||
|
|
||||||
if (d_length.size() != 1)
|
if (d_length.size() != 1)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
|
"First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
|
||||||
int64_t n_els = d_length[0] - 1;
|
int64_t n_els = d_length[0] - 1;
|
||||||
int64_t n_values = p_n_elements[n_els];
|
int64_t n_values = p_n_elements[n_els];
|
||||||
|
@ -103,7 +103,7 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||||
for (int64_t i = 0; i < size - 1; ++i) {
|
for (int64_t i = 0; i < size - 1; ++i) {
|
||||||
pos_end = pos + max_col;
|
pos_end = pos + max_col;
|
||||||
if (pos_end > shape_out_size)
|
if (pos_end > shape_out_size)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||||
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
||||||
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
||||||
|
@ -161,7 +161,7 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||||
for (int64_t i = 0; i < size - 1; ++i) {
|
for (int64_t i = 0; i < size - 1; ++i) {
|
||||||
pos_end = pos + max_col;
|
pos_end = pos + max_col;
|
||||||
if (pos_end > shape_out_size)
|
if (pos_end > shape_out_size)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||||
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
||||||
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
||||||
|
@ -203,6 +203,6 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t
|
||||||
case 3:
|
case 3:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -12,7 +12,7 @@ struct KernelRaggedTensorToSparse : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpRaggedTensorToSparse : Ort::CustomOpBase<CustomOpRaggedTensorToSparse, KernelRaggedTensorToSparse> {
|
struct CustomOpRaggedTensorToSparse : OrtW::CustomOpBase<CustomOpRaggedTensorToSparse, KernelRaggedTensorToSparse> {
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
|
@ -36,7 +36,7 @@ struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||||
int64_t missing_value_;
|
int64_t missing_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpRaggedTensorToDense : Ort::CustomOpBase<CustomOpRaggedTensorToDense, KernelRaggedTensorToDense> {
|
struct CustomOpRaggedTensorToDense : OrtW::CustomOpBase<CustomOpRaggedTensorToDense, KernelRaggedTensorToDense> {
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
|
@ -50,7 +50,7 @@ struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringRaggedTensorToDense : Ort::CustomOpBase<CustomOpStringRaggedTensorToDense, KernelStringRaggedTensorToDense> {
|
struct CustomOpStringRaggedTensorToDense : OrtW::CustomOpBase<CustomOpStringRaggedTensorToDense, KernelStringRaggedTensorToDense> {
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
size_t GetOutputTypeCount() const;
|
size_t GetOutputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||||
|
|
|
@ -27,15 +27,15 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
||||||
OrtTensorDimensions pattern_dimensions(ort_, pattern);
|
OrtTensorDimensions pattern_dimensions(ort_, pattern);
|
||||||
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
|
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
|
||||||
if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
|
if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"pattern (second input) must contain only one element. It has ",
|
"pattern (second input) must contain only one element. It has ",
|
||||||
pattern_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
pattern_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||||
if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
|
if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"rewrite (third input) must contain only one element. It has ",
|
"rewrite (third input) must contain only one element. It has ",
|
||||||
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||||
if (str_pattern[0].empty())
|
if (str_pattern[0].empty())
|
||||||
ORT_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||||
|
|
||||||
// Setup output
|
// Setup output
|
||||||
OrtTensorDimensions dimensions(ort_, input);
|
OrtTensorDimensions dimensions(ort_, input);
|
||||||
|
|
|
@ -14,7 +14,7 @@ struct KernelStringRegexReplace : BaseKernel {
|
||||||
int64_t global_replace_;
|
int64_t global_replace_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringRegexReplace : Ort::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
|
struct CustomOpStringRegexReplace : OrtW::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -24,15 +24,15 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
||||||
// Verifications
|
// Verifications
|
||||||
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
|
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
|
||||||
if (str_pattern.size() != 1)
|
if (str_pattern.size() != 1)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"pattern (second input) must contain only one element. It has ",
|
"pattern (second input) must contain only one element. It has ",
|
||||||
str_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
|
str_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
|
||||||
if (str_keep_pattern.size() > 1)
|
if (str_keep_pattern.size() > 1)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"Third input must contain only one element. It has ",
|
"Third input must contain only one element. It has ",
|
||||||
str_keep_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
|
str_keep_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
|
||||||
if (str_pattern[0].empty())
|
if (str_pattern[0].empty())
|
||||||
ORT_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||||
|
|
||||||
OrtTensorDimensions dimensions(ort_, input);
|
OrtTensorDimensions dimensions(ort_, input);
|
||||||
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
|
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
|
||||||
|
@ -106,7 +106,7 @@ ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetOutputType(siz
|
||||||
case 3:
|
case 3:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), ORT_INVALID_ARGUMENT);
|
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -12,7 +12,7 @@ struct KernelStringRegexSplitWithOffsets : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringRegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
|
struct CustomOpStringRegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -20,7 +20,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
|
||||||
OrtTensorDimensions right_dim(ort_, right);
|
OrtTensorDimensions right_dim(ort_, right);
|
||||||
|
|
||||||
if (left_dim != right_dim) {
|
if (left_dim != right_dim) {
|
||||||
ORT_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> left_value;
|
std::vector<std::string> left_value;
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelStringConcat : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringConcat : Ort::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
|
struct CustomOpStringConcat : OrtW::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -29,17 +29,17 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
|
||||||
OrtTensorDimensions pattern_dimensions(ort_, pattern);
|
OrtTensorDimensions pattern_dimensions(ort_, pattern);
|
||||||
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
|
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
|
||||||
if (pattern_dimensions.Size() != 1) {
|
if (pattern_dimensions.Size() != 1) {
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"pattern (second input) must contain only one element. It has ",
|
"pattern (second input) must contain only one element. It has ",
|
||||||
pattern_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
|
pattern_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
if (rewrite_dimensions.Size() != 1) {
|
if (rewrite_dimensions.Size() != 1) {
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"rewrite (third input) must contain only one element. It has ",
|
"rewrite (third input) must contain only one element. It has ",
|
||||||
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
|
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
if (str_pattern[0].empty()) {
|
if (str_pattern[0].empty()) {
|
||||||
ORT_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup output
|
// Setup output
|
||||||
|
|
|
@ -15,7 +15,7 @@ struct KernelStringECMARegexReplace : BaseKernel {
|
||||||
bool ignore_case_;
|
bool ignore_case_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringECMARegexReplace : Ort::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
|
struct CustomOpStringECMARegexReplace : OrtW::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -28,11 +28,11 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
||||||
// Verifications
|
// Verifications
|
||||||
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
|
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
|
||||||
if (str_pattern.size() != 1)
|
if (str_pattern.size() != 1)
|
||||||
ORT_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(), " values."), ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(), " values."), ORT_INVALID_GRAPH);
|
||||||
if (str_keep_pattern.size() > 1)
|
if (str_keep_pattern.size() > 1)
|
||||||
ORT_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(), " values."), ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(), " values."), ORT_INVALID_GRAPH);
|
||||||
if (str_pattern[0].empty())
|
if (str_pattern[0].empty())
|
||||||
ORT_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_GRAPH);
|
||||||
|
|
||||||
OrtTensorDimensions dimensions(ort_, input);
|
OrtTensorDimensions dimensions(ort_, input);
|
||||||
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
|
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
|
||||||
|
@ -111,7 +111,7 @@ ONNXTensorElementDataType CustomOpStringECMARegexSplitWithOffsets::GetOutputType
|
||||||
case 3:
|
case 3:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
|
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
|
||||||
ORT_INVALID_ARGUMENT);
|
ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
|
||||||
bool ignore_case_;
|
bool ignore_case_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringECMARegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
|
struct CustomOpStringECMARegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -23,7 +23,7 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
|
||||||
// Verifications
|
// Verifications
|
||||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||||
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"num_buckets must contain only one element. It has ",
|
"num_buckets must contain only one element. It has ",
|
||||||
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetInputType(size_t index) const {
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
||||||
// Verifications
|
// Verifications
|
||||||
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
|
||||||
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"num_buckets must contain only one element. It has ",
|
"num_buckets must contain only one element. It has ",
|
||||||
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ ONNXTensorElementDataType CustomOpStringHashFast::GetInputType(size_t index) con
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelStringHash : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHash> {
|
struct CustomOpStringHash : OrtW::CustomOpBase<CustomOpStringHash, KernelStringHash> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
@ -24,7 +24,7 @@ struct KernelStringHashFast : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringHashFast : Ort::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
|
struct CustomOpStringHashFast : OrtW::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -20,18 +20,18 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
|
||||||
// Check input
|
// Check input
|
||||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||||
ORT_CXX_API_THROW("Input 2 is the separator, it should have 1 element.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Input 2 is the separator, it should have 1 element.", ORT_INVALID_ARGUMENT);
|
||||||
OrtTensorDimensions dimensions_axis(ort_, input_axis);
|
OrtTensorDimensions dimensions_axis(ort_, input_axis);
|
||||||
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
|
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
|
||||||
ORT_CXX_API_THROW("Input 3 is the axis, it should have 1 element.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Input 3 is the axis, it should have 1 element.", ORT_INVALID_ARGUMENT);
|
||||||
OrtTensorDimensions dimensions(ort_, input_X);
|
OrtTensorDimensions dimensions(ort_, input_X);
|
||||||
if (dimensions.size() == 0) {
|
if (dimensions.size() == 0) {
|
||||||
// dimensions size 0 means input 1 is scalar, input 1 must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
|
// dimensions size 0 means input 1 is scalar, input 1 must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
|
||||||
if (X.size() != 1)
|
if (X.size() != 1)
|
||||||
ORT_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT);
|
||||||
} else {
|
} else {
|
||||||
if (*axis < 0 || *axis >= static_cast<int64_t>(dimensions.size()))
|
if (*axis < 0 || *axis >= static_cast<int64_t>(dimensions.size()))
|
||||||
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
|
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
|
||||||
|
@ -102,7 +102,7 @@ ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const {
|
||||||
case 2:
|
case 2:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelStringJoin : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringJoin : Ort::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
|
struct CustomOpStringJoin : OrtW::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelStringLength : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringLength : Ort::CustomOpBase<CustomOpStringLength, KernelStringLength> {
|
struct CustomOpStringLength : OrtW::CustomOpBase<CustomOpStringLength, KernelStringLength> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelStringLower : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringLower : Ort::CustomOpBase<CustomOpStringLower, KernelStringLower> {
|
struct CustomOpStringLower : OrtW::CustomOpBase<CustomOpStringLower, KernelStringLower> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -16,7 +16,7 @@ KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo*
|
||||||
auto items = SplitString(line, "\t", true);
|
auto items = SplitString(line, "\t", true);
|
||||||
|
|
||||||
if (items.size() != 2) {
|
if (items.size() != 2) {
|
||||||
ORT_CXX_API_THROW(std::string("[StringMapping]: Should only exist two items in one line, find error in line: ") + std::string(line), ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW(std::string("[StringMapping]: Should only exist two items in one line, find error in line: ") + std::string(line), ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
map_[std::string(items[0])] = std::string(items[1]);
|
map_[std::string(items[0])] = std::string(items[1]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ struct KernelStringMapping : BaseKernel {
|
||||||
std::unordered_map<std::string, std::string> map_;
|
std::unordered_map<std::string, std::string> map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringMapping : Ort::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
|
struct CustomOpStringMapping : OrtW::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -20,13 +20,13 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||||
// Setup output
|
// Setup output
|
||||||
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
OrtTensorDimensions dimensions_sep(ort_, input_sep);
|
||||||
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
|
||||||
ORT_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||||
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
|
||||||
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
|
||||||
ORT_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT);
|
||||||
OrtTensorDimensions dimensions(ort_, input_X);
|
OrtTensorDimensions dimensions(ort_, input_X);
|
||||||
if (dimensions.size() != 1)
|
if (dimensions.size() != 1)
|
||||||
ORT_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT);
|
||||||
|
|
||||||
std::vector<std::string> words;
|
std::vector<std::string> words;
|
||||||
std::vector<int64_t> indices;
|
std::vector<int64_t> indices;
|
||||||
|
@ -112,7 +112,7 @@ ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const
|
||||||
case 2:
|
case 2:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -128,6 +128,6 @@ ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelStringSplit : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
|
struct CustomOpStringSplit : OrtW::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -41,7 +41,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
||||||
|
|
||||||
vector_len_ = ParseVectorLen(lines[0]);
|
vector_len_ = ParseVectorLen(lines[0]);
|
||||||
if (vector_len_ == 0) {
|
if (vector_len_ == 0) {
|
||||||
ORT_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> values(vector_len_);
|
std::vector<int64_t> values(vector_len_);
|
||||||
|
@ -49,7 +49,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
||||||
auto kv = SplitString(line, "\t", true);
|
auto kv = SplitString(line, "\t", true);
|
||||||
|
|
||||||
if (kv.size() != 2) {
|
if (kv.size() != 2) {
|
||||||
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseValues(kv[1], values);
|
ParseValues(kv[1], values);
|
||||||
|
@ -62,14 +62,14 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
|
||||||
void StringToVectorImpl::ParseUnkownValue(std::string& unk) {
|
void StringToVectorImpl::ParseUnkownValue(std::string& unk) {
|
||||||
auto unk_strs = SplitString(unk, " ", true);
|
auto unk_strs = SplitString(unk, " ", true);
|
||||||
if (unk_strs.size() != vector_len_) {
|
if (unk_strs.size() != vector_len_) {
|
||||||
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& str : unk_strs) {
|
for (auto& str : unk_strs) {
|
||||||
int64_t value;
|
int64_t value;
|
||||||
auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value);
|
auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value);
|
||||||
if (end != str.data() + str.size()) {
|
if (end != str.data() + str.size()) {
|
||||||
ORT_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
unk_value_.push_back(value);
|
unk_value_.push_back(value);
|
||||||
|
@ -80,7 +80,7 @@ size_t StringToVectorImpl::ParseVectorLen(const std::string_view& line) {
|
||||||
auto kv = SplitString(line, "\t", true);
|
auto kv = SplitString(line, "\t", true);
|
||||||
|
|
||||||
if (kv.size() != 2) {
|
if (kv.size() != 2) {
|
||||||
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto value_strs = SplitString(kv[1], " ", true);
|
auto value_strs = SplitString(kv[1], " ", true);
|
||||||
|
@ -94,7 +94,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
|
||||||
for (size_t i = 0; i < value_strs.size(); i++) {
|
for (size_t i = 0; i < value_strs.size(); i++) {
|
||||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||||
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
values[i] = value;
|
values[i] = value;
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ struct KernelStringToVector : BaseKernel {
|
||||||
std::shared_ptr<StringToVectorImpl> impl_;
|
std::shared_ptr<StringToVectorImpl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringToVector : Ort::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
|
struct CustomOpStringToVector : OrtW::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -11,7 +11,7 @@ struct KernelStringUpper : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpStringUpper : Ort::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
|
struct CustomOpStringUpper : OrtW::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||||
|
|
|
@ -29,7 +29,7 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
|
||||||
output_dim = input_dim;
|
output_dim = input_dim;
|
||||||
} else {
|
} else {
|
||||||
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
|
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
|
||||||
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
output_dim = input_dim;
|
output_dim = input_dim;
|
||||||
|
@ -70,7 +70,7 @@ void VectorToStringImpl::ParseMappingTable(std::string& map) {
|
||||||
auto kv = SplitString(line, "\t", true);
|
auto kv = SplitString(line, "\t", true);
|
||||||
|
|
||||||
if (kv.size() != 2) {
|
if (kv.size() != 2) {
|
||||||
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseValues(kv[1], values);
|
ParseValues(kv[1], values);
|
||||||
|
@ -83,7 +83,7 @@ size_t VectorToStringImpl::ParseVectorLen(const std::string_view& line) {
|
||||||
auto kv = SplitString(line, "\t", true);
|
auto kv = SplitString(line, "\t", true);
|
||||||
|
|
||||||
if (kv.size() != 2) {
|
if (kv.size() != 2) {
|
||||||
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto value_strs = SplitString(kv[1], " ", true);
|
auto value_strs = SplitString(kv[1], " ", true);
|
||||||
|
@ -97,7 +97,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
|
||||||
for (size_t i = 0; i < value_strs.size(); i++) {
|
for (size_t i = 0; i < value_strs.size(); i++) {
|
||||||
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
|
||||||
if (end != value_strs[i].data() + value_strs[i].size()) {
|
if (end != value_strs[i].data() + value_strs[i].size()) {
|
||||||
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
values[i] = value;
|
values[i] = value;
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ struct KernelVectorToString : BaseKernel {
|
||||||
std::shared_ptr<VectorToStringImpl> impl_;
|
std::shared_ptr<VectorToStringImpl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpVectorToString : Ort::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
|
struct CustomOpVectorToString : OrtW::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -100,7 +100,7 @@ void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
|
||||||
|
|
||||||
OrtTensorDimensions dimensions(ort_, input);
|
OrtTensorDimensions dimensions(ort_, input);
|
||||||
if (dimensions.size() != 1 && dimensions[0] != 1) {
|
if (dimensions.size() != 1 && dimensions[0] != 1) {
|
||||||
ORT_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
|
||||||
|
|
|
@ -27,7 +27,7 @@ struct KernelBasicTokenizer : BaseKernel {
|
||||||
std::shared_ptr<BasicTokenizer> tokenizer_;
|
std::shared_ptr<BasicTokenizer> tokenizer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -33,7 +33,7 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
|
||||||
|
|
||||||
auto it = vocab_.find(utf8_token);
|
auto it = vocab_.find(utf8_token);
|
||||||
if (it == vocab_.end()) {
|
if (it == vocab_.end()) {
|
||||||
ORT_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
|
ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
|
||||||
}
|
}
|
||||||
|
|
||||||
return it->second;
|
return it->second;
|
||||||
|
@ -305,7 +305,7 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
||||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||||
|
|
||||||
if (input_data.size() != 1 && input_data.size() != 2) {
|
if (input_data.size() != 1 && input_data.size() != 2) {
|
||||||
ORT_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
std::vector<int64_t> input_ids;
|
std::vector<int64_t> input_ids;
|
||||||
std::vector<int64_t> token_type_ids;
|
std::vector<int64_t> token_type_ids;
|
||||||
|
@ -365,7 +365,7 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
||||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||||
|
|
||||||
if (input_data.size() != 2) {
|
if (input_data.size() != 2) {
|
||||||
ORT_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
|
std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));
|
||||||
|
|
|
@ -97,7 +97,7 @@ struct KernelBertTokenizer : BaseKernel {
|
||||||
std::unique_ptr<BertTokenizer> tokenizer_;
|
std::unique_ptr<BertTokenizer> tokenizer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
@ -111,7 +111,7 @@ struct KernelHfBertTokenizer : KernelBertTokenizer {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpHfBertTokenizer : Ort::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -146,7 +146,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
||||||
OrtTensorDimensions ids_dim(ort_, ids);
|
OrtTensorDimensions ids_dim(ort_, ids);
|
||||||
|
|
||||||
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
|
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
|
||||||
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
|
|
||||||
// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
|
// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
|
||||||
|
@ -155,7 +155,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
||||||
if (use_indices_ &&
|
if (use_indices_ &&
|
||||||
(!((positions_dim.Size() == 0) ||
|
(!((positions_dim.Size() == 0) ||
|
||||||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
|
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
|
||||||
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);
|
const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);
|
||||||
|
|
|
@ -42,7 +42,7 @@ struct KernelBertTokenizerDecoder : BaseKernel {
|
||||||
bool clean_up_tokenization_spaces_;
|
bool clean_up_tokenization_spaces_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
|
struct CustomOpBertTokenizerDecoder : OrtW::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -12,13 +12,13 @@
|
||||||
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
||||||
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||||
if (model_data_.empty()) {
|
if (model_data_.empty()) {
|
||||||
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size()));
|
void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size()));
|
||||||
|
|
||||||
if (model_ptr == nullptr) {
|
if (model_ptr == nullptr) {
|
||||||
ORT_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
|
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
|
||||||
|
@ -35,7 +35,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
||||||
|
|
||||||
// TODO: fix this scalar check.
|
// TODO: fix this scalar check.
|
||||||
if (dimensions.Size() != 1 && dimensions[0] != 1) {
|
if (dimensions.Size() != 1 && dimensions[0] != 1) {
|
||||||
ORT_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> input_data;
|
std::vector<std::string> input_data;
|
||||||
|
@ -47,7 +47,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
||||||
|
|
||||||
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
|
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
|
||||||
if (output_length < 0) {
|
if (output_length < 0) {
|
||||||
ORT_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
// inline split output_str by newline '\n'
|
// inline split output_str by newline '\n'
|
||||||
|
@ -75,7 +75,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
||||||
output_dimensions[0] = output_sentences.size();
|
output_dimensions[0] = output_sentences.size();
|
||||||
|
|
||||||
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
|
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
|
||||||
Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
|
OrtW::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
|
|
|
@ -25,7 +25,7 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
|
||||||
int max_sentence;
|
int max_sentence;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
|
struct CustomOpBlingFireSentenceBreaker : OrtW::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -29,7 +29,7 @@ class SpecialTokenMap {
|
||||||
auto it = token_map_.find(p_str);
|
auto it = token_map_.find(p_str);
|
||||||
if (it != token_map_.end()) {
|
if (it != token_map_.end()) {
|
||||||
if (it->second != p_id) {
|
if (it->second != p_id) {
|
||||||
ORT_CXX_API_THROW("Duplicate special tokens.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Duplicate special tokens.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
token_map_[p_str] = p_id;
|
token_map_[p_str] = p_id;
|
||||||
|
@ -84,7 +84,7 @@ class SpecialTokenMap {
|
||||||
SpecialTokenInfo(ustring p_str, int p_id)
|
SpecialTokenInfo(ustring p_str, int p_id)
|
||||||
: str(std::move(p_str)), id(p_id) {
|
: str(std::move(p_str)), id(p_id) {
|
||||||
if (str.empty()) {
|
if (str.empty()) {
|
||||||
ORT_CXX_API_THROW("Empty special token.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Empty special token.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -147,7 +147,7 @@ class VocabData {
|
||||||
if ((line[0] == '#') && (index == 0)) continue;
|
if ((line[0] == '#') && (index == 0)) continue;
|
||||||
auto pos = line.find(' ');
|
auto pos = line.find(' ');
|
||||||
if (pos == std::string::npos) {
|
if (pos == std::string::npos) {
|
||||||
ORT_CXX_API_THROW("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
std::string w1 = line.substr(0, pos);
|
std::string w1 = line.substr(0, pos);
|
||||||
std::string w2 = line.substr(pos + 1);
|
std::string w2 = line.substr(pos + 1);
|
||||||
|
@ -231,14 +231,14 @@ class VocabData {
|
||||||
int TokenToID(const std::string& input) const {
|
int TokenToID(const std::string& input) const {
|
||||||
auto it = vocab_map_.find(input);
|
auto it = vocab_map_.find(input);
|
||||||
if (it == vocab_map_.end()) {
|
if (it == vocab_map_.end()) {
|
||||||
ORT_CXX_API_THROW("Token not found: " + input, ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Token not found: " + input, ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string& IdToToken(int id) const {
|
const std::string& IdToToken(int id) const {
|
||||||
if ((id < 0) || (static_cast<size_t>(id) >= id2token_map_.size())) {
|
if ((id < 0) || (static_cast<size_t>(id) >= id2token_map_.size())) {
|
||||||
ORT_CXX_API_THROW("Invalid ID: " + std::to_string(id), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Invalid ID: " + std::to_string(id), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
return id2token_map_[id];
|
return id2token_map_[id];
|
||||||
}
|
}
|
||||||
|
@ -247,7 +247,7 @@ class VocabData {
|
||||||
int GetVocabIndex(const std::string& str) {
|
int GetVocabIndex(const std::string& str) {
|
||||||
auto it = vocab_map_.find(str);
|
auto it = vocab_map_.find(str);
|
||||||
if (it == vocab_map_.end()) {
|
if (it == vocab_map_.end()) {
|
||||||
ORT_CXX_API_THROW("Cannot find word in vocabulary: " + str, ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("Cannot find word in vocabulary: " + str, ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
@ -467,12 +467,12 @@ KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* i
|
||||||
: BaseKernel(api, info) {
|
: BaseKernel(api, info) {
|
||||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
||||||
if (vocab.empty()) {
|
if (vocab.empty()) {
|
||||||
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges");
|
std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges");
|
||||||
if (merges.empty()) {
|
if (merges.empty()) {
|
||||||
ORT_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!TryToGetAttribute<int64_t>("padding_length", padding_length_)) {
|
if (!TryToGetAttribute<int64_t>("padding_length", padding_length_)) {
|
||||||
|
@ -480,7 +480,7 @@ KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* i
|
||||||
}
|
}
|
||||||
|
|
||||||
if (padding_length_ != -1 && padding_length_ <= 0) {
|
if (padding_length_ != -1 && padding_length_ <= 0) {
|
||||||
ORT_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::stringstream vocabu_stream(vocab);
|
std::stringstream vocabu_stream(vocab);
|
||||||
|
|
|
@ -17,7 +17,7 @@ struct KernelBpeTokenizer : BaseKernel {
|
||||||
std::shared_ptr<VocabData> bbpe_tokenizer_;
|
std::shared_ptr<VocabData> bbpe_tokenizer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
|
struct CustomOpBpeTokenizer : OrtW::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -16,7 +16,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
||||||
model_proto.ParseFromArray(model_blob.data(), static_cast<int>(model_blob.size()));
|
model_proto.ParseFromArray(model_blob.data(), static_cast<int>(model_blob.size()));
|
||||||
sentencepiece::util::Status status = tokenizer_.Load(model_proto);
|
sentencepiece::util::Status status = tokenizer_.Load(model_proto);
|
||||||
if (!status.ok()){
|
if (!status.ok()){
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"Failed to create SentencePieceProcessor instance. Error code is ",
|
"Failed to create SentencePieceProcessor instance. Error code is ",
|
||||||
(int)status.code(), ". Message is '", status.error_message(), "'."),
|
(int)status.code(), ". Message is '", status.error_message(), "'."),
|
||||||
ORT_INVALID_PROTOBUF);
|
ORT_INVALID_PROTOBUF);
|
||||||
|
@ -29,7 +29,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
||||||
OrtTensorDimensions ids_dim(ort_, ids);
|
OrtTensorDimensions ids_dim(ort_, ids);
|
||||||
|
|
||||||
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
|
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
|
||||||
ORT_CXX_API_THROW("[SentencePieceDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
ORTX_CXX_API_THROW("[SentencePieceDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto count = ids_dim[0];
|
auto count = ids_dim[0];
|
||||||
|
@ -41,7 +41,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
||||||
[](auto _id) { return static_cast<int>(_id); });
|
[](auto _id) { return static_cast<int>(_id); });
|
||||||
auto status = tokenizer_.Decode(tids, &decoded_string);
|
auto status = tokenizer_.Decode(tids, &decoded_string);
|
||||||
if (!status.ok()){
|
if (!status.ok()){
|
||||||
ORT_CXX_API_THROW("[SentencePieceDecoder] model decoding failed.", ORT_RUNTIME_EXCEPTION);
|
ORTX_CXX_API_THROW("[SentencePieceDecoder] model decoding failed.", ORT_RUNTIME_EXCEPTION);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> result = {decoded_string};
|
std::vector<std::string> result = {decoded_string};
|
||||||
|
@ -53,7 +53,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
||||||
sentencepiece::SentencePieceProcessor tokenizer_;
|
sentencepiece::SentencePieceProcessor tokenizer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpSentencepieceDecoder : Ort::CustomOpBase<CustomOpSentencepieceDecoder, KernelSentencepieceDecoder> {
|
struct CustomOpSentencepieceDecoder : OrtW::CustomOpBase<CustomOpSentencepieceDecoder, KernelSentencepieceDecoder> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return CreateKernelImpl(api, info);
|
return CreateKernelImpl(api, info);
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, co
|
||||||
(int)status.code(), ". Message is '", status.error_message(), "'."));
|
(int)status.code(), ". Message is '", status.error_message(), "'."));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void _check_dimension_constant(Ort::CustomOpApi ort, const OrtValue* ort_value, const char* name) {
|
static void _check_dimension_constant(OrtW::CustomOpApi ort, const OrtValue* ort_value, const char* name) {
|
||||||
OrtTensorDimensions dimensions(ort, ort_value);
|
OrtTensorDimensions dimensions(ort, ort_value);
|
||||||
if (dimensions.size() != 1 || dimensions[0] != 1)
|
if (dimensions.size() != 1 || dimensions[0] != 1)
|
||||||
throw std::runtime_error(MakeString(
|
throw std::runtime_error(MakeString(
|
||||||
|
|
|
@ -16,7 +16,7 @@ struct KernelSentencepieceTokenizer : BaseKernel {
|
||||||
sentencepiece::SentencePieceProcessor tokenizer_;
|
sentencepiece::SentencePieceProcessor tokenizer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpSentencepieceTokenizer : Ort::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
|
struct CustomOpSentencepieceTokenizer : OrtW::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -72,7 +72,7 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
|
||||||
rows.push_back(indices.size());
|
rows.push_back(indices.size());
|
||||||
} else if (text_index == existing_rows[row_index]) {
|
} else if (text_index == existing_rows[row_index]) {
|
||||||
if (row_index >= n_existing_rows)
|
if (row_index >= n_existing_rows)
|
||||||
ORT_CXX_API_THROW(MakeString(
|
ORTX_CXX_API_THROW(MakeString(
|
||||||
"row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT);
|
"row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT);
|
||||||
rows.push_back(indices.size());
|
rows.push_back(indices.size());
|
||||||
++row_index;
|
++row_index;
|
||||||
|
@ -181,7 +181,7 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index)
|
||||||
case 1:
|
case 1:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -198,6 +198,6 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index
|
||||||
case 3:
|
case 3:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -21,7 +21,7 @@ struct KernelWordpieceTokenizer : BaseKernel {
|
||||||
std::unordered_map<std::u32string, int32_t> vocab_;
|
std::unordered_map<std::u32string, int32_t> vocab_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpWordpieceTokenizer : Ort::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
|
struct CustomOpWordpieceTokenizer : OrtW::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||||
const char* GetName() const;
|
const char* GetName() const;
|
||||||
size_t GetInputTypeCount() const;
|
size_t GetInputTypeCount() const;
|
||||||
|
|
|
@ -42,4 +42,4 @@ struct hash<ustring> {
|
||||||
return standard_hash(static_cast<u32string>(__str));
|
return standard_hash(static_cast<u32string>(__str));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace std
|
} // namespace std
|
||||||
|
|
|
@ -12,7 +12,7 @@ void KernelDecodeImage::Compute(OrtKernelContext* context) {
|
||||||
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
|
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
|
||||||
OrtTensorDimensions dimensions(ort_, inputs);
|
OrtTensorDimensions dimensions(ort_, inputs);
|
||||||
if (dimensions.size() != 1ULL) {
|
if (dimensions.size() != 1ULL) {
|
||||||
ORT_CXX_API_THROW("[DecodeImage]: Raw image bytes with 1D shape expected.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[DecodeImage]: Raw image bytes with 1D shape expected.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtTensorTypeAndShapeInfo* input_info = ort_.GetTensorTypeAndShape(inputs);
|
OrtTensorTypeAndShapeInfo* input_info = ort_.GetTensorTypeAndShape(inputs);
|
||||||
|
@ -26,7 +26,7 @@ void KernelDecodeImage::Compute(OrtKernelContext* context) {
|
||||||
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
|
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
|
||||||
|
|
||||||
if (decoded_image.data == nullptr) {
|
if (decoded_image.data == nullptr) {
|
||||||
ORT_CXX_API_THROW("[DecodeImage] Invalid input. Failed to decode image.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[DecodeImage] Invalid input. Failed to decode image.", ORT_INVALID_ARGUMENT);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Setup output & copy to destination
|
// Setup output & copy to destination
|
||||||
|
|
|
@ -15,7 +15,7 @@ struct KernelDecodeImage : BaseKernel {
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpDecodeImage : Ort::CustomOpBase<CustomOpDecodeImage, KernelDecodeImage> {
|
struct CustomOpDecodeImage : OrtW::CustomOpBase<CustomOpDecodeImage, KernelDecodeImage> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new KernelDecodeImage(api);
|
return new KernelDecodeImage(api);
|
||||||
}
|
}
|
||||||
|
@ -37,7 +37,7 @@ struct CustomOpDecodeImage : Ort::CustomOpBase<CustomOpDecodeImage, KernelDecode
|
||||||
case 0:
|
case 0:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ struct CustomOpDecodeImage : Ort::CustomOpBase<CustomOpDecodeImage, KernelDecode
|
||||||
case 0:
|
case 0:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -15,7 +15,7 @@ void KernelEncodeImage ::Compute(OrtKernelContext* context) {
|
||||||
if (dimensions_bgr.size() != 3 || dimensions_bgr[2] != 3) {
|
if (dimensions_bgr.size() != 3 || dimensions_bgr[2] != 3) {
|
||||||
// expect {H, W, C} as that's the inverse of what decode_image produces.
|
// expect {H, W, C} as that's the inverse of what decode_image produces.
|
||||||
// we have no way to check if it's BGR or RGB though
|
// we have no way to check if it's BGR or RGB though
|
||||||
ORT_CXX_API_THROW("[EncodeImage] requires rank 3 BGR input in channels last format.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[EncodeImage] requires rank 3 BGR input in channels last format.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get data & the length
|
// Get data & the length
|
||||||
|
@ -29,7 +29,7 @@ void KernelEncodeImage ::Compute(OrtKernelContext* context) {
|
||||||
// don't know output size ahead of time so need to encode and then copy to output
|
// don't know output size ahead of time so need to encode and then copy to output
|
||||||
std::vector<uint8_t> encoded_image;
|
std::vector<uint8_t> encoded_image;
|
||||||
if (!cv::imencode(extension_, bgr_image, encoded_image)) {
|
if (!cv::imencode(extension_, bgr_image, encoded_image)) {
|
||||||
ORT_CXX_API_THROW("[EncodeImage] Image encoding failed.", ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW("[EncodeImage] Image encoding failed.", ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup output & copy to destination
|
// Setup output & copy to destination
|
||||||
|
|
|
@ -27,12 +27,12 @@ struct KernelEncodeImage : BaseKernel {
|
||||||
/// Converts rank 3 BGR input with channels last ordering to the requested file type.
|
/// Converts rank 3 BGR input with channels last ordering to the requested file type.
|
||||||
/// Default is 'jpg'
|
/// Default is 'jpg'
|
||||||
/// </summary>
|
/// </summary>
|
||||||
struct CustomOpEncodeImage : Ort::CustomOpBase<CustomOpEncodeImage, KernelEncodeImage> {
|
struct CustomOpEncodeImage : OrtW::CustomOpBase<CustomOpEncodeImage, KernelEncodeImage> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
Ort::CustomOpApi op_api{api};
|
OrtW::CustomOpApi op_api{api};
|
||||||
std::string format = op_api.KernelInfoGetAttribute<std::string>(info, "format");
|
std::string format = op_api.KernelInfoGetAttribute<std::string>(info, "format");
|
||||||
if (format != "jpg" && format != "png") {
|
if (format != "jpg" && format != "png") {
|
||||||
ORT_CXX_API_THROW("[EncodeImage] 'format' attribute value must be 'jpg' or 'png'.", ORT_RUNTIME_EXCEPTION);
|
ORTX_CXX_API_THROW("[EncodeImage] 'format' attribute value must be 'jpg' or 'png'.", ORT_RUNTIME_EXCEPTION);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new KernelEncodeImage(api, format);
|
return new KernelEncodeImage(api, format);
|
||||||
|
@ -51,7 +51,7 @@ struct CustomOpEncodeImage : Ort::CustomOpBase<CustomOpEncodeImage, KernelEncode
|
||||||
case 0:
|
case 0:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ struct CustomOpEncodeImage : Ort::CustomOpBase<CustomOpEncodeImage, KernelEncode
|
||||||
case 0:
|
case 0:
|
||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||||
default:
|
default:
|
||||||
ORT_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
|
ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -167,7 +167,7 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
|
||||||
}
|
}
|
||||||
|
|
||||||
static py::object BuildPyObjFromTensor(
|
static py::object BuildPyObjFromTensor(
|
||||||
const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
|
const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
|
||||||
const shape_t& shape, ONNXTensorElementDataType dtype) {
|
const shape_t& shape, ONNXTensorElementDataType dtype) {
|
||||||
std::vector<npy_intp> npy_dims;
|
std::vector<npy_intp> npy_dims;
|
||||||
for (auto n : shape) {
|
for (auto n : shape) {
|
||||||
|
|
|
@ -42,12 +42,12 @@ struct PyCustomOpKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const OrtApi& api_;
|
const OrtApi& api_;
|
||||||
Ort::CustomOpApi ort_;
|
OrtW::CustomOpApi ort_;
|
||||||
uint64_t obj_id_;
|
uint64_t obj_id_;
|
||||||
std::map<std::string, std::string> attrs_values_;
|
std::map<std::string, std::string> attrs_values_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PyCustomOpFactory : Ort::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
|
struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
|
||||||
PyCustomOpFactory() {
|
PyCustomOpFactory() {
|
||||||
// STL vector needs it.
|
// STL vector needs it.
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
#include "onnxruntime_extensions.h"
|
#include "onnxruntime_extensions.h"
|
||||||
#include "ocos.h"
|
#include "ocos.h"
|
||||||
|
|
||||||
|
using namespace OrtW;
|
||||||
|
|
||||||
struct OrtCustomOpDomainDeleter {
|
struct OrtCustomOpDomainDeleter {
|
||||||
explicit OrtCustomOpDomainDeleter(const OrtApi* ort_api) {
|
explicit OrtCustomOpDomainDeleter(const OrtApi* ort_api) {
|
||||||
ort_api_ = ort_api;
|
ort_api_ = ort_api;
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
#include "onnxruntime_cxx_api.h"
|
||||||
|
|
||||||
const char* GetLibraryPath();
|
const char* GetLibraryPath();
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
// Licensed under the MIT License.
|
// Licensed under the MIT License.
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h"
|
|
||||||
|
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "ocos.h"
|
#include "ocos.h"
|
||||||
|
@ -50,7 +47,7 @@ struct KernelOne : BaseKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {
|
struct CustomOpOne : OrtW::CustomOpBase<CustomOpOne, KernelOne> {
|
||||||
const char* GetName() const {
|
const char* GetName() const {
|
||||||
return "CustomOpOne";
|
return "CustomOpOne";
|
||||||
};
|
};
|
||||||
|
@ -93,7 +90,7 @@ struct KernelTwo : BaseKernel {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> {
|
struct CustomOpTwo : OrtW::CustomOpBase<CustomOpTwo, KernelTwo> {
|
||||||
const char* GetName() const {
|
const char* GetName() const {
|
||||||
return "CustomOpTwo";
|
return "CustomOpTwo";
|
||||||
};
|
};
|
||||||
|
@ -138,7 +135,7 @@ struct KernelThree : BaseKernel {
|
||||||
std::string substr_;
|
std::string substr_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> {
|
struct CustomOpThree : OrtW::CustomOpBase<CustomOpThree, KernelThree> {
|
||||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return CreateKernelImpl(api, info);
|
return CreateKernelImpl(api, info);
|
||||||
};
|
};
|
||||||
|
@ -188,8 +185,7 @@ void _assert_eq(Ort::Value& output_tensor, const std::vector<T>& expected, size_
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::vector<std::string>& output) {
|
void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::vector<std::string>& output) {
|
||||||
Ort::CustomOpApi ort(api);
|
OrtTensorDimensions dimensions(OrtW::CustomOpApi(api), value);
|
||||||
OrtTensorDimensions dimensions(ort, value);
|
|
||||||
size_t len = static_cast<size_t>(dimensions.Size());
|
size_t len = static_cast<size_t>(dimensions.Size());
|
||||||
size_t data_len;
|
size_t data_len;
|
||||||
Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
|
Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
|
||||||
|
@ -397,7 +393,7 @@ TEST(ustring, tensor_operator) {
|
||||||
}
|
}
|
||||||
EXPECT_EQ(err_code, ORT_OK);
|
EXPECT_EQ(err_code, ORT_OK);
|
||||||
|
|
||||||
Ort::CustomOpApi custom_api(*api);
|
OrtW::CustomOpApi custom_api(*api);
|
||||||
|
|
||||||
std::vector<int64_t> dim{2, 2};
|
std::vector<int64_t> dim{2, 2};
|
||||||
status = api->CreateTensorAsOrtValue(allocator, dim.data(), dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &tensor);
|
status = api->CreateTensorAsOrtValue(allocator, dim.data(), dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &tensor);
|
||||||
|
|
Загрузка…
Ссылка в новой задаче