Using the header files from the ONNXRuntime package (#322)

* Using the header files from the ONNXRuntime package

* Update includes/onnxruntime_customop.hpp

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>

* fix the build break.

* one more fixing

* wired top project

* ort 1.9.0 used

* switch to 1.10.0 package.

* change the vmimage to latest

* URL issue

* cmake policy

* ignore onnxruntime.dll native scan

* update the Onebranch exclusedPaths

* fixing some build tool issues

* update again

* typo

* undo of ORT dll removal

Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
This commit is contained in:
Wenbing Li 2022-12-09 14:30:24 -08:00 коммит произвёл GitHub
Родитель 69e6ec7cf1
Коммит c599b00d07
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
82 изменённых файлов: 625 добавлений и 4193 удалений

Просмотреть файл

@ -59,7 +59,7 @@ jobs:
displayName: Unpack ONNXRuntime package. displayName: Unpack ONNXRuntime package.
- script: | - script: |
CPU_NUMBER=2 sh ./build.sh -DONNXRUNTIME_LIB_DIR=onnxruntime-linux-x64-$(ort.version)/lib -DOCOS_ENABLE_CTEST=ON CPU_NUMBER=2 sh ./build.sh -DOCOS_ENABLE_CTEST=ON
displayName: build the customop library with onnxruntime displayName: build the customop library with onnxruntime
- script: | - script: |
@ -139,7 +139,7 @@ jobs:
displayName: Unpack ONNXRuntime package. displayName: Unpack ONNXRuntime package.
- script: | - script: |
sh ./build.sh -DONNXRUNTIME_LIB_DIR=$(ort.dirname)/lib -DOCOS_ENABLE_CTEST=ON sh ./build.sh -DOCOS_ENABLE_CTEST=ON
displayName: build the customop library with onnxruntime displayName: build the customop library with onnxruntime
- script: | - script: |
@ -266,7 +266,7 @@ jobs:
- script: | - script: |
call $(vsdevcmd) call $(vsdevcmd)
call .\build.bat -DONNXRUNTIME_LIB_DIR=.\onnxruntime-win-x64-$(ort.version)\lib -DOCOS_ENABLE_CTEST=ON call .\build.bat -DOCOS_ENABLE_CTEST=ON
displayName: build the customop library with onnxruntime displayName: build the customop library with onnxruntime
- script: | - script: |
@ -427,6 +427,7 @@ jobs:
############# #############
# iOS # iOS
# Only cmake==3.25.0 works OpenCV compiling now.
############# #############
- job: IosPackage - job: IosPackage
pool: pool:
@ -441,6 +442,7 @@ jobs:
displayName: "Use Python 3.9" displayName: "Use Python 3.9"
- script: | - script: |
python -m pip install cmake==3.25.0
python ./tools/ios/build_xcframework.py \ python ./tools/ios/build_xcframework.py \
--output-dir $(Build.BinariesDirectory)/xcframework_out \ --output-dir $(Build.BinariesDirectory)/xcframework_out \
--platform-arch iphonesimulator x86_64 \ --platform-arch iphonesimulator x86_64 \

Просмотреть файл

@ -44,6 +44,7 @@ extends:
enabled: true enabled: true
binskim: binskim:
break: true # always break the build on binskim issues in addition to TSA upload break: true # always break the build on binskim issues in addition to TSA upload
analyzeTargetGlob: '**\bin\*' # only scan the DLLs in extensions bin folder.
codeql: codeql:
python: python:
enabled: true enabled: true
@ -87,7 +88,7 @@ extends:
@echo ##vso[task.setvariable variable=vsdevcmd]%vsdevcmd% @echo ##vso[task.setvariable variable=vsdevcmd]%vsdevcmd%
@echo ##vso[task.setvariable variable=vscmake]%vscmake% @echo ##vso[task.setvariable variable=vscmake]%vscmake%
@echo ##vso[task.setvariable variable=vsmsbuild]%vsmsbuild% @echo ##vso[task.setvariable variable=vsmsbuild]%vsmsbuild%
displayName: 'locate vsdevcmd via vswhere' displayName: 'locate vsdevcmd via vswhere'
- script: | - script: |
call $(vsdevcmd) call $(vsdevcmd)
set PYTHONPATH= set PYTHONPATH=
@ -95,6 +96,7 @@ extends:
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install cibuildwheel numpy python -m pip install cibuildwheel numpy
python -m cibuildwheel --platform windows --output-dir $(REPOROOT)\out python -m cibuildwheel --platform windows --output-dir $(REPOROOT)\out
del /s /q /f .setuptools-cmake-build\*onnxruntime.dll
displayName: Build wheels displayName: Build wheels
- task: SDLNativeRules@3 - task: SDLNativeRules@3
inputs: inputs:

Просмотреть файл

@ -44,6 +44,7 @@ extends:
enabled: false enabled: false
binskim: binskim:
break: true # always break the build on binskim issues in addition to TSA upload break: true # always break the build on binskim issues in addition to TSA upload
analyzeTargetGlob: '**\bin\*' # only scan the DLLs in extensions bin folder.
codeql: codeql:
python: python:
enabled: true enabled: true
@ -86,8 +87,8 @@ extends:
@echo ##vso[task.setvariable variable=vslatest]%vslatest% @echo ##vso[task.setvariable variable=vslatest]%vslatest%
@echo ##vso[task.setvariable variable=vsdevcmd]%vsdevcmd% @echo ##vso[task.setvariable variable=vsdevcmd]%vsdevcmd%
@echo ##vso[task.setvariable variable=vscmake]%vscmake% @echo ##vso[task.setvariable variable=vscmake]%vscmake%
@echo ##vso[task.setvariable variable=vsmsbuild]%vsmsbuild% @echo ##vso[task.setvariable variable=vsmsbuild]%vsmsbuild%
displayName: 'locate vsdevcmd via vswhere' displayName: 'locate vsdevcmd via vswhere'
- script: | - script: |
call $(vsdevcmd) call $(vsdevcmd)
set PYTHONPATH= set PYTHONPATH=
@ -95,6 +96,7 @@ extends:
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install cibuildwheel numpy python -m pip install cibuildwheel numpy
python -m cibuildwheel --platform windows --output-dir $(REPOROOT)\out python -m cibuildwheel --platform windows --output-dir $(REPOROOT)\out
del /s /q /f .setuptools-cmake-build\*onnxruntime.dll
displayName: Build wheels displayName: Build wheels
- task: SDLNativeRules@3 - task: SDLNativeRules@3
inputs: inputs:

Просмотреть файл

@ -1,13 +1,12 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
project(onnxruntime_extensions LANGUAGES C CXX) project(onnxruntime_extensions LANGUAGES C CXX)
# set(CMAKE_VERBOSE_MAKEFILE ON)
# set(CMAKE_VERBOSE_MAKEFILE ON)
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Build type not set - using RelWithDebInfo") message(STATUS "Build type not set - using RelWithDebInfo")
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Choose build type: Debug Release RelWithDebInfo." FORCE) set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Choose build type: Debug Release RelWithDebInfo." FORCE)
endif() endif()
set(CPACK_PACKAGE_NAME "onnxruntime_extensions") set(CPACK_PACKAGE_NAME "onnxruntime_extensions")
set(CPACK_PACKAGE_VERSION_MAJOR "0") set(CPACK_PACKAGE_VERSION_MAJOR "0")
set(CPACK_PACKAGE_VERSION_MINOR "5") set(CPACK_PACKAGE_VERSION_MINOR "5")
@ -68,7 +67,7 @@ if(NOT CC_OPTIMIZE)
string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}")
string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") string(REGEX REPLACE "([\-\/]O[123])" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
if (NOT WIN32) if(NOT WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O0")
else() else()
@ -77,15 +76,16 @@ if(NOT CC_OPTIMIZE)
endif() endif()
endif() endif()
if (NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON) if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
message("OCOS_ENABLE_PYTHON IS DEPRECATED, USE OCOS_BUILD_PYTHON INSTEAD") message("OCOS_ENABLE_PYTHON IS DEPRECATED, USE OCOS_BUILD_PYTHON INSTEAD")
set(OCOS_BUILD_PYTHON ON CACHE INTERNAL "") set(OCOS_BUILD_PYTHON ON CACHE INTERNAL "")
endif() endif()
if (OCOS_BUILD_ANDROID) if(OCOS_BUILD_ANDROID)
if (NOT ANDROID_SDK_ROOT OR NOT ANDROID_NDK) if(NOT ANDROID_SDK_ROOT OR NOT ANDROID_NDK)
message("Cannot the find Android SDK/NDK") message("Cannot the find Android SDK/NDK")
endif() endif()
set(OCOS_BUILD_JAVA ON CACHE INTERNAL "") set(OCOS_BUILD_JAVA ON CACHE INTERNAL "")
endif() endif()
@ -95,6 +95,7 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set_property(GLOBAL PROPERTY USE_FOLDERS ON) set_property(GLOBAL PROPERTY USE_FOLDERS ON)
set(CMAKE_FIND_FRAMEWORK NEVER CACHE STRING "...") set(CMAKE_FIND_FRAMEWORK NEVER CACHE STRING "...")
if(NOT "${CMAKE_FIND_FRAMEWORK}" STREQUAL "NEVER") if(NOT "${CMAKE_FIND_FRAMEWORK}" STREQUAL "NEVER")
message(FATAL_ERROR "CMAKE_FIND_FRAMEWORK is not NEVER") message(FATAL_ERROR "CMAKE_FIND_FRAMEWORK is not NEVER")
endif() endif()
@ -102,7 +103,7 @@ endif()
# External dependencies # External dependencies
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/externals ${PROJECT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/externals ${PROJECT_SOURCE_DIR}/cmake)
if (OCOS_ENABLE_SELECTED_OPLIST) if(OCOS_ENABLE_SELECTED_OPLIST)
# Need to ensure _selectedoplist.cmake file is already generated in folder: ${PROJECT_SOURCE_DIR}/cmake/ # Need to ensure _selectedoplist.cmake file is already generated in folder: ${PROJECT_SOURCE_DIR}/cmake/
# You could run gen_selectedops.py in folder: tools/ to generate _selectedoplist.cmake # You could run gen_selectedops.py in folder: tools/ to generate _selectedoplist.cmake
message(STATUS "Looking for the _selectedoplist.cmake") message(STATUS "Looking for the _selectedoplist.cmake")
@ -113,6 +114,7 @@ endif()
if(NOT OCOS_ENABLE_CPP_EXCEPTIONS) if(NOT OCOS_ENABLE_CPP_EXCEPTIONS)
include(noexcep_ops) include(noexcep_ops)
add_compile_definitions(OCOS_NO_EXCEPTIONS ORT_NO_EXCEPTIONS) add_compile_definitions(OCOS_NO_EXCEPTIONS ORT_NO_EXCEPTIONS)
if(MSVC) if(MSVC)
string(REGEX REPLACE "/EHsc" "/EHs-c-" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") string(REGEX REPLACE "/EHsc" "/EHs-c-" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
add_compile_definitions("_HAS_EXCEPTIONS=0") add_compile_definitions("_HAS_EXCEPTIONS=0")
@ -123,8 +125,15 @@ endif()
include(FetchContent) include(FetchContent)
if (OCOS_ENABLE_RE2_REGEX) # PROJECT_IS_TOP_LEVEL is available until 3.21
if (NOT TARGET re2::re2) get_property(not_top DIRECTORY PROPERTY PARENT_DIRECTORY)
if(not_top)
set(_ONNXRUNTIME_EMBEDDED TRUE)
endif()
include(ext_ortlib)
if(OCOS_ENABLE_RE2_REGEX)
if(NOT TARGET re2::re2)
set(RE2_BUILD_TESTING OFF CACHE INTERNAL "") set(RE2_BUILD_TESTING OFF CACHE INTERNAL "")
message(STATUS "Fetch googlere2") message(STATUS "Fetch googlere2")
include(googlere2) include(googlere2)
@ -136,30 +145,29 @@ if (OCOS_ENABLE_RE2_REGEX)
endif() endif()
macro(standardize_output_folder bin_target) macro(standardize_output_folder bin_target)
set_target_properties(${bin_target} PROPERTIES set_target_properties(${bin_target} PROPERTIES
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
endmacro() endmacro()
# ### scan all source files
#### scan all source files
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h") file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h")
if (OCOS_ENABLE_TF_STRING)
if(OCOS_ENABLE_TF_STRING)
set(farmhash_SOURCE_DIR ${PROJECT_SOURCE_DIR}/cmake/externals/farmhash) set(farmhash_SOURCE_DIR ${PROJECT_SOURCE_DIR}/cmake/externals/farmhash)
file(GLOB TARGET_SRC_KERNELS "operators/text/*.cc" "operators/text/*.h*") file(GLOB TARGET_SRC_KERNELS "operators/text/*.cc" "operators/text/*.h*")
file(GLOB TARGET_SRC_HASH "${farmhash_SOURCE_DIR}/src/farmhash.*") file(GLOB TARGET_SRC_HASH "${farmhash_SOURCE_DIR}/src/farmhash.*")
list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH}) list(APPEND TARGET_SRC ${TARGET_SRC_KERNELS} ${TARGET_SRC_HASH})
endif() endif()
if (OCOS_ENABLE_RE2_REGEX) if(OCOS_ENABLE_RE2_REGEX)
file(GLOB TARGET_SRC_RE2_KERNELS "operators/text/re2_strings/*.cc" "operators/text/re2_strings/*.h*") file(GLOB TARGET_SRC_RE2_KERNELS "operators/text/re2_strings/*.cc" "operators/text/re2_strings/*.h*")
list(APPEND TARGET_SRC ${TARGET_SRC_RE2_KERNELS}) list(APPEND TARGET_SRC ${TARGET_SRC_RE2_KERNELS})
endif() endif()
if (OCOS_ENABLE_MATH) if(OCOS_ENABLE_MATH)
if(OCOS_ENABLE_DLIB) if(OCOS_ENABLE_DLIB)
set(DLIB_ISO_CPP_ONLY ON CACHE INTERNAL "") set(DLIB_ISO_CPP_ONLY ON CACHE INTERNAL "")
set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "") set(DLIB_NO_GUI_SUPPORT ON CACHE INTERNAL "")
@ -167,31 +175,33 @@ if (OCOS_ENABLE_MATH)
set(DLIB_USE_LAPACK OFF CACHE INTERNAL "") set(DLIB_USE_LAPACK OFF CACHE INTERNAL "")
set(DLIB_USE_BLAS OFF CACHE INTERNAL "") set(DLIB_USE_BLAS OFF CACHE INTERNAL "")
include(dlib) include(dlib)
# Ideally, dlib should be included as # Ideally, dlib should be included as
# file(GLOB TARGET_SRC_DLIB "${dlib_SOURCE_DIR}/dlib/all/source.cpp") # file(GLOB TARGET_SRC_DLIB "${dlib_SOURCE_DIR}/dlib/all/source.cpp")
# To avoid the unintentional using some unwanted component, only include # To avoid the unintentional using some unwanted component, only include
file(GLOB TARGET_SRC_DLIB "${dlib_SOURCE_DIR}/dlib/test_for_odr_violations.cpp") file(GLOB TARGET_SRC_DLIB "${dlib_SOURCE_DIR}/dlib/test_for_odr_violations.cpp")
file(GLOB TARGET_SRC_INVERSE "operators/math/dlib/*.cc" "operators/math/dlib/*.h*") file(GLOB TARGET_SRC_INVERSE "operators/math/dlib/*.cc" "operators/math/dlib/*.h*")
endif() endif()
file(GLOB TARGET_SRC_MATH "operators/math/*.cc" "operators/math/*.h*") file(GLOB TARGET_SRC_MATH "operators/math/*.cc" "operators/math/*.h*")
list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE}) list(APPEND TARGET_SRC ${TARGET_SRC_MATH} ${TARGET_SRC_DLIB} ${TARGET_SRC_INVERSE})
endif() endif()
# enable the opencv dependency if we have ops that require it # enable the opencv dependency if we have ops that require it
if (OCOS_ENABLE_CV2 OR OCOS_ENABLE_VISION) if(OCOS_ENABLE_CV2 OR OCOS_ENABLE_VISION)
set(_ENABLE_OPENCV ON) set(_ENABLE_OPENCV ON)
message(STATUS "Fetch opencv") message(STATUS "Fetch opencv")
include(opencv) include(opencv)
endif() endif()
if (OCOS_ENABLE_CV2) if(OCOS_ENABLE_CV2)
file(GLOB TARGET_SRC_CV2 "operators/cv2/*.cc" "operators/cv2/*.h*") file(GLOB TARGET_SRC_CV2 "operators/cv2/*.cc" "operators/cv2/*.h*")
list(APPEND TARGET_SRC ${TARGET_SRC_CV2}) list(APPEND TARGET_SRC ${TARGET_SRC_CV2})
endif() endif()
if (OCOS_ENABLE_VISION) if(OCOS_ENABLE_VISION)
if (NOT OCOS_ENABLE_OPENCV_CODECS) if(NOT OCOS_ENABLE_OPENCV_CODECS)
message(FATAL_ERROR "OCOS_ENABLE_VISION requires OCOS_ENABLE_OPENCV_CODECS to be ON") message(FATAL_ERROR "OCOS_ENABLE_VISION requires OCOS_ENABLE_OPENCV_CODECS to be ON")
endif() endif()
file(GLOB TARGET_SRC_VISION "operators/vision/*.cc" "operators/vision/*.h*") file(GLOB TARGET_SRC_VISION "operators/vision/*.cc" "operators/vision/*.h*")
@ -199,14 +209,15 @@ if (OCOS_ENABLE_VISION)
endif() endif()
set(_HAS_TOKENIZER OFF) set(_HAS_TOKENIZER OFF)
if (OCOS_ENABLE_GPT2_TOKENIZER)
if(OCOS_ENABLE_GPT2_TOKENIZER)
# GPT2 # GPT2
set(_HAS_TOKENIZER ON) set(_HAS_TOKENIZER ON)
file(GLOB tok_TARGET_SRC "operators/tokenizer/gpt*.cc" "operators/tokenizer/unicode*.*") file(GLOB tok_TARGET_SRC "operators/tokenizer/gpt*.cc" "operators/tokenizer/unicode*.*")
list(APPEND TARGET_SRC ${tok_TARGET_SRC}) list(APPEND TARGET_SRC ${tok_TARGET_SRC})
endif() endif()
if (OCOS_ENABLE_SPM_TOKENIZER) if(OCOS_ENABLE_SPM_TOKENIZER)
# SentencePiece # SentencePiece
set(_HAS_TOKENIZER ON) set(_HAS_TOKENIZER ON)
set(SPM_ENABLE_TCMALLOC OFF CACHE INTERNAL "") set(SPM_ENABLE_TCMALLOC OFF CACHE INTERNAL "")
@ -218,42 +229,41 @@ if (OCOS_ENABLE_SPM_TOKENIZER)
list(APPEND TARGET_SRC ${stpiece_TARGET_SRC}) list(APPEND TARGET_SRC ${stpiece_TARGET_SRC})
endif() endif()
if (OCOS_ENABLE_WORDPIECE_TOKENIZER) if(OCOS_ENABLE_WORDPIECE_TOKENIZER)
set(_HAS_TOKENIZER ON) set(_HAS_TOKENIZER ON)
file(GLOB wordpiece_TARGET_SRC "operators/tokenizer/wordpiece*.*") file(GLOB wordpiece_TARGET_SRC "operators/tokenizer/wordpiece*.*")
list(APPEND TARGET_SRC ${wordpiece_TARGET_SRC}) list(APPEND TARGET_SRC ${wordpiece_TARGET_SRC})
endif() endif()
if (OCOS_ENABLE_BERT_TOKENIZER) if(OCOS_ENABLE_BERT_TOKENIZER)
# Bert # Bert
set(_HAS_TOKENIZER ON) set(_HAS_TOKENIZER ON)
file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*") file(GLOB bert_TARGET_SRC "operators/tokenizer/basic_tokenizer.*" "operators/tokenizer/bert_tokenizer.*" "operators/tokenizer/bert_tokenizer_decoder.*")
list(APPEND TARGET_SRC ${bert_TARGET_SRC}) list(APPEND TARGET_SRC ${bert_TARGET_SRC})
endif() endif()
if (OCOS_ENABLE_BLINGFIRE) if(OCOS_ENABLE_BLINGFIRE)
# blingfire # blingfire
set(_HAS_TOKENIZER ON) set(_HAS_TOKENIZER ON)
file(GLOB blingfire_TARGET_SRC "operators/tokenizer/blingfire*.*") file(GLOB blingfire_TARGET_SRC "operators/tokenizer/blingfire*.*")
list(APPEND TARGET_SRC ${blingfire_TARGET_SRC}) list(APPEND TARGET_SRC ${blingfire_TARGET_SRC})
endif() endif()
if (OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER) if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
if (NOT TARGET nlohmann_json) if(NOT TARGET nlohmann_json)
set(JSON_BuildTests OFF CACHE INTERNAL "") set(JSON_BuildTests OFF CACHE INTERNAL "")
message(STATUS "Fetch json") message(STATUS "Fetch json")
include(json) include(json)
endif() endif()
endif() endif()
if (_HAS_TOKENIZER) if(_HAS_TOKENIZER)
message(STATUS "Tokenizer needed.") message(STATUS "Tokenizer needed.")
file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*") file(GLOB tokenizer_TARGET_SRC "operators/tokenizer/tokenizers.*")
list(APPEND TARGET_SRC ${tokenizer_TARGET_SRC}) list(APPEND TARGET_SRC ${tokenizer_TARGET_SRC})
endif() endif()
#### make all compile options. # ### make all compile options.
add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>") add_compile_options("$<$<C_COMPILER_ID:MSVC>:/utf-8>")
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>") add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
add_library(ocos_operators STATIC ${TARGET_SRC}) add_library(ocos_operators STATIC ${TARGET_SRC})
@ -274,22 +284,22 @@ source_group(TREE ${PROJECT_SOURCE_DIR} FILES ${_TARGET_SRC_FOR_SOURCE_GROUP})
standardize_output_folder(ocos_operators) standardize_output_folder(ocos_operators)
target_include_directories(ocos_operators PUBLIC target_include_directories(ocos_operators PUBLIC
${ONNXRUNTIME_INCLUDE_DIR}
${PROJECT_SOURCE_DIR}/includes ${PROJECT_SOURCE_DIR}/includes
${PROJECT_SOURCE_DIR}/includes/onnxruntime
${PROJECT_SOURCE_DIR}/operators ${PROJECT_SOURCE_DIR}/operators
${PROJECT_SOURCE_DIR}/operators/tokenizer) ${PROJECT_SOURCE_DIR}/operators/tokenizer)
set(ocos_libraries "") set(ocos_libraries "")
set(OCOS_COMPILE_DEFINITIONS "") set(OCOS_COMPILE_DEFINITIONS "")
if (OCOS_ENABLE_DLIB) if(OCOS_ENABLE_DLIB)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_DLIB) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_DLIB)
endif() endif()
if (_HAS_TOKENIZER) if(_HAS_TOKENIZER)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TOKENIZER) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_TOKENIZER)
endif() endif()
if (OCOS_ENABLE_TF_STRING) if(OCOS_ENABLE_TF_STRING)
target_include_directories(ocos_operators PUBLIC target_include_directories(ocos_operators PUBLIC
${googlere2_SOURCE_DIR} ${googlere2_SOURCE_DIR}
${farmhash_SOURCE_DIR}/src) ${farmhash_SOURCE_DIR}/src)
@ -297,62 +307,63 @@ if (OCOS_ENABLE_TF_STRING)
list(APPEND ocos_libraries re2) list(APPEND ocos_libraries re2)
endif() endif()
if (OCOS_ENABLE_RE2_REGEX) if(OCOS_ENABLE_RE2_REGEX)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_RE2_REGEX) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_RE2_REGEX)
endif() endif()
if (OCOS_ENABLE_MATH) if(OCOS_ENABLE_MATH)
target_include_directories(ocos_operators PUBLIC ${dlib_SOURCE_DIR}) target_include_directories(ocos_operators PUBLIC ${dlib_SOURCE_DIR})
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_MATH) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_MATH)
# The dlib matrix implementation is all in the headers, no library compiling needed. # The dlib matrix implementation is all in the headers, no library compiling needed.
endif() endif()
if (_ENABLE_OPENCV) if(_ENABLE_OPENCV)
list(APPEND ocos_libraries ${opencv_LIBS}) list(APPEND ocos_libraries ${opencv_LIBS})
target_include_directories(ocos_operators PUBLIC ${opencv_INCLUDE_DIRS}) target_include_directories(ocos_operators PUBLIC ${opencv_INCLUDE_DIRS})
endif() endif()
if (OCOS_ENABLE_OPENCV_CODECS) if(OCOS_ENABLE_OPENCV_CODECS)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_OPENCV_CODECS) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_OPENCV_CODECS)
endif() endif()
if (OCOS_ENABLE_CV2) if(OCOS_ENABLE_CV2)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_CV2) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_CV2)
endif() endif()
if (OCOS_ENABLE_VISION) if(OCOS_ENABLE_VISION)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_VISION) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_VISION)
endif() endif()
if (OCOS_ENABLE_GPT2_TOKENIZER) if(OCOS_ENABLE_GPT2_TOKENIZER)
# GPT2 # GPT2
target_include_directories(ocos_operators PRIVATE ${json_SOURCE_DIR}/single_include) target_include_directories(ocos_operators PRIVATE ${json_SOURCE_DIR}/single_include)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_GPT2_TOKENIZER) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_GPT2_TOKENIZER)
list(APPEND ocos_libraries nlohmann_json::nlohmann_json) list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
endif() endif()
if (OCOS_ENABLE_WORDPIECE_TOKENIZER) if(OCOS_ENABLE_WORDPIECE_TOKENIZER)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_WORDPIECE_TOKENIZER) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_WORDPIECE_TOKENIZER)
endif() endif()
if (OCOS_ENABLE_BERT_TOKENIZER) if(OCOS_ENABLE_BERT_TOKENIZER)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BERT_TOKENIZER) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BERT_TOKENIZER)
endif() endif()
if (OCOS_ENABLE_SPM_TOKENIZER) if(OCOS_ENABLE_SPM_TOKENIZER)
# SentencePiece # SentencePiece
target_include_directories(ocos_operators PUBLIC ${spm_INCLUDE_DIRS}) target_include_directories(ocos_operators PUBLIC ${spm_INCLUDE_DIRS})
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_SPM_TOKENIZER) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_SPM_TOKENIZER)
list(APPEND ocos_libraries sentencepiece-static) list(APPEND ocos_libraries sentencepiece-static)
endif() endif()
if (OCOS_ENABLE_BLINGFIRE) if(OCOS_ENABLE_BLINGFIRE)
include(blingfire) include(blingfire)
list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BLINGFIRE) list(APPEND OCOS_COMPILE_DEFINITIONS ENABLE_BLINGFIRE)
list(APPEND ocos_libraries bingfirtinydll_static) list(APPEND ocos_libraries bingfirtinydll_static)
endif() endif()
if (OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER) if(OCOS_ENABLE_GPT2_TOKENIZER OR OCOS_ENABLE_WORDPIECE_TOKENIZER)
target_include_directories(ocos_operators PRIVATE ${json_SOURCE_DIR}/single_include) target_include_directories(ocos_operators PRIVATE ${json_SOURCE_DIR}/single_include)
list(APPEND ocos_libraries nlohmann_json::nlohmann_json) list(APPEND ocos_libraries nlohmann_json::nlohmann_json)
endif() endif()
@ -362,6 +373,7 @@ target_compile_definitions(ocos_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS})
target_link_libraries(ocos_operators PRIVATE ${ocos_libraries}) target_link_libraries(ocos_operators PRIVATE ${ocos_libraries})
file(GLOB shared_TARGET_LIB_SRC "shared/lib/*.cc" "shared/lib/*.h") file(GLOB shared_TARGET_LIB_SRC "shared/lib/*.cc" "shared/lib/*.h")
if(NOT OCOS_ENABLE_STATIC_LIB AND CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if(NOT OCOS_ENABLE_STATIC_LIB AND CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
add_executable(ortcustomops ${shared_TARGET_LIB_SRC}) add_executable(ortcustomops ${shared_TARGET_LIB_SRC})
set_target_properties(ortcustomops PROPERTIES LINK_FLAGS " \ set_target_properties(ortcustomops PROPERTIES LINK_FLAGS " \
@ -375,7 +387,8 @@ if(NOT OCOS_ENABLE_STATIC_LIB AND CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
-s EXPORT_ALL=0 \ -s EXPORT_ALL=0 \
-s VERBOSE=0 \ -s VERBOSE=0 \
--no-entry") --no-entry")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
set_property(TARGET ortcustomops APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=1 -s DEMANGLE_SUPPORT=1") set_property(TARGET ortcustomops APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=1 -s DEMANGLE_SUPPORT=1")
else() else()
set_property(TARGET ortcustomops APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=0 -s DEMANGLE_SUPPORT=0") set_property(TARGET ortcustomops APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=0 -s DEMANGLE_SUPPORT=0")
@ -392,18 +405,19 @@ target_include_directories(ortcustomops PUBLIC
"$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>") "$<TARGET_PROPERTY:ocos_operators,INTERFACE_INCLUDE_DIRECTORIES>")
target_link_libraries(ortcustomops PUBLIC ocos_operators) target_link_libraries(ortcustomops PUBLIC ocos_operators)
if (_BUILD_SHARED_LIBRARY) if(_BUILD_SHARED_LIBRARY)
file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h" "shared/*.def") file(GLOB shared_TARGET_SRC "shared/*.cc" "shared/*.h" "shared/*.def")
add_library(extensions_shared SHARED ${shared_TARGET_SRC}) add_library(extensions_shared SHARED ${shared_TARGET_SRC})
source_group(TREE ${PROJECT_SOURCE_DIR} FILES ${shared_TARGET_SRC}) source_group(TREE ${PROJECT_SOURCE_DIR} FILES ${shared_TARGET_SRC})
standardize_output_folder(extensions_shared) standardize_output_folder(extensions_shared)
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
if (OCOS_ENABLE_SPM_TOKENIZER) if(CMAKE_SYSTEM_NAME STREQUAL "Android")
if(OCOS_ENABLE_SPM_TOKENIZER)
target_link_libraries(extensions_shared PUBLIC log) target_link_libraries(extensions_shared PUBLIC log)
endif() endif()
endif() endif()
if (LINUX OR CMAKE_SYSTEM_NAME STREQUAL "Android") if(LINUX OR CMAKE_SYSTEM_NAME STREQUAL "Android")
set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS "-Wl,-s -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver") set_property(TARGET extensions_shared APPEND_STRING PROPERTY LINK_FLAGS "-Wl,-s -Wl,--version-script -Wl,${PROJECT_SOURCE_DIR}/shared/ortcustomops.ver")
endif() endif()
@ -429,14 +443,16 @@ endif()
# clean up the requirements.txt files from 3rd party project folder to suppress the code security false alarms # clean up the requirements.txt files from 3rd party project folder to suppress the code security false alarms
file(GLOB_RECURSE NO_USE_FILES ${CMAKE_BINARY_DIR}/_deps/*requirements.txt) file(GLOB_RECURSE NO_USE_FILES ${CMAKE_BINARY_DIR}/_deps/*requirements.txt)
message("Found the follow requirements.txt: ${NO_USE_FILES}") message(STATUS "Found the follow requirements.txt: ${NO_USE_FILES}")
foreach(nf ${NO_USE_FILES}) foreach(nf ${NO_USE_FILES})
file(TO_NATIVE_PATH ${nf} nf_native) file(TO_NATIVE_PATH ${nf} nf_native)
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
execute_process(COMMAND cmd /c "del ${nf_native}") if(CMAKE_SYSTEM_NAME MATCHES "Windows")
else() execute_process(COMMAND cmd /c "del ${nf_native}")
execute_process(COMMAND bash -c "rm ${nf_native}") else()
endif() execute_process(COMMAND bash -c "rm ${nf_native}")
endif()
endforeach() endforeach()
# test section # test section
@ -460,10 +476,12 @@ elseif(OCOS_ENABLE_CTEST AND NOT OCOS_ENABLE_SELECTED_OPLIST)
SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) SET(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}") find_library(ONNXRUNTIME onnxruntime HINTS "${ONNXRUNTIME_LIB_DIR}")
if (ONNXRUNTIME-NOTFOUND)
if(ONNXRUNTIME-NOTFOUND)
message(WARNING "The prebuilt onnxruntime libraries directory cannot found (via ONNXRUNTIME_LIB_DIR), the extensions_test will be skipped.") message(WARNING "The prebuilt onnxruntime libraries directory cannot found (via ONNXRUNTIME_LIB_DIR), the extensions_test will be skipped.")
else() else()
set(LINUX_CC_FLAGS "") set(LINUX_CC_FLAGS "")
# needs to link with stdc++fs in Linux # needs to link with stdc++fs in Linux
if(UNIX AND NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android") if(UNIX AND NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Android")
list(APPEND LINUX_CC_FLAGS stdc++fs -pthread) list(APPEND LINUX_CC_FLAGS stdc++fs -pthread)
@ -474,11 +492,14 @@ elseif(OCOS_ENABLE_CTEST AND NOT OCOS_ENABLE_SELECTED_OPLIST)
standardize_output_folder(extensions_test) standardize_output_folder(extensions_test)
target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS} target_include_directories(extensions_test PRIVATE ${spm_INCLUDE_DIRS}
"$<TARGET_PROPERTY:extensions_shared,INTERFACE_INCLUDE_DIRECTORIES>") "$<TARGET_PROPERTY:extensions_shared,INTERFACE_INCLUDE_DIRECTORIES>")
if (ONNXRUNTIME_LIB_DIR)
if(ONNXRUNTIME_LIB_DIR)
target_link_directories(extensions_test PRIVATE ${ONNXRUNTIME_LIB_DIR}) target_link_directories(extensions_test PRIVATE ${ONNXRUNTIME_LIB_DIR})
endif() endif()
target_link_libraries(extensions_test PRIVATE ocos_operators extensions_shared onnxruntime gtest_main ${ocos_libraries} ${LINUX_CC_FLAGS}) target_link_libraries(extensions_test PRIVATE ocos_operators extensions_shared onnxruntime gtest_main ${ocos_libraries} ${LINUX_CC_FLAGS})
if (WIN32)
if(WIN32)
file(TO_CMAKE_PATH "${ONNXRUNTIME_LIB_DIR}/*" ONNXRUNTIME_LIB_FILEPATTERN) file(TO_CMAKE_PATH "${ONNXRUNTIME_LIB_DIR}/*" ONNXRUNTIME_LIB_FILEPATTERN)
file(GLOB ONNXRUNTIME_LIB_FILES CONFIGURE_DEPENDS "${ONNXRUNTIME_LIB_FILEPATTERN}") file(GLOB ONNXRUNTIME_LIB_FILES CONFIGURE_DEPENDS "${ONNXRUNTIME_LIB_FILEPATTERN}")
add_custom_command( add_custom_command(

41
cmake/ext_ortlib.cmake Normal file
Просмотреть файл

@ -0,0 +1,41 @@
if(_ONNXRUNTIME_EMBEDDED)
set(ONNXRUNTIME_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/include/onnxruntime/core/session)
set(ONNXRUNTIME_LIB_DIR "")
else()
set(ONNXRUNTIME_VER "1.10.0" CACHE STRING "ONNX Runtime version")
if(CMAKE_HOST_APPLE)
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-osx-universal2-${ONNXRUNTIME_VER}.tgz")
elseif(CMAKE_HOST_WIN32)
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-win-x64-${ONNXRUNTIME_VER}.zip")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64")
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-win-arm64-${ONNXRUNTIME_VER}.zip")
endif()
else()
# Linux or other, using Linux package to retrieve the headers
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-linux-x64-${ONNXRUNTIME_VER}.tgz")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(ONNXRUNTIME_URL "v${ONNXRUNTIME_VER}/onnxruntime-linux-aarch64-${ONNXRUNTIME_VER}.tgz")
endif()
endif()
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0135 NEW)
endif()
message(STATUS "ONNX Runtime URL suffix: ${ONNXRUNTIME_URL}")
FetchContent_Declare(
onnxruntime
URL https://github.com/microsoft/onnxruntime/releases/download/${ONNXRUNTIME_URL}
)
FetchContent_makeAvailable(onnxruntime)
set(ONNXRUNTIME_INCLUDE_DIR ${onnxruntime_SOURCE_DIR}/include)
set(ONNXRUNTIME_LIB_DIR ${onnxruntime_SOURCE_DIR}/lib)
endif()
if(NOT EXISTS ${ONNXRUNTIME_INCLUDE_DIR})
message(FATAL_ERROR "ONNX Runtime headers not found at ${ONNXRUNTIME_INCLUDE_DIR}")
endif()

Просмотреть файл

@ -8,9 +8,7 @@
#include <iterator> #include <iterator>
#include <vector> #include <vector>
#define ORT_API_MANUAL_INIT #include "onnxruntime_customop.hpp"
#include "onnxruntime_cxx_api.h"
#undef ORT_API_MANUAL_INIT
// A helper API to support test kernels. // A helper API to support test kernels.
// Must be invoked before RegisterCustomOps. // Must be invoked before RegisterCustomOps.
@ -40,13 +38,13 @@ struct BaseKernel {
protected: protected:
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status); OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
const OrtApi& api_; const OrtApi& api_;
Ort::CustomOpApi ort_; OrtW::CustomOpApi ort_;
const OrtKernelInfo* info_; const OrtKernelInfo* info_;
}; };
struct OrtTensorDimensions : std::vector<int64_t> { struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions() = default; OrtTensorDimensions() = default;
OrtTensorDimensions(Ort::CustomOpApi& ort, const OrtValue* value) { OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
std::vector<int64_t>::operator=(ort.GetTensorShape(info)); std::vector<int64_t>::operator=(ort.GetTensorShape(info));
ort.ReleaseTensorTypeAndShapeInfo(info); ort.ReleaseTensorTypeAndShapeInfo(info);

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -1,945 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
//
// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
// and automatically releasing resources in the destructors.
//
// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};).
//
// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone'
// methods for this purpose.
#pragma once
#include "onnxruntime_c_api.h"
#include <cstddef>
#include <array>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
#include <type_traits>
#ifdef ORT_NO_EXCEPTIONS
#include <iostream>
#endif
namespace Ort {
// All C++ methods that can fail will throw an exception of this type
struct Exception : std::exception {
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
OrtErrorCode GetOrtErrorCode() const { return code_; }
const char* what() const noexcept override { return message_.c_str(); }
private:
std::string message_;
OrtErrorCode code_;
};
#ifdef ORT_NO_EXCEPTIONS
#define ORT_CXX_API_THROW(string, code) \
do { \
std::cerr << Ort::Exception(string, code) \
.what() \
<< std::endl; \
abort(); \
} while (false)
#else
#define ORT_CXX_API_THROW(string, code) \
throw Ort::Exception(string, code)
#endif
// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make
// it transparent to the users of the API.
template <typename T>
struct Global {
static const OrtApi* api_;
};
// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
template <typename T>
#ifdef ORT_API_MANUAL_INIT
const OrtApi* Global<T>::api_{};
inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
#else
const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
#endif
// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
inline const OrtApi& GetApi() { return *Global<void>::api_; }
// This is a C++ wrapper for GetAvailableProviders() C API and returns
// a vector of strings representing the available execution providers.
std::vector<std::string> GetAvailableProviders();
// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
// This can't be done in the C API since C doesn't have function overloading.
#define ORT_DEFINE_RELEASE(NAME) \
inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
ORT_DEFINE_RELEASE(Allocator);
ORT_DEFINE_RELEASE(MemoryInfo);
ORT_DEFINE_RELEASE(CustomOpDomain);
ORT_DEFINE_RELEASE(Env);
ORT_DEFINE_RELEASE(RunOptions);
ORT_DEFINE_RELEASE(Session);
ORT_DEFINE_RELEASE(SessionOptions);
ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
ORT_DEFINE_RELEASE(SequenceTypeInfo);
ORT_DEFINE_RELEASE(MapTypeInfo);
ORT_DEFINE_RELEASE(TypeInfo);
ORT_DEFINE_RELEASE(Value);
ORT_DEFINE_RELEASE(ModelMetadata);
ORT_DEFINE_RELEASE(ThreadingOptions);
ORT_DEFINE_RELEASE(IoBinding);
ORT_DEFINE_RELEASE(ArenaCfg);
/*! \class Ort::Float16_t
* \brief it is a structure that represents float16 data.
* \details It is necessary for type dispatching to make use of C++ API
* The type is implicitly convertible to/from uint16_t.
* The size of the structure should align with uint16_t and one can freely cast
* uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
*
* Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
* on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
* And you can also feed a array of uint16_t elements directly. For example,
*
* \code{.unparsed}
* uint16_t values[] = { 15360, 16384, 16896, 17408, 17664};
* constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
* std::vector<int64_t> dims = {values_length}; // one dimensional example
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
* // Note we are passing bytes count in this api, not number of elements -> sizeof(values)
* auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values),
* dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
* \endcode
*
* Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
* a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
* template specialization.
*
* \code{.unparsed}
* namespace yours { struct half {}; } // assume this is your type, define this:
* namespace Ort {
* template<>
* struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
* } //namespace Ort
*
* std::vector<yours::half> values;
* std::vector<int64_t> dims = {values.size()}; // one dimensional example
* Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
* // Here we are passing element count -> values.size()
* auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
*
* \endcode
*/
struct Float16_t {
uint16_t value;
constexpr Float16_t() noexcept : value(0) {}
constexpr Float16_t(uint16_t v) noexcept : value(v) {}
constexpr operator uint16_t() const noexcept { return value; }
constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
};
static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
/*! \class Ort::BFloat16_t
* \brief is a structure that represents bfloat16 data.
* \details It is necessary for type dispatching to make use of C++ API
* The type is implicitly convertible to/from uint16_t.
* The size of the structure should align with uint16_t and one can freely cast
* uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
*
* See also code examples for Float16_t above.
*/
struct BFloat16_t {
uint16_t value;
constexpr BFloat16_t() noexcept : value(0) {}
constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
constexpr operator uint16_t() const noexcept { return value; }
constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
};
static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
// This is used internally by the C++ API. This is the common base class used by the wrapper objects.
template <typename T>
struct Base {
using contained_type = T;
Base() = default;
Base(T* p) : p_{p} {
if (!p)
ORT_CXX_API_THROW("Allocation failure", ORT_FAIL);
}
~Base() { OrtRelease(p_); }
operator T*() { return p_; }
operator const T*() const { return p_; }
T* release() {
T* p = p_;
p_ = nullptr;
return p;
}
protected:
Base(const Base&) = delete;
Base& operator=(const Base&) = delete;
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
void operator=(Base&& v) noexcept {
OrtRelease(p_);
p_ = v.p_;
v.p_ = nullptr;
}
T* p_{};
template <typename>
friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
};
template <typename T>
struct Base<const T> {
using contained_type = const T;
Base() = default;
Base(const T* p) : p_{p} {
if (!p)
ORT_CXX_API_THROW("Invalid instance ptr", ORT_INVALID_ARGUMENT);
}
~Base() = default;
operator const T*() const { return p_; }
protected:
Base(const Base&) = delete;
Base& operator=(const Base&) = delete;
Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
void operator=(Base&& v) noexcept {
p_ = v.p_;
v.p_ = nullptr;
}
const T* p_{};
};
template <typename T>
struct Unowned : T {
Unowned(decltype(T::p_) p) : T{p} {}
Unowned(Unowned&& v) : T{v.p_} {}
~Unowned() { this->release(); }
};
struct AllocatorWithDefaultOptions;
struct MemoryInfo;
struct Env;
struct TypeInfo;
struct Value;
struct ModelMetadata;
struct Env : Base<OrtEnv> {
Env(std::nullptr_t) {}
Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
Env& EnableTelemetryEvents();
Env& DisableTelemetryEvents();
Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
static const OrtApi* s_api;
};
struct CustomOpDomain : Base<OrtCustomOpDomain> {
explicit CustomOpDomain(std::nullptr_t) {}
explicit CustomOpDomain(const char* domain);
void Add(OrtCustomOp* op);
};
struct RunOptions : Base<OrtRunOptions> {
RunOptions(std::nullptr_t) {}
RunOptions();
RunOptions& SetRunLogVerbosityLevel(int);
int GetRunLogVerbosityLevel() const;
RunOptions& SetRunLogSeverityLevel(int);
int GetRunLogSeverityLevel() const;
RunOptions& SetRunTag(const char* run_tag);
const char* GetRunTag() const;
RunOptions& AddConfigEntry(const char* config_key, const char* config_value);
// terminate ALL currently executing Session::Run calls that were made using this RunOptions instance
RunOptions& SetTerminate();
// unset the terminate flag so this RunOptions instance can be used in a new Session::Run call
RunOptions& UnsetTerminate();
};
struct SessionOptions : Base<OrtSessionOptions> {
explicit SessionOptions(std::nullptr_t) {}
SessionOptions();
explicit SessionOptions(OrtSessionOptions* p) : Base<OrtSessionOptions>{p} {}
SessionOptions Clone() const;
SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads);
SessionOptions& SetInterOpNumThreads(int inter_op_num_threads);
SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
SessionOptions& EnableCpuMemArena();
SessionOptions& DisableCpuMemArena();
SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
SessionOptions& DisableProfiling();
SessionOptions& EnableOrtCustomOps();
SessionOptions& EnableMemPattern();
SessionOptions& DisableMemPattern();
SessionOptions& SetExecutionMode(ExecutionMode execution_mode);
SessionOptions& SetLogId(const char* logid);
SessionOptions& SetLogSeverityLevel(int level);
SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);
SessionOptions& DisablePerSessionThreads();
SessionOptions& AddConfigEntry(const char* config_key, const char* config_value);
SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
};
struct ModelMetadata : Base<OrtModelMetadata> {
explicit ModelMetadata(std::nullptr_t) {}
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {}
char* GetProducerName(OrtAllocator* allocator) const;
char* GetGraphName(OrtAllocator* allocator) const;
char* GetDomain(OrtAllocator* allocator) const;
char* GetDescription(OrtAllocator* allocator) const;
char* GetGraphDescription(OrtAllocator* allocator) const;
char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const;
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const;
int64_t GetVersion() const;
};
struct Session : Base<OrtSession> {
explicit Session(std::nullptr_t) {}
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container);
Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
// Run that will allocate the output values
std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count);
// Run for when there is a list of preallocated outputs
void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count);
void Run(const RunOptions& run_options, const struct IoBinding&);
size_t GetInputCount() const;
size_t GetOutputCount() const;
size_t GetOverridableInitializerCount() const;
char* GetInputName(size_t index, OrtAllocator* allocator) const;
char* GetOutputName(size_t index, OrtAllocator* allocator) const;
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const;
char* EndProfiling(OrtAllocator* allocator) const;
uint64_t GetProfilingStartTimeNs() const;
ModelMetadata GetModelMetadata() const;
TypeInfo GetInputTypeInfo(size_t index) const;
TypeInfo GetOutputTypeInfo(size_t index) const;
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const;
};
struct TensorTypeAndShapeInfo : Base<OrtTensorTypeAndShapeInfo> {
explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base<OrtTensorTypeAndShapeInfo>{p} {}
ONNXTensorElementDataType GetElementType() const;
size_t GetElementCount() const;
size_t GetDimensionsCount() const;
void GetDimensions(int64_t* values, size_t values_count) const;
void GetSymbolicDimensions(const char** values, size_t values_count) const;
std::vector<int64_t> GetShape() const;
};
struct SequenceTypeInfo : Base<OrtSequenceTypeInfo> {
explicit SequenceTypeInfo(std::nullptr_t) {}
explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base<OrtSequenceTypeInfo>{p} {}
TypeInfo GetSequenceElementType() const;
};
struct MapTypeInfo : Base<OrtMapTypeInfo> {
explicit MapTypeInfo(std::nullptr_t) {}
explicit MapTypeInfo(OrtMapTypeInfo* p) : Base<OrtMapTypeInfo>{p} {}
ONNXTensorElementDataType GetMapKeyType() const;
TypeInfo GetMapValueType() const;
};
struct TypeInfo : Base<OrtTypeInfo> {
explicit TypeInfo(std::nullptr_t) {}
explicit TypeInfo(OrtTypeInfo* p) : Base<OrtTypeInfo>{p} {}
Unowned<TensorTypeAndShapeInfo> GetTensorTypeAndShapeInfo() const;
Unowned<SequenceTypeInfo> GetSequenceTypeInfo() const;
Unowned<MapTypeInfo> GetMapTypeInfo() const;
ONNXType GetONNXType() const;
};
struct Value : Base<OrtValue> {
// This structure is used to feed sparse tensor values
// information for use with FillSparseTensor<Format>() API
// if the data type for the sparse tensor values is numeric
// use data.p_data, otherwise, use data.str pointer to feed
// values. data.str is an array of const char* that are zero terminated.
// number of strings in the array must match shape size.
// For fully sparse tensors use shape {0} and set p_data/str
// to nullptr.
struct OrtSparseValuesParam {
const int64_t* values_shape;
size_t values_shape_len;
union {
const void* p_data;
const char** str;
} data;
};
// Provides a way to pass shape in a single
// argument
struct Shape {
const int64_t* shape;
size_t shape_len;
};
template <typename T>
static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
ONNXTensorElementDataType type);
#if !defined(DISABLE_SPARSE_TENSORS)
/// <summary>
/// This is a simple forwarding method to the other overload that helps deducing
/// data type enum value from the type of the buffer.
/// </summary>
/// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
/// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
/// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
/// <param name="dense_shape">a would be dense shape of the tensor</param>
/// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
/// <returns></returns>
template <typename T>
static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
const Shape& values_shape);
/// <summary>
/// Creates an OrtValue instance containing SparseTensor. This constructs
/// a sparse tensor that makes use of user allocated buffers. It does not make copies
/// of the user provided data and does not modify it. The lifespan of user provided buffers should
/// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
/// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
/// to supply a sparse format specific indices.
/// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
/// can be properly copied into the allocated buffer.
/// </summary>
/// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
/// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
/// <param name="dense_shape">a would be dense shape of the tensor</param>
/// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
/// <param name="type">data type</param>
/// <returns>Ort::Value instance containing SparseTensor</returns>
static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
const Shape& values_shape, ONNXTensorElementDataType type);
/// <summary>
/// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
/// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
/// allocated buffers lifespan must eclipse that of the OrtValue.
/// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
/// </summary>
/// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
/// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
void UseCooIndices(int64_t* indices_data, size_t indices_num);
/// <summary>
/// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
/// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
/// allocated buffers lifespan must eclipse that of the OrtValue.
/// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
/// </summary>
/// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
/// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
/// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
/// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
/// <summary>
/// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
/// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
/// allocated buffers lifespan must eclipse that of the OrtValue.
/// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
/// </summary>
/// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
/// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
#endif // !defined(DISABLE_SPARSE_TENSORS)
template <typename T>
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
#if !defined(DISABLE_SPARSE_TENSORS)
/// <summary>
/// This is a simple forwarding method the below CreateSparseTensor.
/// This helps to specify data type enum in terms of C++ data type.
/// Use CreateSparseTensor<T>
/// </summary>
/// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
/// <param name="allocator">allocator to use</param>
/// <param name="dense_shape">a would be dense shape of the tensor</param>
/// <returns>Ort::Value</returns>
template <typename T>
static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
/// <summary>
/// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
/// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
/// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
/// Use this API to create OrtValues that contain sparse tensors with all supported data types including
/// strings.
/// </summary>
/// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
/// <param name="dense_shape">a would be dense shape of the tensor</param>
/// <param name="type">data type</param>
/// <returns>an instance of Ort::Value</returns>
static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
/// <summary>
/// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
/// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
/// at difference device than the allocator, a X-device copy will be performed if possible.
/// </summary>
/// <param name="data_mem_info">specified buffer memory description</param>
/// <param name="values_param">values buffer information.</param>
/// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
/// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
const int64_t* indices_data, size_t indices_num);
/// <summary>
/// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
/// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
/// at difference device than the allocator, a X-device copy will be performed if possible.
/// </summary>
/// <param name="data_mem_info">specified buffer memory description</param>
/// <param name="values_param">values buffer information</param>
/// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
/// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
/// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
/// <param name="outer_indices_num">number of csr outer indices or 0</param>
void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
const OrtSparseValuesParam& values,
const int64_t* inner_indices_data, size_t inner_indices_num,
const int64_t* outer_indices_data, size_t outer_indices_num);
/// <summary>
/// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
/// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
/// at difference device than the allocator, a X-device copy will be performed if possible.
/// </summary>
/// <param name="data_mem_info">specified buffer memory description</param>
/// <param name="values_param">values buffer information</param>
/// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
/// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
const OrtSparseValuesParam& values,
const Shape& indices_shape,
const int32_t* indices_data);
/// <summary>
/// The API returns the sparse data format this OrtValue holds in a sparse tensor.
/// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
/// the value returned is ORT_SPARSE_UNDEFINED.
/// </summary>
/// <returns>Format enum</returns>
OrtSparseFormat GetSparseFormat() const;
/// <summary>
/// The API returns type and shape information for stored non-zero values of the
/// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
/// </summary>
/// <returns>TensorTypeAndShapeInfo values information</returns>
TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
/// <summary>
/// The API returns type and shape information for the specified indices. Each supported
/// indices have their own enum values even if a give format has more than one kind of indices.
/// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
/// </summary>
/// <param name="">enum requested</param>
/// <returns>type and shape information</returns>
TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat) const;
/// <summary>
/// The API retrieves a pointer to the internal indices buffer. The API merely performs
/// a convenience data type casting on the return type pointer. Make sure you are requesting
/// the right type, use GetSparseTensorIndicesTypeShapeInfo();
/// </summary>
/// <typeparam name="T">type to cast to</typeparam>
/// <param name="indices_format">requested indices kind</param>
/// <param name="num_indices">number of indices entries</param>
/// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
template <typename T>
const T* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
#endif // !defined(DISABLE_SPARSE_TENSORS)
static Value CreateMap(Value& keys, Value& values);
static Value CreateSequence(std::vector<Value>& values);
template <typename T>
static Value CreateOpaque(const char* domain, const char* type_name, const T&);
template <typename T>
void GetOpaqueData(const char* domain, const char* type_name, T&) const;
explicit Value(std::nullptr_t) {}
explicit Value(OrtValue* p) : Base<OrtValue>{p} {}
Value(Value&&) = default;
Value& operator=(Value&&) = default;
bool IsTensor() const;
#if !defined(DISABLE_SPARSE_TENSORS)
/// <summary>
/// Returns true if the OrtValue contains a sparse tensor
/// </summary>
/// <returns></returns>
bool IsSparseTensor() const;
#endif
size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
Value GetValue(int index, OrtAllocator* allocator) const;
/// <summary>
/// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
/// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
/// for allocating necessary memory and calling GetStringTensorContent().
/// </summary>
/// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
size_t GetStringTensorDataLength() const;
/// <summary>
/// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
/// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
/// The user must also allocate offsets buffer with the number of entries equal to that of the contained
/// strings.
///
/// Strings are always assumed to be on CPU, no X-device copy.
/// </summary>
/// <param name="buffer">user allocated buffer</param>
/// <param name="buffer_length">length in bytes of the allocated buffer</param>
/// <param name="offsets">a pointer to the offsets user allocated buffer</param>
/// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
/// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
/// for sparse tensors</param>
void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
template <typename T>
T* GetTensorMutableData();
template <typename T>
const T* GetTensorData() const;
#if !defined(DISABLE_SPARSE_TENSORS)
/// <summary>
/// The API returns a pointer to an internal buffer of the sparse tensor
/// containing non-zero values. The API merely does casting. Make sure you
/// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
/// first.
/// </summary>
/// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
/// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
template <typename T>
const T* GetSparseTensorValues() const;
#endif
template <typename T>
T& At(const std::vector<int64_t>& location);
/// <summary>
/// The API returns type information for data contained in a tensor. For sparse
/// tensors it returns type information for contained non-zero values.
/// It returns dense shape for sparse tensors.
/// </summary>
/// <returns>TypeInfo</returns>
TypeInfo GetTypeInfo() const;
/// <summary>
/// The API returns type information for data contained in a tensor. For sparse
/// tensors it returns type information for contained non-zero values.
/// It returns dense shape for sparse tensors.
/// </summary>
/// <returns>TensorTypeAndShapeInfo</returns>
TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
/// <summary>
/// The API returns a byte length of UTF-8 encoded string element
/// contained in either a tensor or a spare tensor values.
/// </summary>
/// <param name="element_index"></param>
/// <returns>byte length for the specified string element</returns>
size_t GetStringTensorElementLength(size_t element_index) const;
/// <summary>
/// The API copies UTF-8 encoded bytes for the requested string element
/// contained within a tensor or a sparse tensor into a provided buffer.
/// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
/// </summary>
/// <param name="buffer_length"></param>
/// <param name="element_index"></param>
/// <param name="buffer"></param>
void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
void FillStringTensor(const char* const* s, size_t s_len);
void FillStringTensorElement(const char* s, size_t index);
};
// Represents native memory allocation
struct MemoryAllocation {
MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
~MemoryAllocation();
MemoryAllocation(const MemoryAllocation&) = delete;
MemoryAllocation& operator=(const MemoryAllocation&) = delete;
MemoryAllocation(MemoryAllocation&&) noexcept;
MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
void* get() { return p_; }
size_t size() const { return size_; }
private:
OrtAllocator* allocator_;
void* p_;
size_t size_;
};
struct AllocatorWithDefaultOptions {
AllocatorWithDefaultOptions();
operator OrtAllocator*() { return p_; }
operator const OrtAllocator*() const { return p_; }
void* Alloc(size_t size);
// The return value will own the allocation
MemoryAllocation GetAllocation(size_t size);
void Free(void* p);
const OrtMemoryInfo* GetInfo() const;
private:
OrtAllocator* p_{};
};
template <typename B>
struct BaseMemoryInfo : B {
BaseMemoryInfo() = default;
explicit BaseMemoryInfo(typename B::contained_type* p) : B(p) {}
~BaseMemoryInfo() = default;
BaseMemoryInfo(BaseMemoryInfo&&) = default;
BaseMemoryInfo& operator=(BaseMemoryInfo&&) = default;
std::string GetAllocatorName() const;
OrtAllocatorType GetAllocatorType() const;
int GetDeviceId() const;
OrtMemType GetMemoryType() const;
template <typename U>
bool operator==(const BaseMemoryInfo<U>& o) const;
};
struct UnownedMemoryInfo : BaseMemoryInfo<Base<const OrtMemoryInfo> > {
explicit UnownedMemoryInfo(std::nullptr_t) {}
explicit UnownedMemoryInfo(const OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
};
struct MemoryInfo : BaseMemoryInfo<Base<OrtMemoryInfo> > {
static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
explicit MemoryInfo(std::nullptr_t) {}
explicit MemoryInfo(OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
};
struct Allocator : public Base<OrtAllocator> {
Allocator(const Session& session, const MemoryInfo&);
void* Alloc(size_t size) const;
// The return value will own the allocation
MemoryAllocation GetAllocation(size_t size);
void Free(void* p) const;
UnownedMemoryInfo GetInfo() const;
};
struct IoBinding : public Base<OrtIoBinding> {
private:
std::vector<std::string> GetOutputNamesHelper(OrtAllocator*) const;
std::vector<Value> GetOutputValuesHelper(OrtAllocator*) const;
public:
explicit IoBinding(Session& session);
void BindInput(const char* name, const Value&);
void BindOutput(const char* name, const Value&);
void BindOutput(const char* name, const MemoryInfo&);
std::vector<std::string> GetOutputNames() const;
std::vector<std::string> GetOutputNames(Allocator&) const;
std::vector<Value> GetOutputValues() const;
std::vector<Value> GetOutputValues(Allocator&) const;
void ClearBoundInputs();
void ClearBoundOutputs();
};
/*! \struct Ort::ArenaCfg
* \brief it is a structure that represents the configuration of an arena based allocator
* \details Please see docs/C_API.md for details
*/
struct ArenaCfg : Base<OrtArenaCfg> {
explicit ArenaCfg(std::nullptr_t) {}
/**
* \param max_mem - use 0 to allow ORT to choose the default
* \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
* \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
* \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
* See docs/C_API.md for details on what the following parameters mean and how to choose these values
*/
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
};
//
// Custom OPs (only needed to implement custom OPs)
//
struct CustomOpApi {
CustomOpApi(const OrtApi& api) : api_(api) {}
template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
template <typename T>
T* GetTensorMutableData(_Inout_ OrtValue* value);
template <typename T>
const T* GetTensorData(_Inout_ const OrtValue* value);
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
size_t KernelContext_GetInputCount(const OrtKernelContext* context);
const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
void ThrowOnError(OrtStatus* result);
private:
const OrtApi& api_;
};
template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26409)
#endif
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
}
template <typename... Args>
TKernel* CreateKernelImpl(Args&&... args) const {
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26409)
#endif
return new TKernel(std::forward<Args>(args)...);
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
}
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api);
}
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
const char* GetExecutionProviderType() const { return nullptr; }
// Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
// (inputs and outputs are required by default)
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
};
} // namespace Ort
#include "onnxruntime_cxx_inline.h"

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,306 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
#pragma once
#include <cstddef>
#include <array>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
#include <type_traits>
#ifdef ORT_NO_EXCEPTIONS
#include <iostream>
#endif
#include "onnxruntime_c_api.h"
namespace OrtW {
// All C++ methods that can fail will throw an exception of this type
struct Exception : std::exception {
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
OrtErrorCode GetOrtErrorCode() const { return code_; }
const char* what() const noexcept override { return message_.c_str(); }
private:
std::string message_;
OrtErrorCode code_;
};
#ifdef ORT_NO_EXCEPTIONS
#define ORTX_CXX_API_THROW(string, code) \
do { \
std::cerr << OrtW::Exception(string, code) \
.what() \
<< std::endl; \
abort(); \
} while (false)
#else
#define ORTX_CXX_API_THROW(string, code) \
throw OrtW::Exception(string, code)
#endif
inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) {
if (status) {
std::string error_message = ort.GetErrorMessage(status);
OrtErrorCode error_code = ort.GetErrorCode(status);
ort.ReleaseStatus(status);
ORTX_CXX_API_THROW(std::move(error_message), error_code);
}
}
//
// Custom OPs (only needed to implement custom OPs)
//
struct CustomOpApi {
CustomOpApi(const OrtApi& api) : api_(api) {}
template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const;
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
template <typename T>
T* GetTensorMutableData(_Inout_ OrtValue* value) const;
template <typename T>
const T* GetTensorData(_Inout_ const OrtValue* value) const;
std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) const;
void ThrowOnError(OrtStatus* status) const {
OrtW::ThrowOnError(api_, status);
}
private:
const OrtApi& api_;
};
template <typename TOp, typename TKernel>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
OrtCustomOp::version = ORT_API_VERSION;
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26409)
#endif
OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
}
template <typename... Args>
TKernel* CreateKernelImpl(Args&&... args) const {
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26409)
#endif
return new TKernel(std::forward<Args>(args)...);
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
}
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api);
}
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
const char* GetExecutionProviderType() const { return nullptr; }
// Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
// (inputs and outputs are required by default)
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
};
//
// Custom OP API Inlines
//
template <>
inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
float out;
ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
return out;
}
template <>
inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
int64_t out;
ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
return out;
}
template <>
inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
size_t size = 0;
std::string out;
// Feed nullptr for the data buffer to query the true size of the string attribute
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
out.resize(size - 1); // remove the terminating character '\0'
} else {
ThrowOnError(status);
}
return out;
}
template <>
inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
size_t size = 0;
std::vector<float> out;
// Feed nullptr for the data buffer to query the true size of the attribute
OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
} else {
ThrowOnError(status);
}
return out;
}
template <>
inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
size_t size = 0;
std::vector<int64_t> out;
// Feed nullptr for the data buffer to query the true size of the attribute
OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
if (status == nullptr) {
out.resize(size);
ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
} else {
ThrowOnError(status);
}
return out;
}
inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
OrtTensorTypeAndShapeInfo* out;
ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
return out;
}
inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
size_t out;
ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
return out;
}
inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
ONNXTensorElementDataType out;
ThrowOnError(api_.GetTensorElementType(info, &out));
return out;
}
inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
size_t out;
ThrowOnError(api_.GetDimensionsCount(info, &out));
return out;
}
inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
}
inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
}
template <typename T>
inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
T* data;
ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
return data;
}
template <typename T>
inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
}
inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
std::vector<int64_t> output(GetDimensionsCount(info));
GetDimensions(info, output.data(), output.size());
return output;
}
inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
api_.ReleaseTensorTypeAndShapeInfo(input);
}
inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
size_t out;
ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
return out;
}
inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
const OrtValue* out;
ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
return out;
}
inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
size_t out;
ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
return out;
}
inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
_In_ const int64_t* dim_values, size_t dim_count) const {
OrtValue* out;
ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
return out;
}
} // namespace OrtW

Просмотреть файл

@ -16,7 +16,7 @@ struct KernelGaussianBlur : BaseKernel {
const OrtValue* input_ksize = ort_.KernelContext_GetInput(context, 1); const OrtValue* input_ksize = ort_.KernelContext_GetInput(context, 1);
OrtTensorDimensions dim_ksize(ort_, input_ksize); OrtTensorDimensions dim_ksize(ort_, input_ksize);
if (dim_ksize.size() != 1 || dim_ksize[0] != 2) { if (dim_ksize.size() != 1 || dim_ksize[0] != 2) {
ORT_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[GaussianBlur]: ksize shape is (2,)", ORT_INVALID_ARGUMENT);
} }
std::copy_n(ort_.GetTensorData<std::int64_t>(input_ksize), 2, ksize); std::copy_n(ort_.GetTensorData<std::int64_t>(input_ksize), 2, ksize);
} }
@ -25,7 +25,7 @@ struct KernelGaussianBlur : BaseKernel {
const OrtValue* input_sigma = ort_.KernelContext_GetInput(context, 2); const OrtValue* input_sigma = ort_.KernelContext_GetInput(context, 2);
OrtTensorDimensions dim_sigma(ort_, input_sigma); OrtTensorDimensions dim_sigma(ort_, input_sigma);
if (dim_sigma.size() != 1 || dim_sigma[0] != 2) { if (dim_sigma.size() != 1 || dim_sigma[0] != 2) {
ORT_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[GaussianBlur]: sigma shape is (2,)", ORT_INVALID_ARGUMENT);
} }
std::copy_n(ort_.GetTensorData<double>(input_sigma), 2, sigma); std::copy_n(ort_.GetTensorData<double>(input_sigma), 2, sigma);
} }
@ -53,7 +53,7 @@ struct KernelGaussianBlur : BaseKernel {
} }
}; };
struct CustomOpGaussianBlur : Ort::CustomOpBase<CustomOpGaussianBlur, KernelGaussianBlur> { struct CustomOpGaussianBlur : OrtW::CustomOpBase<CustomOpGaussianBlur, KernelGaussianBlur> {
size_t GetInputTypeCount() const { size_t GetInputTypeCount() const {
return 3; return 3;
} }

Просмотреть файл

@ -20,7 +20,7 @@ struct KernelImageDecoder : BaseKernel {
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL); const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
OrtTensorDimensions dimensions(ort_, inputs); OrtTensorDimensions dimensions(ort_, inputs);
if (dimensions.size() != 1ULL) { if (dimensions.size() != 1ULL) {
ORT_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[ImageDecoder]: Only raw image formats are supported.", ORT_INVALID_ARGUMENT);
} }
// Get data & the length // Get data & the length
@ -48,7 +48,7 @@ struct KernelImageDecoder : BaseKernel {
} }
}; };
struct CustomOpImageDecoder : Ort::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> { struct CustomOpImageDecoder : OrtW::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelImageDecoder(api); return new KernelImageDecoder(api);
} }
@ -66,7 +66,7 @@ struct CustomOpImageDecoder : Ort::CustomOpBase<CustomOpImageDecoder, KernelImag
case 0: case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
} }
} }
@ -79,7 +79,7 @@ struct CustomOpImageDecoder : Ort::CustomOpBase<CustomOpImageDecoder, KernelImag
case 0: case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
} }
} }
}; };

Просмотреть файл

@ -14,7 +14,7 @@ struct KernelImageReader : BaseKernel {
int n = input_data_dimensions[0]; int n = input_data_dimensions[0];
if (n != 1) { if (n != 1) {
ORT_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[ImageReader]: the dimension of input value can only be 1 now.", ORT_INVALID_ARGUMENT);
} }
std::vector<std::string> image_paths; std::vector<std::string> image_paths;
@ -28,7 +28,7 @@ struct KernelImageReader : BaseKernel {
} }
}; };
struct CustomOpImageReader : Ort::CustomOpBase<CustomOpImageReader, KernelImageReader> { struct CustomOpImageReader : OrtW::CustomOpBase<CustomOpImageReader, KernelImageReader> {
size_t GetInputTypeCount() const { size_t GetInputTypeCount() const {
return 1; return 1;
} }

Просмотреть файл

@ -81,7 +81,7 @@ ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetInputType(size_
case 2: case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
} }
} }
@ -94,6 +94,6 @@ ONNXTensorElementDataType CustomOpSuperResolutionPostProcess::GetOutputType(size
case 0: case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
} }
} }

Просмотреть файл

@ -10,7 +10,7 @@ struct KernelSuperResolutionPostProcess : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpSuperResolutionPostProcess : Ort::CustomOpBase<CustomOpSuperResolutionPostProcess, KernelSuperResolutionPostProcess> { struct CustomOpSuperResolutionPostProcess : OrtW::CustomOpBase<CustomOpSuperResolutionPostProcess, KernelSuperResolutionPostProcess> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -71,7 +71,7 @@ ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetInputType(size_t
case 0: case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
} }
} }
@ -86,6 +86,6 @@ ONNXTensorElementDataType CustomOpSuperResolutionPreProcess::GetOutputType(size_
case 2: case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected output index ", index), ORT_INVALID_ARGUMENT);
} }
} }

Просмотреть файл

@ -10,7 +10,7 @@ struct KernelSuperResolutionPreProcess : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpSuperResolutionPreProcess : Ort::CustomOpBase<CustomOpSuperResolutionPreProcess, KernelSuperResolutionPreProcess> { struct CustomOpSuperResolutionPreProcess : OrtW::CustomOpBase<CustomOpSuperResolutionPreProcess, KernelSuperResolutionPreProcess> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -33,7 +33,7 @@ struct KernelInverse : BaseKernel {
} }
}; };
struct CustomOpInverse : Ort::CustomOpBase<CustomOpInverse, KernelInverse> { struct CustomOpInverse : OrtW::CustomOpBase<CustomOpInverse, KernelInverse> {
const char* GetName() const { const char* GetName() const {
return "Inverse"; return "Inverse";
} }

Просмотреть файл

@ -39,7 +39,7 @@ struct KernelNegPos : BaseKernel {
} }
}; };
struct CustomOpNegPos : Ort::CustomOpBase<CustomOpNegPos, KernelNegPos> { struct CustomOpNegPos : OrtW::CustomOpBase<CustomOpNegPos, KernelNegPos> {
const char* GetName() const{ const char* GetName() const{
return "NegPos"; return "NegPos";
} }

Просмотреть файл

@ -11,7 +11,7 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
const int64_t* p_data = ort_.GetTensorData<int64_t>(input); const int64_t* p_data = ort_.GetTensorData<int64_t>(input);
OrtTensorDimensions input_dim(ort_, input); OrtTensorDimensions input_dim(ort_, input);
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) { if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
ORT_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n]." , ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n]." , ORT_INVALID_GRAPH);
} }
std::vector<std::int64_t> segment_value; std::vector<std::int64_t> segment_value;

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelSegmentExtraction : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpSegmentExtraction : Ort::CustomOpBase<CustomOpSegmentExtraction, KernelSegmentExtraction> { struct CustomOpSegmentExtraction : OrtW::CustomOpBase<CustomOpSegmentExtraction, KernelSegmentExtraction> {
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;

Просмотреть файл

@ -4,7 +4,7 @@
#include "segment_sum.hpp" #include "segment_sum.hpp"
template <typename T> template <typename T>
void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context) { void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
// Setup inputs // Setup inputs
const OrtValue* data = ort_.KernelContext_GetInput(context, 0); const OrtValue* data = ort_.KernelContext_GetInput(context, 0);
const T* p_data = ort_.GetTensorData<T>(data); const T* p_data = ort_.GetTensorData<T>(data);
@ -15,11 +15,11 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
OrtTensorDimensions dim_data(ort_, data); OrtTensorDimensions dim_data(ort_, data);
OrtTensorDimensions dim_seg(ort_, segment_ids); OrtTensorDimensions dim_seg(ort_, segment_ids);
if (dim_data.size() == 0 || dim_seg.size() == 0) if (dim_data.size() == 0 || dim_seg.size() == 0)
ORT_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("Both inputs cannot be empty.", ORT_INVALID_GRAPH);
if (dim_seg.size() != 1) if (dim_seg.size() != 1)
ORT_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
if (dim_data[0] != dim_seg[0]) if (dim_data[0] != dim_seg[0])
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data, "First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
" segment_ids shape: ", dim_seg), ORT_INVALID_GRAPH); " segment_ids shape: ", dim_seg), ORT_INVALID_GRAPH);
@ -42,7 +42,7 @@ void KernelSegmentSum_Compute(Ort::CustomOpApi& ort_, OrtKernelContext* context)
const int64_t* p_seg = p_segment_ids; const int64_t* p_seg = p_segment_ids;
for (; begin != end; ++p_seg) { for (; begin != end; ++p_seg) {
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1)) if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
ORT_CXX_API_THROW(MakeString("segment_ids must be increasing but found ", ORTX_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
*(p_seg - 1), " and ", *p_seg, " at position ", *(p_seg - 1), " and ", *p_seg, " at position ",
std::distance(p_segment_ids, p_seg), "."), ORT_RUNTIME_EXCEPTION); std::distance(p_segment_ids, p_seg), "."), ORT_RUNTIME_EXCEPTION);
p_out = p_output + *p_seg * in_stride; p_out = p_output + *p_seg * in_stride;
@ -82,6 +82,6 @@ ONNXTensorElementDataType CustomOpSegmentSum::GetInputType(size_t index) const {
case 1: case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default: default:
ORT_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Operator SegmentSum has 2 inputs.", ORT_INVALID_ARGUMENT);
} }
}; };

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelSegmentSum : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpSegmentSum : Ort::CustomOpBase<CustomOpSegmentSum, KernelSegmentSum> { struct CustomOpSegmentSum : OrtW::CustomOpBase<CustomOpSegmentSum, KernelSegmentSum> {
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;

Просмотреть файл

@ -5,7 +5,7 @@
bool BaseKernel::HasAttribute(const char* name) const { bool BaseKernel::HasAttribute(const char* name) const {
if (info_ == nullptr) { if (info_ == nullptr) {
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
} }
size_t size; size_t size;
std::string out; std::string out;
@ -46,7 +46,7 @@ void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std:
template <> template <>
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) { bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
if (info_ == nullptr) { if (info_ == nullptr) {
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
} }
size_t size = 0; size_t size = 0;
@ -71,7 +71,7 @@ bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
template <> template <>
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) { bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
if (info_ == nullptr) { if (info_ == nullptr) {
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
} }
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK; return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
@ -80,7 +80,7 @@ bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
template <> template <>
bool BaseKernel::TryToGetAttribute(const char* name, float& value) { bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
if (info_ == nullptr) { if (info_ == nullptr) {
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
} }
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK; return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
@ -89,7 +89,7 @@ bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
template <> template <>
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) { bool BaseKernel::TryToGetAttribute(const char* name, bool& value) {
if (info_ == nullptr) { if (info_ == nullptr) {
ORT_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
} }
int64_t origin_value = 0; int64_t origin_value = 0;

Просмотреть файл

@ -3,17 +3,17 @@
#include "string_utils.h" #include "string_utils.h"
#include "string_tensor.h" #include "string_tensor.h"
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
const OrtValue* value, std::vector<std::string>& output) { const OrtValue* value, std::vector<std::string>& output) {
(void)context; (void)context;
OrtTensorDimensions dimensions(ort, value); OrtTensorDimensions dimensions(ort, value);
size_t len = static_cast<size_t>(dimensions.Size()); size_t len = static_cast<size_t>(dimensions.Size());
size_t data_len; size_t data_len;
Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len)); OrtW::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
output.resize(len); output.resize(len);
std::vector<char> result(data_len + len + 1, '\0'); std::vector<char> result(data_len + len + 1, '\0');
std::vector<size_t> offsets(len); std::vector<size_t> offsets(len);
Ort::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size())); OrtW::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
output.resize(len); output.resize(len);
for (int64_t i = (int64_t)len - 1; i >= 0; --i) { for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
if (i < static_cast<int64_t>(len) - 1) if (i < static_cast<int64_t>(len) - 1)
@ -22,7 +22,7 @@ void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKer
} }
} }
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
const std::vector<std::string>& value, OrtValue* output) { const std::vector<std::string>& value, OrtValue* output) {
(void)ort; (void)ort;
(void)context; (void)context;
@ -31,10 +31,10 @@ void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelCon
temp[i] = value[i].c_str(); temp[i] = value[i].c_str();
} }
Ort::ThrowOnError(api,api.FillStringTensor(output, temp.data(), value.size())); OrtW::ThrowOnError(api,api.FillStringTensor(output, temp.data(), value.size()));
} }
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
const OrtValue* value, std::vector<ustring>& output) { const OrtValue* value, std::vector<ustring>& output) {
std::vector<std::string> utf8_strings; std::vector<std::string> utf8_strings;
GetTensorMutableDataString(api, ort, context, value, utf8_strings); GetTensorMutableDataString(api, ort, context, value, utf8_strings);
@ -46,7 +46,7 @@ void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKer
} }
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
const std::vector<ustring>& value, OrtValue* output) { const std::vector<ustring>& value, OrtValue* output) {
std::vector<std::string> utf8_strings; std::vector<std::string> utf8_strings;
utf8_strings.reserve(value.size()); utf8_strings.reserve(value.size());

Просмотреть файл

@ -10,14 +10,14 @@
// Retrieves a vector of strings if the input type is std::string. // Retrieves a vector of strings if the input type is std::string.
// It is a copy of the input data and can be modified to compute the output. // It is a copy of the input data and can be modified to compute the output.
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
const OrtValue* value, std::vector<std::string>& output); const OrtValue* value, std::vector<std::string>& output);
void GetTensorMutableDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, void GetTensorMutableDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
const OrtValue* value, std::vector<ustring>& output); const OrtValue* value, std::vector<ustring>& output);
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
const std::vector<std::string>& value, OrtValue* output); const std::vector<std::string>& value, OrtValue* output);
void FillTensorDataString(const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, void FillTensorDataString(const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context,
const std::vector<ustring>& value, OrtValue* output); const std::vector<ustring>& value, OrtValue* output);

Просмотреть файл

@ -21,11 +21,11 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
OrtTensorDimensions mask_dimensions(ort_, input_mask); OrtTensorDimensions mask_dimensions(ort_, input_mask);
if (!(value_dimensions.IsScalar() || value_dimensions.IsVector())) { if (!(value_dimensions.IsScalar() || value_dimensions.IsVector())) {
ORT_CXX_API_THROW("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value should be vector or scalar.", ORT_INVALID_ARGUMENT);
} }
if (value_dimensions != mask_dimensions) { if (value_dimensions != mask_dimensions) {
ORT_CXX_API_THROW("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[MaskedFill]: the dimension of input value and mask should be same.", ORT_INVALID_ARGUMENT);
} }
std::vector<std::string> value; std::vector<std::string> value;
@ -68,7 +68,7 @@ ONNXTensorElementDataType CustomOpMaskedFill::GetInputType(size_t index) const {
case 1: case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
}}; }};
size_t CustomOpMaskedFill::GetOutputTypeCount() const { size_t CustomOpMaskedFill::GetOutputTypeCount() const {

Просмотреть файл

@ -14,7 +14,7 @@ struct KernelMaskedFill : BaseKernel {
std::unordered_map<std::string, std::string> map_; std::unordered_map<std::string, std::string> map_;
}; };
struct CustomOpMaskedFill : Ort::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> { struct CustomOpMaskedFill : OrtW::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelStringEqual : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringEqual : Ort::CustomOpBase<CustomOpStringEqual, KernelStringEqual> { struct CustomOpStringEqual : OrtW::CustomOpBase<CustomOpStringEqual, KernelStringEqual> {
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;

Просмотреть файл

@ -13,7 +13,7 @@ class BroadcastIteratorRight {
const std::vector<int64_t>& shape2, const std::vector<int64_t>& shape2,
const T1* p1, const T2* p2, T3* p3) : shape1_(shape1), p1_(p1), p2_(p2), p3_(p3) { const T1* p1, const T2* p2, T3* p3) : shape1_(shape1), p1_(p1), p2_(p2), p3_(p3) {
if (shape2.size() > shape1.size()) if (shape2.size() > shape1.size())
ORT_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("shape2 must have less dimensions than shape1", ORT_INVALID_ARGUMENT);
shape2_.resize(shape1_.size()); shape2_.resize(shape1_.size());
cum_shape2_.resize(shape1_.size()); cum_shape2_.resize(shape1_.size());
total_ = 1; total_ = 1;
@ -26,7 +26,7 @@ class BroadcastIteratorRight {
shape2_[i] = shape2[i]; shape2_[i] = shape2[i];
} }
if (shape2[i] != 1 && shape1[i] != shape2[i]) { if (shape2[i] != 1 && shape1[i] != shape2[i]) {
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT); "Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]), ORT_INVALID_ARGUMENT);
} }
} }
@ -84,7 +84,7 @@ class BroadcastIteratorRight {
template <typename TCMP> template <typename TCMP>
void loop(TCMP& cmp, BroadcastIteratorRightState& /*it*/, int64_t pos = 0) { void loop(TCMP& cmp, BroadcastIteratorRightState& /*it*/, int64_t pos = 0) {
if (pos != 0) if (pos != 0)
ORT_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED); ORTX_CXX_API_THROW("Not implemented yet.", ORT_NOT_IMPLEMENTED);
while (!end()) { while (!end()) {
*p3 = cmp(*p1, *p2); *p3 = cmp(*p1, *p2);
next(); next();
@ -114,7 +114,7 @@ inline bool Compare<std::string>::operator()(const std::string& s1, const std::s
} }
template <typename T> template <typename T>
void KernelEqual_Compute(const OrtApi& api, Ort::CustomOpApi& ort_, OrtKernelContext* context) { void KernelEqual_Compute(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
// Setup inputs // Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0); const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const T* X = ort_.GetTensorData<T>(input_X); const T* X = ort_.GetTensorData<T>(input_X);
@ -144,7 +144,7 @@ void KernelEqual_Compute(const OrtApi& api, Ort::CustomOpApi& ort_, OrtKernelCon
} }
template <> template <>
void KernelEqual_Compute<std::string>(const OrtApi& api, Ort::CustomOpApi& ort_, OrtKernelContext* context) { void KernelEqual_Compute<std::string>(const OrtApi& api, OrtW::CustomOpApi& ort_, OrtKernelContext* context) {
// Setup inputs // Setup inputs
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0); const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1); const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);

Просмотреть файл

@ -11,7 +11,7 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
OrtTensorDimensions d_length(ort_, n_elements); OrtTensorDimensions d_length(ort_, n_elements);
if (d_length.size() != 1) if (d_length.size() != 1)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT); "First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
int64_t n_els = d_length[0] - 1; int64_t n_els = d_length[0] - 1;
int64_t n_values = p_n_elements[n_els]; int64_t n_values = p_n_elements[n_els];
@ -103,7 +103,7 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
for (int64_t i = 0; i < size - 1; ++i) { for (int64_t i = 0; i < size - 1; ++i) {
pos_end = pos + max_col; pos_end = pos + max_col;
if (pos_end > shape_out_size) if (pos_end > shape_out_size)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1], "Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT); " - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) { for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
@ -161,7 +161,7 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
for (int64_t i = 0; i < size - 1; ++i) { for (int64_t i = 0; i < size - 1; ++i) {
pos_end = pos + max_col; pos_end = pos + max_col;
if (pos_end > shape_out_size) if (pos_end > shape_out_size)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1], "Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT); " - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) { for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
@ -203,6 +203,6 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t
case 3: case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default: default:
ORT_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
} }
}; };

Просмотреть файл

@ -12,7 +12,7 @@ struct KernelRaggedTensorToSparse : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpRaggedTensorToSparse : Ort::CustomOpBase<CustomOpRaggedTensorToSparse, KernelRaggedTensorToSparse> { struct CustomOpRaggedTensorToSparse : OrtW::CustomOpBase<CustomOpRaggedTensorToSparse, KernelRaggedTensorToSparse> {
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;
@ -36,7 +36,7 @@ struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
int64_t missing_value_; int64_t missing_value_;
}; };
struct CustomOpRaggedTensorToDense : Ort::CustomOpBase<CustomOpRaggedTensorToDense, KernelRaggedTensorToDense> { struct CustomOpRaggedTensorToDense : OrtW::CustomOpBase<CustomOpRaggedTensorToDense, KernelRaggedTensorToDense> {
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;
@ -50,7 +50,7 @@ struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringRaggedTensorToDense : Ort::CustomOpBase<CustomOpStringRaggedTensorToDense, KernelStringRaggedTensorToDense> { struct CustomOpStringRaggedTensorToDense : OrtW::CustomOpBase<CustomOpStringRaggedTensorToDense, KernelStringRaggedTensorToDense> {
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
size_t GetOutputTypeCount() const; size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(size_t index) const; ONNXTensorElementDataType GetOutputType(size_t index) const;

Просмотреть файл

@ -27,15 +27,15 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
OrtTensorDimensions pattern_dimensions(ort_, pattern); OrtTensorDimensions pattern_dimensions(ort_, pattern);
OrtTensorDimensions rewrite_dimensions(ort_, rewrite); OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1) if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"pattern (second input) must contain only one element. It has ", "pattern (second input) must contain only one element. It has ",
pattern_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT); pattern_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1) if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"rewrite (third input) must contain only one element. It has ", "rewrite (third input) must contain only one element. It has ",
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT); rewrite_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
if (str_pattern[0].empty()) if (str_pattern[0].empty())
ORT_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
// Setup output // Setup output
OrtTensorDimensions dimensions(ort_, input); OrtTensorDimensions dimensions(ort_, input);

Просмотреть файл

@ -14,7 +14,7 @@ struct KernelStringRegexReplace : BaseKernel {
int64_t global_replace_; int64_t global_replace_;
}; };
struct CustomOpStringRegexReplace : Ort::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> { struct CustomOpStringRegexReplace : OrtW::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -24,15 +24,15 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
// Verifications // Verifications
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern); OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
if (str_pattern.size() != 1) if (str_pattern.size() != 1)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"pattern (second input) must contain only one element. It has ", "pattern (second input) must contain only one element. It has ",
str_pattern.size(), " values."), ORT_INVALID_ARGUMENT); str_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
if (str_keep_pattern.size() > 1) if (str_keep_pattern.size() > 1)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"Third input must contain only one element. It has ", "Third input must contain only one element. It has ",
str_keep_pattern.size(), " values."), ORT_INVALID_ARGUMENT); str_keep_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
if (str_pattern[0].empty()) if (str_pattern[0].empty())
ORT_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions(ort_, input); OrtTensorDimensions dimensions(ort_, input);
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty()); bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
@ -106,7 +106,7 @@ ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetOutputType(siz
case 3: case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default: default:
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), ORT_INVALID_ARGUMENT); "StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), ORT_INVALID_ARGUMENT);
} }
}; };

Просмотреть файл

@ -12,7 +12,7 @@ struct KernelStringRegexSplitWithOffsets : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringRegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> { struct CustomOpStringRegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -20,7 +20,7 @@ void KernelStringConcat::Compute(OrtKernelContext* context) {
OrtTensorDimensions right_dim(ort_, right); OrtTensorDimensions right_dim(ort_, right);
if (left_dim != right_dim) { if (left_dim != right_dim) {
ORT_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Two input tensor should have the same dimension.", ORT_INVALID_ARGUMENT);
} }
std::vector<std::string> left_value; std::vector<std::string> left_value;

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelStringConcat : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringConcat : Ort::CustomOpBase<CustomOpStringConcat, KernelStringConcat> { struct CustomOpStringConcat : OrtW::CustomOpBase<CustomOpStringConcat, KernelStringConcat> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -29,17 +29,17 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
OrtTensorDimensions pattern_dimensions(ort_, pattern); OrtTensorDimensions pattern_dimensions(ort_, pattern);
OrtTensorDimensions rewrite_dimensions(ort_, rewrite); OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
if (pattern_dimensions.Size() != 1) { if (pattern_dimensions.Size() != 1) {
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"pattern (second input) must contain only one element. It has ", "pattern (second input) must contain only one element. It has ",
pattern_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH); pattern_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
} }
if (rewrite_dimensions.Size() != 1) { if (rewrite_dimensions.Size() != 1) {
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"rewrite (third input) must contain only one element. It has ", "rewrite (third input) must contain only one element. It has ",
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH); rewrite_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
} }
if (str_pattern[0].empty()) { if (str_pattern[0].empty()) {
ORT_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_GRAPH);
} }
// Setup output // Setup output

Просмотреть файл

@ -15,7 +15,7 @@ struct KernelStringECMARegexReplace : BaseKernel {
bool ignore_case_; bool ignore_case_;
}; };
struct CustomOpStringECMARegexReplace : Ort::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> { struct CustomOpStringECMARegexReplace : OrtW::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -28,11 +28,11 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
// Verifications // Verifications
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern); OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
if (str_pattern.size() != 1) if (str_pattern.size() != 1)
ORT_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(), " values."), ORT_INVALID_GRAPH); ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(), " values."), ORT_INVALID_GRAPH);
if (str_keep_pattern.size() > 1) if (str_keep_pattern.size() > 1)
ORT_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(), " values."), ORT_INVALID_GRAPH); ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(), " values."), ORT_INVALID_GRAPH);
if (str_pattern[0].empty()) if (str_pattern[0].empty())
ORT_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_GRAPH);
OrtTensorDimensions dimensions(ort_, input); OrtTensorDimensions dimensions(ort_, input);
bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty()); bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
@ -111,7 +111,7 @@ ONNXTensorElementDataType CustomOpStringECMARegexSplitWithOffsets::GetOutputType
case 3: case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default: default:
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), "StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
ORT_INVALID_ARGUMENT); ORT_INVALID_ARGUMENT);
} }

Просмотреть файл

@ -15,7 +15,7 @@ struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
bool ignore_case_; bool ignore_case_;
}; };
struct CustomOpStringECMARegexSplitWithOffsets : Ort::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> { struct CustomOpStringECMARegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -23,7 +23,7 @@ void KernelStringHash::Compute(OrtKernelContext* context) {
// Verifications // Verifications
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets); OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1) if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"num_buckets must contain only one element. It has ", "num_buckets must contain only one element. It has ",
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT); num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
@ -56,7 +56,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetInputType(size_t index) const {
case 1: case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
} }
}; };
@ -82,7 +82,7 @@ void KernelStringHashFast::Compute(OrtKernelContext* context) {
// Verifications // Verifications
OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets); OrtTensorDimensions num_buckets_dimensions(ort_, num_buckets);
if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1) if (num_buckets_dimensions.size() != 1 || num_buckets_dimensions[0] != 1)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"num_buckets must contain only one element. It has ", "num_buckets must contain only one element. It has ",
num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT); num_buckets_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
@ -115,7 +115,7 @@ ONNXTensorElementDataType CustomOpStringHashFast::GetInputType(size_t index) con
case 1: case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
} }
}; };

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelStringHash : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringHash : Ort::CustomOpBase<CustomOpStringHash, KernelStringHash> { struct CustomOpStringHash : OrtW::CustomOpBase<CustomOpStringHash, KernelStringHash> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;
@ -24,7 +24,7 @@ struct KernelStringHashFast : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringHashFast : Ort::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> { struct CustomOpStringHashFast : OrtW::CustomOpBase<CustomOpStringHashFast, KernelStringHashFast> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -20,18 +20,18 @@ void KernelStringJoin::Compute(OrtKernelContext* context) {
// Check input // Check input
OrtTensorDimensions dimensions_sep(ort_, input_sep); OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1) if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
ORT_CXX_API_THROW("Input 2 is the separator, it should have 1 element.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Input 2 is the separator, it should have 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions_axis(ort_, input_axis); OrtTensorDimensions dimensions_axis(ort_, input_axis);
if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1) if (dimensions_axis.size() != 1 || dimensions_axis[0] != 1)
ORT_CXX_API_THROW("Input 3 is the axis, it should have 1 element.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Input 3 is the axis, it should have 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions(ort_, input_X); OrtTensorDimensions dimensions(ort_, input_X);
if (dimensions.size() == 0) { if (dimensions.size() == 0) {
// dimensions size 0 means input 1 is scalar, input 1 must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724 // dimensions size 0 means input 1 is scalar, input 1 must have 1 element. See issue: https://github.com/onnx/onnx/issues/3724
if (X.size() != 1) if (X.size() != 1)
ORT_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Input 1's dimensions size is 0 (scalar), it must has 1 element but it has ", X.size()), ORT_INVALID_ARGUMENT);
} else { } else {
if (*axis < 0 || *axis >= static_cast<int64_t>(dimensions.size())) if (*axis < 0 || *axis >= static_cast<int64_t>(dimensions.size()))
ORT_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("axis must be positive and smaller than the number of dimension but it is ", *axis), ORT_INVALID_ARGUMENT);
} }
std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1); std::vector<int64_t> dimensions_out(dimensions.size() > 1 ? dimensions.size() - 1 : 1);
@ -102,7 +102,7 @@ ONNXTensorElementDataType CustomOpStringJoin::GetInputType(size_t index) const {
case 2: case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
} }
}; };

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelStringJoin : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringJoin : Ort::CustomOpBase<CustomOpStringJoin, KernelStringJoin> { struct CustomOpStringJoin : OrtW::CustomOpBase<CustomOpStringJoin, KernelStringJoin> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelStringLength : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringLength : Ort::CustomOpBase<CustomOpStringLength, KernelStringLength> { struct CustomOpStringLength : OrtW::CustomOpBase<CustomOpStringLength, KernelStringLength> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelStringLower : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringLower : Ort::CustomOpBase<CustomOpStringLower, KernelStringLower> { struct CustomOpStringLower : OrtW::CustomOpBase<CustomOpStringLower, KernelStringLower> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -16,7 +16,7 @@ KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo*
auto items = SplitString(line, "\t", true); auto items = SplitString(line, "\t", true);
if (items.size() != 2) { if (items.size() != 2) {
ORT_CXX_API_THROW(std::string("[StringMapping]: Should only exist two items in one line, find error in line: ") + std::string(line), ORT_INVALID_GRAPH); ORTX_CXX_API_THROW(std::string("[StringMapping]: Should only exist two items in one line, find error in line: ") + std::string(line), ORT_INVALID_GRAPH);
} }
map_[std::string(items[0])] = std::string(items[1]); map_[std::string(items[0])] = std::string(items[1]);
} }

Просмотреть файл

@ -14,7 +14,7 @@ struct KernelStringMapping : BaseKernel {
std::unordered_map<std::string, std::string> map_; std::unordered_map<std::string, std::string> map_;
}; };
struct CustomOpStringMapping : Ort::CustomOpBase<CustomOpStringMapping, KernelStringMapping> { struct CustomOpStringMapping : OrtW::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -20,13 +20,13 @@ void KernelStringSplit::Compute(OrtKernelContext* context) {
// Setup output // Setup output
OrtTensorDimensions dimensions_sep(ort_, input_sep); OrtTensorDimensions dimensions_sep(ort_, input_sep);
if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1) if (dimensions_sep.size() != 1 || dimensions_sep[0] != 1)
ORT_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Input 2 is the delimiter, it has 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty); OrtTensorDimensions dimensions_skip_empty(ort_, input_skip_empty);
if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1) if (dimensions_skip_empty.size() != 1 || dimensions_skip_empty[0] != 1)
ORT_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Input 3 is skip_empty, it has 1 element.", ORT_INVALID_ARGUMENT);
OrtTensorDimensions dimensions(ort_, input_X); OrtTensorDimensions dimensions(ort_, input_X);
if (dimensions.size() != 1) if (dimensions.size() != 1)
ORT_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Only 1D tensor are supported as input.", ORT_INVALID_ARGUMENT);
std::vector<std::string> words; std::vector<std::string> words;
std::vector<int64_t> indices; std::vector<int64_t> indices;
@ -112,7 +112,7 @@ ONNXTensorElementDataType CustomOpStringSplit::GetInputType(size_t index) const
case 2: case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL; return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
} }
}; };
@ -128,6 +128,6 @@ ONNXTensorElementDataType CustomOpStringSplit::GetOutputType(size_t index) const
case 1: case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
default: default:
ORT_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("[StringSplit] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
} }
}; };

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelStringSplit : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringSplit : Ort::CustomOpBase<CustomOpStringSplit, KernelStringSplit> { struct CustomOpStringSplit : OrtW::CustomOpBase<CustomOpStringSplit, KernelStringSplit> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -41,7 +41,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
vector_len_ = ParseVectorLen(lines[0]); vector_len_ = ParseVectorLen(lines[0]);
if (vector_len_ == 0) { if (vector_len_ == 0) {
ORT_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("The mapped value of string input cannot be empty: ", lines[0]), ORT_INVALID_ARGUMENT);
} }
std::vector<int64_t> values(vector_len_); std::vector<int64_t> values(vector_len_);
@ -49,7 +49,7 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
auto kv = SplitString(line, "\t", true); auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) { if (kv.size() != 2) {
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
} }
ParseValues(kv[1], values); ParseValues(kv[1], values);
@ -62,14 +62,14 @@ void StringToVectorImpl::ParseMappingTable(std::string& map) {
void StringToVectorImpl::ParseUnkownValue(std::string& unk) { void StringToVectorImpl::ParseUnkownValue(std::string& unk) {
auto unk_strs = SplitString(unk, " ", true); auto unk_strs = SplitString(unk, " ", true);
if (unk_strs.size() != vector_len_) { if (unk_strs.size() != vector_len_) {
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length of unknown_value should be: ", vector_len_), ORT_INVALID_ARGUMENT);
} }
for (auto& str : unk_strs) { for (auto& str : unk_strs) {
int64_t value; int64_t value;
auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value); auto [end, ec] = std::from_chars(str.data(), str.data() + str.size(), value);
if (end != str.data() + str.size()) { if (end != str.data() + str.size()) {
ORT_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Failed to parse unknown_value when processing the number: ", str), ORT_INVALID_ARGUMENT);
} }
unk_value_.push_back(value); unk_value_.push_back(value);
@ -80,7 +80,7 @@ size_t StringToVectorImpl::ParseVectorLen(const std::string_view& line) {
auto kv = SplitString(line, "\t", true); auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) { if (kv.size() != 2) {
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
} }
auto value_strs = SplitString(kv[1], " ", true); auto value_strs = SplitString(kv[1], " ", true);
@ -94,7 +94,7 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
for (size_t i = 0; i < value_strs.size(); i++) { for (size_t i = 0; i < value_strs.size(); i++) {
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value); auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
if (end != value_strs[i].data() + value_strs[i].size()) { if (end != value_strs[i].data() + value_strs[i].size()) {
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
} }
values[i] = value; values[i] = value;
} }

Просмотреть файл

@ -36,7 +36,7 @@ struct KernelStringToVector : BaseKernel {
std::shared_ptr<StringToVectorImpl> impl_; std::shared_ptr<StringToVectorImpl> impl_;
}; };
struct CustomOpStringToVector : Ort::CustomOpBase<CustomOpStringToVector, KernelStringToVector> { struct CustomOpStringToVector : OrtW::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -11,7 +11,7 @@ struct KernelStringUpper : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpStringUpper : Ort::CustomOpBase<CustomOpStringUpper, KernelStringUpper> { struct CustomOpStringUpper : OrtW::CustomOpBase<CustomOpStringUpper, KernelStringUpper> {
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(size_t index) const; ONNXTensorElementDataType GetInputType(size_t index) const;

Просмотреть файл

@ -29,7 +29,7 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
output_dim = input_dim; output_dim = input_dim;
} else { } else {
if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) { if (input_dim.IsScalar() || input_dim[input_dim.size() - 1] != static_cast<int64_t>(vector_len_)) {
ORT_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Incompatible dimension: required vector length should be ", vector_len_), ORT_INVALID_ARGUMENT);
} }
output_dim = input_dim; output_dim = input_dim;
@ -70,7 +70,7 @@ void VectorToStringImpl::ParseMappingTable(std::string& map) {
auto kv = SplitString(line, "\t", true); auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) { if (kv.size() != 2) {
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
} }
ParseValues(kv[1], values); ParseValues(kv[1], values);
@ -83,7 +83,7 @@ size_t VectorToStringImpl::ParseVectorLen(const std::string_view& line) {
auto kv = SplitString(line, "\t", true); auto kv = SplitString(line, "\t", true);
if (kv.size() != 2) { if (kv.size() != 2) {
ORT_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Failed to parse mapping_table when processing the line: ", line), ORT_INVALID_ARGUMENT);
} }
auto value_strs = SplitString(kv[1], " ", true); auto value_strs = SplitString(kv[1], " ", true);
@ -97,7 +97,7 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
for (size_t i = 0; i < value_strs.size(); i++) { for (size_t i = 0; i < value_strs.size(); i++) {
auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value); auto [end, ec] = std::from_chars(value_strs[i].data(), value_strs[i].data() + value_strs[i].size(), value);
if (end != value_strs[i].data() + value_strs[i].size()) { if (end != value_strs[i].data() + value_strs[i].size()) {
ORT_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Failed to parse map when processing the number: ", value_strs[i]), ORT_INVALID_ARGUMENT);
} }
values[i] = value; values[i] = value;
} }

Просмотреть файл

@ -40,7 +40,7 @@ struct KernelVectorToString : BaseKernel {
std::shared_ptr<VectorToStringImpl> impl_; std::shared_ptr<VectorToStringImpl> impl_;
}; };
struct CustomOpVectorToString : Ort::CustomOpBase<CustomOpVectorToString, KernelVectorToString> { struct CustomOpVectorToString : OrtW::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -100,7 +100,7 @@ void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
OrtTensorDimensions dimensions(ort_, input); OrtTensorDimensions dimensions(ort_, input);
if (dimensions.size() != 1 && dimensions[0] != 1) { if (dimensions.size() != 1 && dimensions[0] != 1) {
ORT_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
} }
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size()); OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());

Просмотреть файл

@ -27,7 +27,7 @@ struct KernelBasicTokenizer : BaseKernel {
std::shared_ptr<BasicTokenizer> tokenizer_; std::shared_ptr<BasicTokenizer> tokenizer_;
}; };
struct CustomOpBasicTokenizer : Ort::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> { struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -33,7 +33,7 @@ int32_t BertTokenizerVocab::FindTokenId(const ustring& token) {
auto it = vocab_.find(utf8_token); auto it = vocab_.find(utf8_token);
if (it == vocab_.end()) { if (it == vocab_.end()) {
ORT_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION); ORTX_CXX_API_THROW("[BertTokenizerVocab]: can not find tokens: " + std::string(token), ORT_RUNTIME_EXCEPTION);
} }
return it->second; return it->second;
@ -305,7 +305,7 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
GetTensorMutableDataString(api_, ort_, context, input, input_data); GetTensorMutableDataString(api_, ort_, context, input, input_data);
if (input_data.size() != 1 && input_data.size() != 2) { if (input_data.size() != 1 && input_data.size() != 2) {
ORT_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("[BertTokenizer]: only support one or two query.", ORT_INVALID_GRAPH);
} }
std::vector<int64_t> input_ids; std::vector<int64_t> input_ids;
std::vector<int64_t> token_type_ids; std::vector<int64_t> token_type_ids;
@ -365,7 +365,7 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
GetTensorMutableDataString(api_, ort_, context, input, input_data); GetTensorMutableDataString(api_, ort_, context, input, input_data);
if (input_data.size() != 2) { if (input_data.size() != 2) {
ORT_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("[HfBertTokenizer]: Support only two input strings.", ORT_INVALID_GRAPH);
} }
std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0])); std::vector<ustring> tokens1 = tokenizer_->Tokenize(ustring(input_data[0]));

Просмотреть файл

@ -97,7 +97,7 @@ struct KernelBertTokenizer : BaseKernel {
std::unique_ptr<BertTokenizer> tokenizer_; std::unique_ptr<BertTokenizer> tokenizer_;
}; };
struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> { struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;
@ -111,7 +111,7 @@ struct KernelHfBertTokenizer : KernelBertTokenizer {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpHfBertTokenizer : Ort::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> { struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -146,7 +146,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
OrtTensorDimensions ids_dim(ort_, ids); OrtTensorDimensions ids_dim(ort_, ids);
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) { if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
} }
// const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices); // const int64_t* p_row_indices = ort_row_indices_dim.empty() ? nullptr : ort_.GetTensorData<int64_t>(ort_row_indices);
@ -155,7 +155,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
if (use_indices_ && if (use_indices_ &&
(!((positions_dim.Size() == 0) || (!((positions_dim.Size() == 0) ||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) { (positions_dim.size() == 2 && positions_dim[1] == 2)))) {
ORT_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
} }
const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions); const int64_t* p_positions = positions_dim.Size() == 0 ? nullptr : ort_.GetTensorData<int64_t>(positions);

Просмотреть файл

@ -42,7 +42,7 @@ struct KernelBertTokenizerDecoder : BaseKernel {
bool clean_up_tokenization_spaces_; bool clean_up_tokenization_spaces_;
}; };
struct CustomOpBertTokenizerDecoder : Ort::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> { struct CustomOpBertTokenizerDecoder : OrtW::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -12,13 +12,13 @@
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) { KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model"); model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
if (model_data_.empty()) { if (model_data_.empty()) {
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
} }
void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size())); void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size()));
if (model_ptr == nullptr) { if (model_ptr == nullptr) {
ORT_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
} }
model_ = std::shared_ptr<void>(model_ptr, FreeModel); model_ = std::shared_ptr<void>(model_ptr, FreeModel);
@ -35,7 +35,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
// TODO: fix this scalar check. // TODO: fix this scalar check.
if (dimensions.Size() != 1 && dimensions[0] != 1) { if (dimensions.Size() != 1 && dimensions[0] != 1) {
ORT_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
} }
std::vector<std::string> input_data; std::vector<std::string> input_data;
@ -47,7 +47,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get()); int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
if (output_length < 0) { if (output_length < 0) {
ORT_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
} }
// inline split output_str by newline '\n' // inline split output_str by newline '\n'
@ -75,7 +75,7 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
output_dimensions[0] = output_sentences.size(); output_dimensions[0] = output_sentences.size();
OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size()); OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size())); OrtW::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
} }
void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {

Просмотреть файл

@ -25,7 +25,7 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
int max_sentence; int max_sentence;
}; };
struct CustomOpBlingFireSentenceBreaker : Ort::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> { struct CustomOpBlingFireSentenceBreaker : OrtW::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -29,7 +29,7 @@ class SpecialTokenMap {
auto it = token_map_.find(p_str); auto it = token_map_.find(p_str);
if (it != token_map_.end()) { if (it != token_map_.end()) {
if (it->second != p_id) { if (it->second != p_id) {
ORT_CXX_API_THROW("Duplicate special tokens.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Duplicate special tokens.", ORT_INVALID_ARGUMENT);
} }
} else { } else {
token_map_[p_str] = p_id; token_map_[p_str] = p_id;
@ -84,7 +84,7 @@ class SpecialTokenMap {
SpecialTokenInfo(ustring p_str, int p_id) SpecialTokenInfo(ustring p_str, int p_id)
: str(std::move(p_str)), id(p_id) { : str(std::move(p_str)), id(p_id) {
if (str.empty()) { if (str.empty()) {
ORT_CXX_API_THROW("Empty special token.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Empty special token.", ORT_INVALID_ARGUMENT);
} }
} }
}; };
@ -147,7 +147,7 @@ class VocabData {
if ((line[0] == '#') && (index == 0)) continue; if ((line[0] == '#') && (index == 0)) continue;
auto pos = line.find(' '); auto pos = line.find(' ');
if (pos == std::string::npos) { if (pos == std::string::npos) {
ORT_CXX_API_THROW("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Cannot know how to parse line: " + line, ORT_INVALID_ARGUMENT);
} }
std::string w1 = line.substr(0, pos); std::string w1 = line.substr(0, pos);
std::string w2 = line.substr(pos + 1); std::string w2 = line.substr(pos + 1);
@ -231,14 +231,14 @@ class VocabData {
int TokenToID(const std::string& input) const { int TokenToID(const std::string& input) const {
auto it = vocab_map_.find(input); auto it = vocab_map_.find(input);
if (it == vocab_map_.end()) { if (it == vocab_map_.end()) {
ORT_CXX_API_THROW("Token not found: " + input, ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Token not found: " + input, ORT_INVALID_ARGUMENT);
} }
return it->second; return it->second;
} }
const std::string& IdToToken(int id) const { const std::string& IdToToken(int id) const {
if ((id < 0) || (static_cast<size_t>(id) >= id2token_map_.size())) { if ((id < 0) || (static_cast<size_t>(id) >= id2token_map_.size())) {
ORT_CXX_API_THROW("Invalid ID: " + std::to_string(id), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Invalid ID: " + std::to_string(id), ORT_INVALID_ARGUMENT);
} }
return id2token_map_[id]; return id2token_map_[id];
} }
@ -247,7 +247,7 @@ class VocabData {
int GetVocabIndex(const std::string& str) { int GetVocabIndex(const std::string& str) {
auto it = vocab_map_.find(str); auto it = vocab_map_.find(str);
if (it == vocab_map_.end()) { if (it == vocab_map_.end()) {
ORT_CXX_API_THROW("Cannot find word in vocabulary: " + str, ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("Cannot find word in vocabulary: " + str, ORT_INVALID_ARGUMENT);
} }
return it->second; return it->second;
} }
@ -467,12 +467,12 @@ KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* i
: BaseKernel(api, info) { : BaseKernel(api, info) {
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab"); std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
if (vocab.empty()) { if (vocab.empty()) {
ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
} }
std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges"); std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges");
if (merges.empty()) { if (merges.empty()) {
ORT_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
} }
if (!TryToGetAttribute<int64_t>("padding_length", padding_length_)) { if (!TryToGetAttribute<int64_t>("padding_length", padding_length_)) {
@ -480,7 +480,7 @@ KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* i
} }
if (padding_length_ != -1 && padding_length_ <= 0) { if (padding_length_ != -1 && padding_length_ <= 0) {
ORT_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("padding_length should be more than 0 or equal -1", ORT_INVALID_ARGUMENT);
} }
std::stringstream vocabu_stream(vocab); std::stringstream vocabu_stream(vocab);

Просмотреть файл

@ -17,7 +17,7 @@ struct KernelBpeTokenizer : BaseKernel {
std::shared_ptr<VocabData> bbpe_tokenizer_; std::shared_ptr<VocabData> bbpe_tokenizer_;
}; };
struct CustomOpBpeTokenizer : Ort::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> { struct CustomOpBpeTokenizer : OrtW::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -16,7 +16,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
model_proto.ParseFromArray(model_blob.data(), static_cast<int>(model_blob.size())); model_proto.ParseFromArray(model_blob.data(), static_cast<int>(model_blob.size()));
sentencepiece::util::Status status = tokenizer_.Load(model_proto); sentencepiece::util::Status status = tokenizer_.Load(model_proto);
if (!status.ok()){ if (!status.ok()){
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"Failed to create SentencePieceProcessor instance. Error code is ", "Failed to create SentencePieceProcessor instance. Error code is ",
(int)status.code(), ". Message is '", status.error_message(), "'."), (int)status.code(), ". Message is '", status.error_message(), "'."),
ORT_INVALID_PROTOBUF); ORT_INVALID_PROTOBUF);
@ -29,7 +29,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
OrtTensorDimensions ids_dim(ort_, ids); OrtTensorDimensions ids_dim(ort_, ids);
if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) { if (!((ids_dim.size() == 1) || (ids_dim.size() == 2 && ids_dim[0] == 1))) {
ORT_CXX_API_THROW("[SentencePieceDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH); ORTX_CXX_API_THROW("[SentencePieceDecoder]: Expect ids dimension [n] or [1,n].", ORT_INVALID_GRAPH);
} }
auto count = ids_dim[0]; auto count = ids_dim[0];
@ -41,7 +41,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
[](auto _id) { return static_cast<int>(_id); }); [](auto _id) { return static_cast<int>(_id); });
auto status = tokenizer_.Decode(tids, &decoded_string); auto status = tokenizer_.Decode(tids, &decoded_string);
if (!status.ok()){ if (!status.ok()){
ORT_CXX_API_THROW("[SentencePieceDecoder] model decoding failed.", ORT_RUNTIME_EXCEPTION); ORTX_CXX_API_THROW("[SentencePieceDecoder] model decoding failed.", ORT_RUNTIME_EXCEPTION);
} }
std::vector<std::string> result = {decoded_string}; std::vector<std::string> result = {decoded_string};
@ -53,7 +53,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
sentencepiece::SentencePieceProcessor tokenizer_; sentencepiece::SentencePieceProcessor tokenizer_;
}; };
struct CustomOpSentencepieceDecoder : Ort::CustomOpBase<CustomOpSentencepieceDecoder, KernelSentencepieceDecoder> { struct CustomOpSentencepieceDecoder : OrtW::CustomOpBase<CustomOpSentencepieceDecoder, KernelSentencepieceDecoder> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info); return CreateKernelImpl(api, info);
} }

Просмотреть файл

@ -23,7 +23,7 @@ KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, co
(int)status.code(), ". Message is '", status.error_message(), "'.")); (int)status.code(), ". Message is '", status.error_message(), "'."));
} }
static void _check_dimension_constant(Ort::CustomOpApi ort, const OrtValue* ort_value, const char* name) { static void _check_dimension_constant(OrtW::CustomOpApi ort, const OrtValue* ort_value, const char* name) {
OrtTensorDimensions dimensions(ort, ort_value); OrtTensorDimensions dimensions(ort, ort_value);
if (dimensions.size() != 1 || dimensions[0] != 1) if (dimensions.size() != 1 || dimensions[0] != 1)
throw std::runtime_error(MakeString( throw std::runtime_error(MakeString(

Просмотреть файл

@ -16,7 +16,7 @@ struct KernelSentencepieceTokenizer : BaseKernel {
sentencepiece::SentencePieceProcessor tokenizer_; sentencepiece::SentencePieceProcessor tokenizer_;
}; };
struct CustomOpSentencepieceTokenizer : Ort::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> { struct CustomOpSentencepieceTokenizer : OrtW::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -72,7 +72,7 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
rows.push_back(indices.size()); rows.push_back(indices.size());
} else if (text_index == existing_rows[row_index]) { } else if (text_index == existing_rows[row_index]) {
if (row_index >= n_existing_rows) if (row_index >= n_existing_rows)
ORT_CXX_API_THROW(MakeString( ORTX_CXX_API_THROW(MakeString(
"row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT); "row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT);
rows.push_back(indices.size()); rows.push_back(indices.size());
++row_index; ++row_index;
@ -181,7 +181,7 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetInputType(size_t index)
case 1: case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default: default:
ORT_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
} }
}; };
@ -198,6 +198,6 @@ ONNXTensorElementDataType CustomOpWordpieceTokenizer::GetOutputType(size_t index
case 3: case 3:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
default: default:
ORT_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("[WordpieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
} }
}; };

Просмотреть файл

@ -21,7 +21,7 @@ struct KernelWordpieceTokenizer : BaseKernel {
std::unordered_map<std::u32string, int32_t> vocab_; std::unordered_map<std::u32string, int32_t> vocab_;
}; };
struct CustomOpWordpieceTokenizer : Ort::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> { struct CustomOpWordpieceTokenizer : OrtW::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const; const char* GetName() const;
size_t GetInputTypeCount() const; size_t GetInputTypeCount() const;

Просмотреть файл

@ -42,4 +42,4 @@ struct hash<ustring> {
return standard_hash(static_cast<u32string>(__str)); return standard_hash(static_cast<u32string>(__str));
} }
}; };
} // namespace std } // namespace std

Просмотреть файл

@ -12,7 +12,7 @@ void KernelDecodeImage::Compute(OrtKernelContext* context) {
const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL); const OrtValue* const inputs = ort_.KernelContext_GetInput(context, 0ULL);
OrtTensorDimensions dimensions(ort_, inputs); OrtTensorDimensions dimensions(ort_, inputs);
if (dimensions.size() != 1ULL) { if (dimensions.size() != 1ULL) {
ORT_CXX_API_THROW("[DecodeImage]: Raw image bytes with 1D shape expected.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[DecodeImage]: Raw image bytes with 1D shape expected.", ORT_INVALID_ARGUMENT);
} }
OrtTensorTypeAndShapeInfo* input_info = ort_.GetTensorTypeAndShape(inputs); OrtTensorTypeAndShapeInfo* input_info = ort_.GetTensorTypeAndShape(inputs);
@ -26,7 +26,7 @@ void KernelDecodeImage::Compute(OrtKernelContext* context) {
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR); const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
if (decoded_image.data == nullptr) { if (decoded_image.data == nullptr) {
ORT_CXX_API_THROW("[DecodeImage] Invalid input. Failed to decode image.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[DecodeImage] Invalid input. Failed to decode image.", ORT_INVALID_ARGUMENT);
}; };
// Setup output & copy to destination // Setup output & copy to destination

Просмотреть файл

@ -15,7 +15,7 @@ struct KernelDecodeImage : BaseKernel {
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
}; };
struct CustomOpDecodeImage : Ort::CustomOpBase<CustomOpDecodeImage, KernelDecodeImage> { struct CustomOpDecodeImage : OrtW::CustomOpBase<CustomOpDecodeImage, KernelDecodeImage> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new KernelDecodeImage(api); return new KernelDecodeImage(api);
} }
@ -37,7 +37,7 @@ struct CustomOpDecodeImage : Ort::CustomOpBase<CustomOpDecodeImage, KernelDecode
case 0: case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default: default:
ORT_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
} }
} }
@ -50,7 +50,7 @@ struct CustomOpDecodeImage : Ort::CustomOpBase<CustomOpDecodeImage, KernelDecode
case 0: case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default: default:
ORT_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
} }
} }
}; };

Просмотреть файл

@ -15,7 +15,7 @@ void KernelEncodeImage ::Compute(OrtKernelContext* context) {
if (dimensions_bgr.size() != 3 || dimensions_bgr[2] != 3) { if (dimensions_bgr.size() != 3 || dimensions_bgr[2] != 3) {
// expect {H, W, C} as that's the inverse of what decode_image produces. // expect {H, W, C} as that's the inverse of what decode_image produces.
// we have no way to check if it's BGR or RGB though // we have no way to check if it's BGR or RGB though
ORT_CXX_API_THROW("[EncodeImage] requires rank 3 BGR input in channels last format.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[EncodeImage] requires rank 3 BGR input in channels last format.", ORT_INVALID_ARGUMENT);
} }
// Get data & the length // Get data & the length
@ -29,7 +29,7 @@ void KernelEncodeImage ::Compute(OrtKernelContext* context) {
// don't know output size ahead of time so need to encode and then copy to output // don't know output size ahead of time so need to encode and then copy to output
std::vector<uint8_t> encoded_image; std::vector<uint8_t> encoded_image;
if (!cv::imencode(extension_, bgr_image, encoded_image)) { if (!cv::imencode(extension_, bgr_image, encoded_image)) {
ORT_CXX_API_THROW("[EncodeImage] Image encoding failed.", ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW("[EncodeImage] Image encoding failed.", ORT_INVALID_ARGUMENT);
} }
// Setup output & copy to destination // Setup output & copy to destination

Просмотреть файл

@ -27,12 +27,12 @@ struct KernelEncodeImage : BaseKernel {
/// Converts rank 3 BGR input with channels last ordering to the requested file type. /// Converts rank 3 BGR input with channels last ordering to the requested file type.
/// Default is 'jpg' /// Default is 'jpg'
/// </summary> /// </summary>
struct CustomOpEncodeImage : Ort::CustomOpBase<CustomOpEncodeImage, KernelEncodeImage> { struct CustomOpEncodeImage : OrtW::CustomOpBase<CustomOpEncodeImage, KernelEncodeImage> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
Ort::CustomOpApi op_api{api}; OrtW::CustomOpApi op_api{api};
std::string format = op_api.KernelInfoGetAttribute<std::string>(info, "format"); std::string format = op_api.KernelInfoGetAttribute<std::string>(info, "format");
if (format != "jpg" && format != "png") { if (format != "jpg" && format != "png") {
ORT_CXX_API_THROW("[EncodeImage] 'format' attribute value must be 'jpg' or 'png'.", ORT_RUNTIME_EXCEPTION); ORTX_CXX_API_THROW("[EncodeImage] 'format' attribute value must be 'jpg' or 'png'.", ORT_RUNTIME_EXCEPTION);
} }
return new KernelEncodeImage(api, format); return new KernelEncodeImage(api, format);
@ -51,7 +51,7 @@ struct CustomOpEncodeImage : Ort::CustomOpBase<CustomOpEncodeImage, KernelEncode
case 0: case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default: default:
ORT_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Invalid input index ", index), ORT_INVALID_ARGUMENT);
} }
} }
@ -64,7 +64,7 @@ struct CustomOpEncodeImage : Ort::CustomOpBase<CustomOpEncodeImage, KernelEncode
case 0: case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
default: default:
ORT_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT); ORTX_CXX_API_THROW(MakeString("Invalid output index ", index), ORT_INVALID_ARGUMENT);
} }
} }
}; };

Просмотреть файл

@ -167,7 +167,7 @@ struct PyCustomOpDefImpl : public PyCustomOpDef {
} }
static py::object BuildPyObjFromTensor( static py::object BuildPyObjFromTensor(
const OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value, const OrtApi& api, OrtW::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
const shape_t& shape, ONNXTensorElementDataType dtype) { const shape_t& shape, ONNXTensorElementDataType dtype) {
std::vector<npy_intp> npy_dims; std::vector<npy_intp> npy_dims;
for (auto n : shape) { for (auto n : shape) {

Просмотреть файл

@ -42,12 +42,12 @@ struct PyCustomOpKernel {
private: private:
const OrtApi& api_; const OrtApi& api_;
Ort::CustomOpApi ort_; OrtW::CustomOpApi ort_;
uint64_t obj_id_; uint64_t obj_id_;
std::map<std::string, std::string> attrs_values_; std::map<std::string, std::string> attrs_values_;
}; };
struct PyCustomOpFactory : Ort::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> { struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
PyCustomOpFactory() { PyCustomOpFactory() {
// STL vector needs it. // STL vector needs it.
} }

Просмотреть файл

@ -7,6 +7,8 @@
#include "onnxruntime_extensions.h" #include "onnxruntime_extensions.h"
#include "ocos.h" #include "ocos.h"
using namespace OrtW;
struct OrtCustomOpDomainDeleter { struct OrtCustomOpDomainDeleter {
explicit OrtCustomOpDomainDeleter(const OrtApi* ort_api) { explicit OrtCustomOpDomainDeleter(const OrtApi* ort_api) {
ort_api_ = ort_api; ort_api_ = ort_api;

Просмотреть файл

@ -4,6 +4,7 @@
#pragma once #pragma once
#include <math.h> #include <math.h>
#include "onnxruntime_cxx_api.h"
const char* GetLibraryPath(); const char* GetLibraryPath();

Просмотреть файл

@ -1,8 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. // Licensed under the MIT License.
#include "onnxruntime_cxx_api.h"
#include <filesystem> #include <filesystem>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ocos.h" #include "ocos.h"
@ -50,7 +47,7 @@ struct KernelOne : BaseKernel {
} }
}; };
struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> { struct CustomOpOne : OrtW::CustomOpBase<CustomOpOne, KernelOne> {
const char* GetName() const { const char* GetName() const {
return "CustomOpOne"; return "CustomOpOne";
}; };
@ -93,7 +90,7 @@ struct KernelTwo : BaseKernel {
} }
}; };
struct CustomOpTwo : Ort::CustomOpBase<CustomOpTwo, KernelTwo> { struct CustomOpTwo : OrtW::CustomOpBase<CustomOpTwo, KernelTwo> {
const char* GetName() const { const char* GetName() const {
return "CustomOpTwo"; return "CustomOpTwo";
}; };
@ -138,7 +135,7 @@ struct KernelThree : BaseKernel {
std::string substr_; std::string substr_;
}; };
struct CustomOpThree : Ort::CustomOpBase<CustomOpThree, KernelThree> { struct CustomOpThree : OrtW::CustomOpBase<CustomOpThree, KernelThree> {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return CreateKernelImpl(api, info); return CreateKernelImpl(api, info);
}; };
@ -188,8 +185,7 @@ void _assert_eq(Ort::Value& output_tensor, const std::vector<T>& expected, size_
} }
void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::vector<std::string>& output) { void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::vector<std::string>& output) {
Ort::CustomOpApi ort(api); OrtTensorDimensions dimensions(OrtW::CustomOpApi(api), value);
OrtTensorDimensions dimensions(ort, value);
size_t len = static_cast<size_t>(dimensions.Size()); size_t len = static_cast<size_t>(dimensions.Size());
size_t data_len; size_t data_len;
Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len)); Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
@ -397,7 +393,7 @@ TEST(ustring, tensor_operator) {
} }
EXPECT_EQ(err_code, ORT_OK); EXPECT_EQ(err_code, ORT_OK);
Ort::CustomOpApi custom_api(*api); OrtW::CustomOpApi custom_api(*api);
std::vector<int64_t> dim{2, 2}; std::vector<int64_t> dim{2, 2};
status = api->CreateTensorAsOrtValue(allocator, dim.data(), dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &tensor); status = api->CreateTensorAsOrtValue(allocator, dim.data(), dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, &tensor);