Add ability to prevent exception propagation if building as part of ORT when ORT has exceptions disabled (#368)
* Add ability to prevent exception propagation with top level try/catch hander macros. If combined build with ORT has exceptions disabled in ORT but ort-ext has an operator that requires exceptions, we enable exceptions in ort-ext but prevent them propagating up via try/catch in the entry points that ORT can call - RegisterCustomOps - CustomOpBase constructor and Compute Removed some places in CustomOpApi that threw is OpKernelInfo* was nullptr but standardizing all kernels to store the OpKernelInfo provided in the ctor. Added unit tests - need to validate on more platforms and add CI for build where we don't want to allow exceptions to propagate * Update pyop * Update CMakeLists.txt Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> * Update includes/exceptions.h Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> * Update includes/exceptions.h Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> * Update includes/onnxruntime_customop.hpp Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> * Merge with main and update Address PR comments Fix some issues. * Delete local file * Fix pyop update * Add CI Address PR comments --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
This commit is contained in:
Родитель
b375cb57e6
Коммит
5e44a7c3c9
|
@ -449,3 +449,32 @@ jobs:
|
|||
parameters:
|
||||
xcWorkspacePath: '$(Build.SourcesDirectory)/test/ios/OrtExtensionsUsage/OrtExtensionsUsage.xcworkspace'
|
||||
scheme: 'OrtExtensionsUsage'
|
||||
|
||||
|
||||
#####################################
|
||||
# Linux prevent exception propagation
|
||||
#####################################
|
||||
- job: Linux_Prevent_Exception_Propagation
|
||||
pool:
|
||||
vmImage: 'ubuntu-latest'
|
||||
|
||||
steps:
|
||||
# Simulate an embedded build as part of ORT with exceptions disabled by manually setting CMAKE_CXX_FLAGS and
|
||||
# using _OCOS_PREVENT_EXCEPTION_PROPAGATION_OVERRIDE. The build should re-enable exceptions within ort-ext
|
||||
# but prevent them from propagating. Unit tests are run to validate this.
|
||||
- script: '
|
||||
./build_lib.sh --enable_cxx_tests --onnxruntime_version 1.14.0 --config RelWithDebInfo
|
||||
--cmake_extra_defines
|
||||
_OCOS_PREVENT_EXCEPTION_PROPAGATION_OVERRIDE=ON OCOS_ENABLE_CPP_EXCEPTIONS=OFF
|
||||
CMAKE_CXX_FLAGS="-fno-exceptions -fno-unwind-tables -fno-asynchronous-unwind-tables"
|
||||
'
|
||||
|
||||
displayName: Build ort-ext with exception propagation disabled
|
||||
|
||||
# As an extra validation check CMakeCache.txt as well
|
||||
- script: |
|
||||
grep "^_OCOS_PREVENT_EXCEPTION_PROPAGATION.*ON$" build/Linux/RelWithDebInfo/CMakeCache.txt
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Exception propogation was not enabled correctly."
|
||||
exit 1
|
||||
fi
|
||||
|
|
|
@ -150,6 +150,12 @@ endif()
|
|||
# External dependencies
|
||||
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/externals ${PROJECT_SOURCE_DIR}/cmake)
|
||||
|
||||
# PROJECT_IS_TOP_LEVEL is available since 3.21
|
||||
get_property(not_top DIRECTORY PROPERTY PARENT_DIRECTORY)
|
||||
if(not_top AND ONNXRUNTIME_ROOT)
|
||||
set(_ONNXRUNTIME_EMBEDDED TRUE)
|
||||
endif()
|
||||
|
||||
if(OCOS_ENABLE_SELECTED_OPLIST)
|
||||
# Need to ensure _selectedoplist.cmake file is already generated in folder: ${PROJECT_SOURCE_DIR}/cmake/
|
||||
# You could run gen_selectedops.py in folder: tools/ to generate _selectedoplist.cmake
|
||||
|
@ -158,6 +164,48 @@ if(OCOS_ENABLE_SELECTED_OPLIST)
|
|||
include(_selectedoplist)
|
||||
endif()
|
||||
|
||||
set(_OCOS_EXCEPTIONS_REQUIRED OFF)
|
||||
if (OCOS_ENABLE_GPT2_TOKENIZER OR
|
||||
OCOS_ENABLE_WORDPIECE_TOKENIZER OR
|
||||
OCOS_ENABLE_BLINGFIRE OR
|
||||
OCOS_ENABLE_SPM_TOKENIZER OR
|
||||
(OCOS_ENABLE_CV2 OR OCOS_ENABLE_OPENCV_CODECS OR OCOS_ENABLE_VISION))
|
||||
set(_OCOS_EXCEPTIONS_REQUIRED ON)
|
||||
endif()
|
||||
|
||||
# Special case an embedded build with ORT exceptions disabled but custom ops that require exceptions.
|
||||
# Allow using an override so we can do a direct build of ort-ext in a CI without having to embed it in an ORT build.
|
||||
set(_OCOS_PREVENT_EXCEPTION_PROPAGATION OFF)
|
||||
if (_OCOS_PREVENT_EXCEPTION_PROPAGATION_OVERRIDE)
|
||||
set(_OCOS_PREVENT_EXCEPTION_PROPAGATION ${_OCOS_PREVENT_EXCEPTION_PROPAGATION_OVERRIDE})
|
||||
elseif(_ONNXRUNTIME_EMBEDDED AND onnxruntime_DISABLE_EXCEPTIONS AND _OCOS_EXCEPTIONS_REQUIRED)
|
||||
set(_OCOS_PREVENT_EXCEPTION_PROPAGATION ON)
|
||||
endif()
|
||||
|
||||
if (_OCOS_PREVENT_EXCEPTION_PROPAGATION)
|
||||
message(STATUS "Embedded build as part of ONNX Runtime with exceptions disabled. "
|
||||
"Extensions will be built with exceptions enabled due to included custom ops "
|
||||
"using 3rd party libraries that require exceptions.")
|
||||
|
||||
if (NOT OCOS_ENABLE_CPP_EXCEPTIONS)
|
||||
message(WARNING "Enabling C++ exception support as custom ops included in the build require them to be enabled.")
|
||||
set(OCOS_ENABLE_CPP_EXCEPTIONS ON)
|
||||
endif()
|
||||
|
||||
# undo the flags that ORT has set to disable exceptions.
|
||||
# see https://github.com/microsoft/onnxruntime/blob/b1abb8c656c597bf221bd85682ae3d9e350d9aba/cmake/adjust_global_compile_flags.cmake#L160-L169
|
||||
if(MSVC)
|
||||
string(REPLACE "/EHs-c-" "/EHsc" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
else()
|
||||
string(REPLACE "-fno-exceptions" "-fexceptions" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
string(REPLACE "-fno-unwind-tables" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
string(REPLACE "-fno-asynchronous-unwind-tables" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
endif()
|
||||
|
||||
# the ort-ext code has to provide a barrier between the exception enabled custom op code and ORT.
|
||||
add_compile_definitions(OCOS_PREVENT_EXCEPTION_PROPAGATION)
|
||||
endif()
|
||||
|
||||
if(NOT OCOS_ENABLE_CPP_EXCEPTIONS)
|
||||
add_compile_definitions(OCOS_NO_EXCEPTIONS ORT_NO_EXCEPTIONS)
|
||||
endif()
|
||||
|
@ -182,14 +230,7 @@ endfunction()
|
|||
|
||||
# set default MSVC warning level to 3 for external dependencies
|
||||
set_msvc_c_cpp_compiler_warning_level(3)
|
||||
|
||||
# PROJECT_IS_TOP_LEVEL is available until 3.21
|
||||
get_property(not_top DIRECTORY PROPERTY PARENT_DIRECTORY)
|
||||
if(not_top AND ONNXRUNTIME_ROOT)
|
||||
set(_ONNXRUNTIME_EMBEDDED TRUE)
|
||||
endif()
|
||||
include(ext_ortlib)
|
||||
|
||||
include(gsl)
|
||||
|
||||
macro(standardize_output_folder bin_target)
|
||||
|
@ -214,7 +255,7 @@ endif()
|
|||
|
||||
# ### scan all source files
|
||||
set(TARGET_SRC_NOEXCEPTION)
|
||||
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h")
|
||||
file(GLOB TARGET_SRC "operators/*.cc" "operators/*.h" "includes/*.h*")
|
||||
|
||||
if(OCOS_ENABLE_TF_STRING)
|
||||
set(farmhash_SOURCE_DIR ${PROJECT_SOURCE_DIR}/cmake/externals/farmhash)
|
||||
|
@ -553,7 +594,7 @@ endforeach()
|
|||
if(OCOS_ENABLE_CTEST)
|
||||
if (OCOS_ENABLE_SELECTED_OPLIST)
|
||||
# currently the tests don't handle operator exclusion cleanly.
|
||||
message(FATAL "Due to usage of OCOS_ENABLE_SELECTED_OPLIST excluding operators the tests are unable to be built and run")
|
||||
message(FATAL_ERROR "Due to usage of OCOS_ENABLE_SELECTED_OPLIST excluding operators the tests are unable to be built and run")
|
||||
endif()
|
||||
|
||||
# Enable CTest
|
||||
|
@ -593,7 +634,8 @@ if(OCOS_ENABLE_CTEST)
|
|||
target_link_directories(extensions_test PRIVATE ${ONNXRUNTIME_LIB_DIR})
|
||||
endif()
|
||||
|
||||
target_link_libraries(extensions_test PRIVATE ocos_operators extensions_shared onnxruntime gtest_main ${ocos_libraries} ${LINUX_CC_FLAGS})
|
||||
target_link_libraries(extensions_test PRIVATE ocos_operators extensions_shared onnxruntime gtest_main gmock_main
|
||||
${ocos_libraries} ${LINUX_CC_FLAGS})
|
||||
|
||||
# Copy ONNXRuntime DLLs into bin folder for testing on Windows platform
|
||||
if(WIN32)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
|
||||
#if defined(__ANDROID__)
|
||||
#include <android/log.h>
|
||||
#else
|
||||
#include <iostream>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
namespace OrtW {
|
||||
|
||||
// All C++ methods that can fail will throw an exception of this type
|
||||
struct Exception : std::exception {
|
||||
Exception(std::string message, OrtErrorCode code) : message_{std::move(message)}, code_{code} {}
|
||||
|
||||
OrtErrorCode GetOrtErrorCode() const { return code_; }
|
||||
const char* what() const noexcept override { return message_.c_str(); }
|
||||
|
||||
private:
|
||||
std::string message_;
|
||||
OrtErrorCode code_;
|
||||
};
|
||||
|
||||
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
|
||||
inline void PrintFinalMessage(const char* file, int line, const char* msg) {
|
||||
#if defined(__ANDROID__)
|
||||
__android_log_print(ANDROID_LOG_ERROR, "onnxruntime-extensions", "Exception in %s line %d: %s", file, line, msg);
|
||||
#else
|
||||
std::cerr << "Exception in " << file << " line " << line << ": " << msg << std::endl;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef OCOS_NO_EXCEPTIONS
|
||||
#define ORTX_CXX_API_THROW(string, code) \
|
||||
do { \
|
||||
OrtW::PrintFinalMessage(__FILE__, __LINE__, OrtW::Exception(string, code).what()); \
|
||||
abort(); \
|
||||
} while (false)
|
||||
|
||||
#define OCOS_TRY if (true)
|
||||
#define OCOS_CATCH(x) else if (false)
|
||||
#define OCOS_RETHROW
|
||||
// In order to ignore the catch statement when a specific exception (not ... ) is caught and referred
|
||||
// in the body of the catch statements, it is necessary to wrap the body of the catch statement into
|
||||
// a lambda function. otherwise the exception referred will be undefined and cause build break
|
||||
#define OCOS_HANDLE_EXCEPTION(func)
|
||||
#else
|
||||
#define ORTX_CXX_API_THROW(string, code) \
|
||||
throw OrtW::Exception(string, code)
|
||||
|
||||
#define OCOS_TRY try
|
||||
#define OCOS_CATCH(x) catch (x)
|
||||
#define OCOS_RETHROW throw;
|
||||
#define OCOS_HANDLE_EXCEPTION(func) func()
|
||||
#endif
|
||||
|
||||
inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) {
|
||||
if (status) {
|
||||
std::string error_message = ort.GetErrorMessage(status);
|
||||
OrtErrorCode error_code = ort.GetErrorCode(status);
|
||||
ort.ReleaseStatus(status);
|
||||
ORTX_CXX_API_THROW(std::move(error_message), error_code);
|
||||
}
|
||||
}
|
||||
} // namespace OrtW
|
||||
|
||||
// macros to wrap entry points that ORT calls where we may need to prevent exceptions propagating upwards to ORT
|
||||
#define OCOS_API_IMPL_BEGIN \
|
||||
OCOS_TRY {
|
||||
// if exceptions are disabled (a 3rd party library could throw so we need to handle that)
|
||||
// or we're preventing exception propagation, log and abort().
|
||||
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
|
||||
#define OCOS_API_IMPL_END \
|
||||
} \
|
||||
OCOS_CATCH(const std::exception& ex) { \
|
||||
OCOS_HANDLE_EXCEPTION([&]() { \
|
||||
OrtW::PrintFinalMessage(__FILE__, __LINE__, ex.what()); \
|
||||
abort(); \
|
||||
}); \
|
||||
}
|
||||
#else
|
||||
// rethrow.
|
||||
#define OCOS_API_IMPL_END \
|
||||
} \
|
||||
OCOS_CATCH(const std::exception&) { \
|
||||
OCOS_HANDLE_EXCEPTION([&]() { \
|
||||
OCOS_RETHROW; \
|
||||
}); \
|
||||
}
|
||||
#endif
|
|
@ -18,28 +18,29 @@ constexpr const char* c_OpDomain = "ai.onnx.contrib";
|
|||
constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";
|
||||
|
||||
struct BaseKernel {
|
||||
BaseKernel(const OrtApi& api) : api_(api), info_(nullptr), ort_(api_) {}
|
||||
BaseKernel(const OrtApi& api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
|
||||
BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept : api_(api), info_(info), ort_(api_) {
|
||||
}
|
||||
|
||||
bool HasAttribute(const char* name) const;
|
||||
bool HasAttribute(const char* name) const noexcept;
|
||||
|
||||
template <class T>
|
||||
bool TryToGetAttribute(const char* name, T& value);
|
||||
bool TryToGetAttribute(const char* name, T& value) const noexcept;
|
||||
|
||||
template <class T>
|
||||
T TryToGetAttributeWithDefault(const char* name, T default_value) {
|
||||
T TryToGetAttributeWithDefault(const char* name, T default_value) const noexcept {
|
||||
T& result = default_value;
|
||||
TryToGetAttribute(name, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data);
|
||||
void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
|
||||
const std::vector<int64_t>& data);
|
||||
|
||||
protected:
|
||||
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
|
||||
OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
|
||||
const OrtApi& api_;
|
||||
OrtW::CustomOpApi ort_;
|
||||
const OrtKernelInfo* info_;
|
||||
const OrtKernelInfo& info_;
|
||||
};
|
||||
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
|
|
|
@ -9,56 +9,17 @@
|
|||
#include <cstddef>
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <type_traits>
|
||||
|
||||
#ifdef ORT_NO_EXCEPTIONS
|
||||
#include <cstdio>
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
|
||||
#include "exceptions.h"
|
||||
|
||||
namespace OrtW {
|
||||
|
||||
// All C++ methods that can fail will throw an exception of this type
|
||||
struct Exception : std::exception {
|
||||
Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
|
||||
|
||||
OrtErrorCode GetOrtErrorCode() const { return code_; }
|
||||
const char* what() const noexcept override { return message_.c_str(); }
|
||||
|
||||
private:
|
||||
std::string message_;
|
||||
OrtErrorCode code_;
|
||||
};
|
||||
|
||||
#ifdef ORT_NO_EXCEPTIONS
|
||||
#define ORTX_CXX_API_THROW(string, code) \
|
||||
do { \
|
||||
fprintf(stderr, "%s\n", \
|
||||
OrtW::Exception(string, code).what()); \
|
||||
abort(); \
|
||||
} while (false)
|
||||
#else
|
||||
#define ORTX_CXX_API_THROW(string, code) \
|
||||
throw OrtW::Exception(string, code)
|
||||
#endif
|
||||
|
||||
inline void ThrowOnError(const OrtApi& ort, OrtStatus* status) {
|
||||
if (status) {
|
||||
std::string error_message = ort.GetErrorMessage(status);
|
||||
OrtErrorCode error_code = ort.GetErrorCode(status);
|
||||
ort.ReleaseStatus(status);
|
||||
ORTX_CXX_API_THROW(std::move(error_message), error_code);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// Custom OPs (only needed to implement custom OPs)
|
||||
//
|
||||
|
@ -72,7 +33,8 @@ struct CustomOpApi {
|
|||
size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
|
||||
ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
|
||||
size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
|
||||
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const;
|
||||
void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
|
||||
size_t dim_values_length) const;
|
||||
void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
|
||||
|
||||
template <typename T>
|
||||
|
@ -85,7 +47,8 @@ struct CustomOpApi {
|
|||
size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
|
||||
const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
|
||||
size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
|
||||
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count) const;
|
||||
OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
|
||||
size_t dim_count) const;
|
||||
|
||||
void ThrowOnError(OrtStatus* status) const {
|
||||
OrtW::ThrowOnError(api_, status);
|
||||
|
@ -99,18 +62,44 @@ template <typename TOp, typename TKernel>
|
|||
struct CustomOpBase : OrtCustomOp {
|
||||
CustomOpBase() {
|
||||
OrtCustomOp::version = 10; // The minimum ORT version supported
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
|
||||
OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
|
||||
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
|
||||
void* result = nullptr;
|
||||
OCOS_API_IMPL_BEGIN
|
||||
result = static_cast<const TOp*>(this_)->CreateKernel(*api, *info);
|
||||
OCOS_API_IMPL_END
|
||||
return result;
|
||||
};
|
||||
|
||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
|
||||
OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetName();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
|
||||
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
|
||||
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetExecutionProviderType();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
|
||||
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
|
||||
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetInputTypeCount();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetInputType(index);
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetOutputTypeCount();
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetOutputType(index);
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
||||
OCOS_API_IMPL_BEGIN
|
||||
static_cast<TKernel*>(op_kernel)->Compute(context);
|
||||
OCOS_API_IMPL_END
|
||||
};
|
||||
|
||||
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 26409)
|
||||
|
@ -119,26 +108,29 @@ struct CustomOpBase : OrtCustomOp {
|
|||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
|
||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
|
||||
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetInputCharacteristic(index);
|
||||
};
|
||||
|
||||
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
|
||||
return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
TKernel* CreateKernelImpl(Args&&... args) const {
|
||||
// default implementation. we can't use a virtual function as the layout of this struct has to be aligned with
|
||||
// OrtCustomOp, but a derived class can override by creating a function with the same name and signature,
|
||||
// calling this base class implementation as needed. e.g. see CustomOpThree in the unit test code
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 26409)
|
||||
#endif
|
||||
return new TKernel(std::forward<Args>(args)...);
|
||||
return new TKernel(api, info);
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
}
|
||||
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api);
|
||||
}
|
||||
|
||||
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
|
||||
const char* GetExecutionProviderType() const { return nullptr; }
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
#include <cstdint>
|
||||
|
||||
struct KernelImageDecoder : BaseKernel {
|
||||
KernelImageDecoder(const OrtApi& api) : BaseKernel(api) {}
|
||||
KernelImageDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
|
@ -33,7 +33,7 @@ struct KernelImageDecoder : BaseKernel {
|
|||
// Decode the image
|
||||
const std::vector<int32_t> encoded_image_sizes{1, static_cast<int32_t>(encoded_image_data_len)};
|
||||
const cv::Mat encoded_image(encoded_image_sizes, CV_8UC1,
|
||||
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
|
||||
const_cast<void*>(static_cast<const void*>(encoded_image_data)));
|
||||
const cv::Mat decoded_image = cv::imdecode(encoded_image, cv::IMREAD_COLOR);
|
||||
|
||||
// Setup output & copy to destination
|
||||
|
@ -41,18 +41,14 @@ struct KernelImageDecoder : BaseKernel {
|
|||
const int64_t colors = 3;
|
||||
|
||||
const std::vector<int64_t> output_dimensions{decoded_image_size.height, decoded_image_size.width, colors};
|
||||
OrtValue *const output_value = ort_.KernelContext_GetOutput(
|
||||
context, 0, output_dimensions.data(), output_dimensions.size());
|
||||
OrtValue* const output_value = ort_.KernelContext_GetOutput(
|
||||
context, 0, output_dimensions.data(), output_dimensions.size());
|
||||
uint8_t* const decoded_image_data = ort_.GetTensorMutableData<uint8_t>(output_value);
|
||||
memcpy(decoded_image_data, decoded_image.data, decoded_image.total() * decoded_image.elemSize());
|
||||
}
|
||||
};
|
||||
|
||||
struct CustomOpImageDecoder : OrtW::CustomOpBase<CustomOpImageDecoder, KernelImageDecoder> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelImageDecoder(api);
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "ImageDecoder";
|
||||
}
|
||||
|
|
|
@ -3,9 +3,8 @@
|
|||
|
||||
#include "string_tensor.h"
|
||||
|
||||
|
||||
struct KernelImageReader : BaseKernel {
|
||||
KernelImageReader(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelImageReader(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
|
@ -45,7 +44,7 @@ struct CustomOpImageReader : OrtW::CustomOpBase<CustomOpImageReader, KernelImage
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
}
|
||||
|
||||
const char* GetName() const{
|
||||
const char* GetName() const {
|
||||
return "ImageReader";
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
#include <opencv2/core.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
|
||||
struct KernelGaussianBlur : BaseKernel {
|
||||
KernelGaussianBlur(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelGaussianBlur(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
|
@ -47,7 +46,7 @@ struct KernelGaussianBlur : BaseKernel {
|
|||
sigma[0], sigma[1], cv::BORDER_DEFAULT);
|
||||
|
||||
OrtValue* image_y = ort_.KernelContext_GetOutput(context,
|
||||
0, input_data_dimensions.data(), input_data_dimensions.size());
|
||||
0, input_data_dimensions.data(), input_data_dimensions.size());
|
||||
float* p_output_image = ort_.GetTensorMutableData<float>(image_y);
|
||||
memcpy(p_output_image, output_image.data, output_image.total() * output_image.elemSize());
|
||||
}
|
||||
|
@ -76,7 +75,7 @@ struct CustomOpGaussianBlur : OrtW::CustomOpBase<CustomOpGaussianBlur, KernelGau
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
const char* GetName() const{
|
||||
const char* GetName() const {
|
||||
return "GaussianBlur";
|
||||
}
|
||||
};
|
||||
|
|
|
@ -6,9 +6,8 @@
|
|||
#include <dlib/matrix.h>
|
||||
#include "ocos.h"
|
||||
|
||||
|
||||
struct KernelInverse : BaseKernel {
|
||||
KernelInverse(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelInverse(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
|
@ -23,7 +22,7 @@ struct KernelInverse : BaseKernel {
|
|||
}
|
||||
|
||||
OrtValue* output0 = ort_.KernelContext_GetOutput(
|
||||
context, 0, dimensions.data(), dimensions.size());
|
||||
context, 0, dimensions.data(), dimensions.size());
|
||||
float* out0 = ort_.GetTensorMutableData<float>(output0);
|
||||
|
||||
dlib::matrix<float> dm_x(dimensions[0], dimensions[1]);
|
||||
|
|
|
@ -6,10 +6,10 @@
|
|||
#include "ocos.h"
|
||||
|
||||
struct KernelNegPos : BaseKernel {
|
||||
KernelNegPos(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelNegPos(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context){
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
|
||||
const float* X = ort_.GetTensorData<float>(input_X);
|
||||
|
@ -40,22 +40,22 @@ struct KernelNegPos : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpNegPos : OrtW::CustomOpBase<CustomOpNegPos, KernelNegPos> {
|
||||
const char* GetName() const{
|
||||
const char* GetName() const {
|
||||
return "NegPos";
|
||||
}
|
||||
|
||||
size_t GetInputTypeCount() const{
|
||||
size_t GetInputTypeCount() const {
|
||||
return 1;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const{
|
||||
|
||||
size_t GetOutputTypeCount() const {
|
||||
return 2;
|
||||
}
|
||||
|
||||
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
|
||||
#include "segment_extraction.hpp"
|
||||
|
||||
KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelSegmentExtraction::KernelSegmentExtraction(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
||||
|
@ -11,7 +12,7 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
|||
const int64_t* p_data = ort_.GetTensorData<int64_t>(input);
|
||||
OrtTensorDimensions input_dim(ort_, input);
|
||||
if (!((input_dim.size() == 1) || (input_dim.size() == 2 && input_dim[0] == 1))) {
|
||||
ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n]." , ORT_INVALID_GRAPH);
|
||||
ORTX_CXX_API_THROW("[SegmentExtraction]: Expect input dimension [n] or [1,n].", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
std::vector<std::int64_t> segment_value;
|
||||
|
@ -35,8 +36,8 @@ void KernelSegmentExtraction::Compute(OrtKernelContext* context) {
|
|||
|
||||
std::vector<int64_t> segment_value_dim({static_cast<int64_t>(segment_value.size())});
|
||||
std::vector<int64_t> segment_position_dim({static_cast<int64_t>(segment_value.size()), 2});
|
||||
SetOutput(context, 0, segment_position_dim, segment_position);
|
||||
SetOutput(context, 1, segment_value_dim, segment_value);
|
||||
SetOutput(context, 0, segment_position_dim, segment_position);
|
||||
SetOutput(context, 1, segment_value_dim, segment_value);
|
||||
}
|
||||
|
||||
size_t CustomOpSegmentExtraction::GetInputTypeCount() const {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelSegmentExtraction : BaseKernel {
|
||||
KernelSegmentExtraction(const OrtApi& api);
|
||||
KernelSegmentExtraction(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -20,8 +20,9 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
|
|||
ORTX_CXX_API_THROW("segment_ids must a single tensor", ORT_INVALID_GRAPH);
|
||||
if (dim_data[0] != dim_seg[0])
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
|
||||
" segment_ids shape: ", dim_seg), ORT_INVALID_GRAPH);
|
||||
"First dimensions of data and segment_ids should be the same, data shape: ", dim_data,
|
||||
" segment_ids shape: ", dim_seg),
|
||||
ORT_INVALID_GRAPH);
|
||||
|
||||
int64_t last_seg = p_segment_ids[dim_seg[0] - 1];
|
||||
OrtTensorDimensions dim_out = dim_data;
|
||||
|
@ -43,8 +44,9 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
|
|||
for (; begin != end; ++p_seg) {
|
||||
if ((p_seg != p_segment_ids) && (*p_seg != *(p_seg - 1)) && (*p_seg != *(p_seg - 1) + 1))
|
||||
ORTX_CXX_API_THROW(MakeString("segment_ids must be increasing but found ",
|
||||
*(p_seg - 1), " and ", *p_seg, " at position ",
|
||||
std::distance(p_segment_ids, p_seg), "."), ORT_RUNTIME_EXCEPTION);
|
||||
*(p_seg - 1), " and ", *p_seg, " at position ",
|
||||
std::distance(p_segment_ids, p_seg), "."),
|
||||
ORT_RUNTIME_EXCEPTION);
|
||||
p_out = p_output + *p_seg * in_stride;
|
||||
p_out_end = p_out + in_stride;
|
||||
for (; p_out != p_out_end; ++p_out, ++begin)
|
||||
|
@ -52,7 +54,7 @@ void KernelSegmentSum_Compute(OrtW::CustomOpApi& ort_, OrtKernelContext* context
|
|||
}
|
||||
}
|
||||
|
||||
KernelSegmentSum::KernelSegmentSum(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelSegmentSum::KernelSegmentSum(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelSegmentSum::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelSegmentSum : BaseKernel {
|
||||
KernelSegmentSum(const OrtApi& api);
|
||||
KernelSegmentSum(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -3,14 +3,11 @@
|
|||
#include <sstream>
|
||||
#include "ocos.h"
|
||||
|
||||
bool BaseKernel::HasAttribute(const char* name) const {
|
||||
if (info_ == nullptr) {
|
||||
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
bool BaseKernel::HasAttribute(const char* name) const noexcept {
|
||||
size_t size;
|
||||
std::string out;
|
||||
// Crashes here.
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
|
||||
auto r = api_.GetErrorCode(status);
|
||||
bool has = (r == ORT_INVALID_ARGUMENT) || (r == ORT_OK);
|
||||
if (has) {
|
||||
|
@ -26,7 +23,7 @@ bool BaseKernel::HasAttribute(const char* name) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
|
||||
OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept {
|
||||
if (status == nullptr) {
|
||||
return ORT_OK;
|
||||
}
|
||||
|
@ -35,22 +32,19 @@ OrtErrorCode BaseKernel::GetErrorCodeAndRelease(OrtStatusPtr status) {
|
|||
return error_code;
|
||||
}
|
||||
|
||||
void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data) {
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
|
||||
int64_t * data_ptr = ort_.GetTensorMutableData<int64_t>(output);
|
||||
for (size_t i = 0; i < data.size(); i++) {
|
||||
data_ptr[i] = data[i];
|
||||
}
|
||||
void BaseKernel::SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
|
||||
const std::vector<int64_t>& data) {
|
||||
OrtValue* output = ort_.KernelContext_GetOutput(ctx, output_idx, dim.data(), dim.size());
|
||||
int64_t* data_ptr = ort_.GetTensorMutableData<int64_t>(output);
|
||||
for (size_t i = 0; i < data.size(); i++) {
|
||||
data_ptr[i] = data[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
||||
if (info_ == nullptr) {
|
||||
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) const noexcept {
|
||||
size_t size = 0;
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info_, name, nullptr, &size);
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(&info_, name, nullptr, &size);
|
||||
|
||||
// The status should be a nullptr when querying for the size.
|
||||
if (status != nullptr) {
|
||||
|
@ -59,7 +53,7 @@ bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
|||
}
|
||||
|
||||
value.resize(size);
|
||||
status = api_.KernelInfoGetAttribute_string(info_, name, &value[0], &size);
|
||||
status = api_.KernelInfoGetAttribute_string(&info_, name, &value[0], &size);
|
||||
if (GetErrorCodeAndRelease(status) != ORT_OK) {
|
||||
return false;
|
||||
}
|
||||
|
@ -69,31 +63,19 @@ bool BaseKernel::TryToGetAttribute(const char* name, std::string& value) {
|
|||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) {
|
||||
if (info_ == nullptr) {
|
||||
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &value)) == ORT_OK;
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, int64_t& value) const noexcept {
|
||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &value)) == ORT_OK;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, float& value) {
|
||||
if (info_ == nullptr) {
|
||||
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(info_, name, &value)) == ORT_OK;
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, float& value) const noexcept {
|
||||
return GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_float(&info_, name, &value)) == ORT_OK;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) {
|
||||
if (info_ == nullptr) {
|
||||
ORTX_CXX_API_THROW("Kernel was incorrectly initialized, pointer info_ cannot be null.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
bool BaseKernel::TryToGetAttribute(const char* name, bool& value) const noexcept {
|
||||
int64_t origin_value = 0;
|
||||
if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(info_, name, &origin_value)) != ORT_OK) {
|
||||
if (GetErrorCodeAndRelease(api_.KernelInfoGetAttribute_int64(&info_, name, &origin_value)) != ORT_OK) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -8,8 +8,7 @@
|
|||
#include <codecvt>
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* /*info*/) : BaseKernel(api) {
|
||||
KernelMaskedFill::KernelMaskedFill(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
||||
|
@ -29,7 +28,7 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
|
||||
std::vector<std::string> value;
|
||||
const bool * mask = nullptr;
|
||||
const bool* mask = nullptr;
|
||||
|
||||
GetTensorMutableDataString(api_, ort_, context, input_value, value);
|
||||
mask = ort_.GetTensorData<bool>(input_mask);
|
||||
|
@ -51,10 +50,6 @@ void KernelMaskedFill::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
}
|
||||
|
||||
void* CustomOpMaskedFill::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpMaskedFill::GetName() const { return "MaskedFill"; };
|
||||
|
||||
size_t CustomOpMaskedFill::GetInputTypeCount() const {
|
||||
|
@ -69,7 +64,8 @@ ONNXTensorElementDataType CustomOpMaskedFill::GetInputType(size_t index) const {
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
size_t CustomOpMaskedFill::GetOutputTypeCount() const {
|
||||
return 1;
|
||||
|
|
|
@ -8,14 +8,14 @@
|
|||
#include <unordered_map>
|
||||
|
||||
struct KernelMaskedFill : BaseKernel {
|
||||
KernelMaskedFill(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelMaskedFill(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> map_;
|
||||
};
|
||||
|
||||
struct CustomOpMaskedFill : OrtW::CustomOpBase<CustomOpMaskedFill, KernelMaskedFill> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
#include "op_equal_impl.hpp"
|
||||
#include <string>
|
||||
|
||||
KernelStringEqual::KernelStringEqual(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelStringEqual::KernelStringEqual(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringEqual::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringEqual : BaseKernel {
|
||||
KernelStringEqual(const OrtApi& api);
|
||||
KernelStringEqual(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -12,7 +12,8 @@ void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
|
|||
|
||||
if (d_length.size() != 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
|
||||
"First input must have one dimension not ", d_length.size(), "."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
int64_t n_els = d_length[0] - 1;
|
||||
int64_t n_values = p_n_elements[n_els];
|
||||
std::vector<int64_t> shape{n_values, 2};
|
||||
|
@ -58,7 +59,8 @@ ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*in
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
|
||||
|
@ -77,8 +79,9 @@ int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices
|
|||
return max_col;
|
||||
}
|
||||
|
||||
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
||||
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1;
|
||||
KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: CommonRaggedTensorToDense(api, info) {
|
||||
missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(&info, "missing_value") : -1;
|
||||
}
|
||||
|
||||
void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||
|
@ -104,8 +107,9 @@ void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
|||
pos_end = pos + max_col;
|
||||
if (pos_end > shape_out_size)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||
" - i=", i, " size=", size, "."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
||||
dense[pos] = p_values[j];
|
||||
}
|
||||
|
@ -127,10 +131,6 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*in
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
void* CustomOpRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpRaggedTensorToDense::GetName() const {
|
||||
return "RaggedTensorToDense";
|
||||
};
|
||||
|
@ -139,7 +139,7 @@ ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*ind
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
|
||||
KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info) : CommonRaggedTensorToDense(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
||||
|
@ -162,8 +162,9 @@ void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
|
|||
pos_end = pos + max_col;
|
||||
if (pos_end > shape_out_size)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||
" - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
|
||||
"Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
|
||||
" - i=", i, " size=", size, "."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
|
||||
dense[static_cast<size_t>(pos)] = input[static_cast<size_t>(j)];
|
||||
}
|
||||
|
@ -186,10 +187,6 @@ ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
|
||||
void* CustomOpStringRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpStringRaggedTensorToDense::GetName() const {
|
||||
return "StringRaggedTensorToDense";
|
||||
};
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
#include "ocos.h"
|
||||
|
||||
struct KernelRaggedTensorToSparse : BaseKernel {
|
||||
KernelRaggedTensorToSparse(const OrtApi& api)
|
||||
: BaseKernel(api) {}
|
||||
KernelRaggedTensorToSparse(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
@ -21,7 +21,7 @@ struct CustomOpRaggedTensorToSparse : OrtW::CustomOpBase<CustomOpRaggedTensorToS
|
|||
};
|
||||
|
||||
struct CommonRaggedTensorToDense : BaseKernel {
|
||||
CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
|
||||
CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
|
||||
protected:
|
||||
void GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims);
|
||||
|
@ -29,7 +29,7 @@ struct CommonRaggedTensorToDense : BaseKernel {
|
|||
};
|
||||
|
||||
struct KernelRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||
KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -41,20 +41,19 @@ struct CustomOpRaggedTensorToDense : OrtW::CustomOpBase<CustomOpRaggedTensorToDe
|
|||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
||||
struct KernelStringRaggedTensorToDense : CommonRaggedTensorToDense {
|
||||
KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringRaggedTensorToDense : OrtW::CustomOpBase<CustomOpStringRaggedTensorToDense, KernelStringRaggedTensorToDense> {
|
||||
struct CustomOpStringRaggedTensorToDense : OrtW::CustomOpBase<CustomOpStringRaggedTensorToDense,
|
||||
KernelStringRaggedTensorToDense> {
|
||||
size_t GetInputTypeCount() const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
const char* GetName() const;
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* /* info */) const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
};
|
||||
|
|
|
@ -8,8 +8,9 @@
|
|||
#include "re2/re2.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
|
||||
KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(&info_, "global_replace") : 1;
|
||||
}
|
||||
|
||||
void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
||||
|
@ -28,12 +29,14 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
|||
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
|
||||
if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"pattern (second input) must contain only one element. It has ",
|
||||
pattern_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||
"pattern (second input) must contain only one element. It has ",
|
||||
pattern_dimensions.size(), " dimensions."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"rewrite (third input) must contain only one element. It has ",
|
||||
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
|
||||
"rewrite (third input) must contain only one element. It has ",
|
||||
rewrite_dimensions.size(), " dimensions."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
if (str_pattern[0].empty())
|
||||
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
|
||||
|
@ -62,10 +65,6 @@ void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, str_input, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringRegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpStringRegexReplace::GetName() const { return "StringRegexReplace"; };
|
||||
|
||||
size_t CustomOpStringRegexReplace::GetInputTypeCount() const {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringRegexReplace : BaseKernel {
|
||||
KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
protected:
|
||||
|
@ -15,7 +15,6 @@ struct KernelStringRegexReplace : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpStringRegexReplace : OrtW::CustomOpBase<CustomOpStringRegexReplace, KernelStringRegexReplace> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
||||
|
@ -24,13 +25,13 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
// Verifications
|
||||
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
|
||||
if (str_pattern.size() != 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"pattern (second input) must contain only one element. It has ",
|
||||
str_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ",
|
||||
str_pattern.size(), " values."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
if (str_keep_pattern.size() > 1)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"Third input must contain only one element. It has ",
|
||||
str_keep_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ",
|
||||
str_keep_pattern.size(), " values."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
if (str_pattern[0].empty())
|
||||
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT);
|
||||
|
||||
|
@ -50,8 +51,8 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
std::vector<int64_t> begin_offsets;
|
||||
std::vector<int64_t> end_offsets;
|
||||
RegexSplitImpl(str_input[static_cast<size_t>(i)], reg,
|
||||
include_delimiter, keep_reg,
|
||||
tokens, begin_offsets, end_offsets);
|
||||
include_delimiter, keep_reg,
|
||||
tokens, begin_offsets, end_offsets);
|
||||
all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end());
|
||||
for (size_t j = 0; j < begin_offsets.size(); ++j) {
|
||||
all_begin_offsets.push_back(begin_offsets[j]);
|
||||
|
@ -79,10 +80,6 @@ void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
void* CustomOpStringRegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpStringRegexSplitWithOffsets::GetName() const { return "StringRegexSplitWithOffsets"; };
|
||||
|
||||
size_t CustomOpStringRegexSplitWithOffsets::GetInputTypeCount() const {
|
||||
|
@ -106,7 +103,7 @@ ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetOutputType(siz
|
|||
case 3:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), ORT_INVALID_ARGUMENT);
|
||||
ORTX_CXX_API_THROW(MakeString("StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -8,12 +8,11 @@
|
|||
|
||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||
struct KernelStringRegexSplitWithOffsets : BaseKernel {
|
||||
KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpStringRegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringRegexSplitWithOffsets, KernelStringRegexSplitWithOffsets> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -8,8 +8,7 @@
|
|||
#include <codecvt>
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
KernelStringConcat::KernelStringConcat(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelStringConcat::KernelStringConcat(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringConcat::Compute(OrtKernelContext* context) {
|
||||
|
@ -53,4 +52,4 @@ size_t CustomOpStringConcat::GetOutputTypeCount() const {
|
|||
|
||||
ONNXTensorElementDataType CustomOpStringConcat::GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
|
||||
};
|
||||
};
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringConcat : BaseKernel {
|
||||
KernelStringConcat(const OrtApi& api);
|
||||
KernelStringConcat(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,10 +7,10 @@
|
|||
#include <regex>
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelStringECMARegexReplace::KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
global_replace_ = TryToGetAttributeWithDefault("global_replace", true);
|
||||
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
||||
|
||||
}
|
||||
|
||||
void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
|
||||
|
@ -24,19 +24,18 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
|
|||
GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
|
||||
GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);
|
||||
|
||||
|
||||
// Verifications
|
||||
OrtTensorDimensions pattern_dimensions(ort_, pattern);
|
||||
OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
|
||||
if (pattern_dimensions.Size() != 1) {
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"pattern (second input) must contain only one element. It has ",
|
||||
pattern_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
|
||||
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ",
|
||||
pattern_dimensions.size(), " dimensions."),
|
||||
ORT_INVALID_GRAPH);
|
||||
}
|
||||
if (rewrite_dimensions.Size() != 1) {
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"rewrite (third input) must contain only one element. It has ",
|
||||
rewrite_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
|
||||
ORTX_CXX_API_THROW(MakeString("rewrite (third input) must contain only one element. It has ",
|
||||
rewrite_dimensions.size(), " dimensions."),
|
||||
ORT_INVALID_GRAPH);
|
||||
}
|
||||
if (str_pattern[0].empty()) {
|
||||
ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_GRAPH);
|
||||
|
@ -70,10 +69,6 @@ void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, str_input, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringECMARegexReplace::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpStringECMARegexReplace::GetName() const { return "StringECMARegexReplace"; };
|
||||
|
||||
size_t CustomOpStringECMARegexReplace::GetInputTypeCount() const {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringECMARegexReplace : BaseKernel {
|
||||
KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelStringECMARegexReplace(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
protected:
|
||||
|
@ -16,7 +16,6 @@ struct KernelStringECMARegexReplace : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpStringECMARegexReplace : OrtW::CustomOpBase<CustomOpStringECMARegexReplace, KernelStringECMARegexReplace> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -9,8 +9,9 @@
|
|||
#include "string_ecmaregex_split.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
|
||||
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelStringECMARegexSplitWithOffsets::KernelStringECMARegexSplitWithOffsets(const OrtApi& api,
|
||||
const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
|
||||
}
|
||||
|
||||
|
@ -28,9 +29,13 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
// Verifications
|
||||
OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
|
||||
if (str_pattern.size() != 1)
|
||||
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(), " values."), ORT_INVALID_GRAPH);
|
||||
ORTX_CXX_API_THROW(MakeString("pattern (second input) must contain only one element. It has ", str_pattern.size(),
|
||||
" values."),
|
||||
ORT_INVALID_GRAPH);
|
||||
if (str_keep_pattern.size() > 1)
|
||||
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(), " values."), ORT_INVALID_GRAPH);
|
||||
ORTX_CXX_API_THROW(MakeString("Third input must contain only one element. It has ", str_keep_pattern.size(),
|
||||
" values."),
|
||||
ORT_INVALID_GRAPH);
|
||||
if (str_pattern[0].empty())
|
||||
ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_GRAPH);
|
||||
|
||||
|
@ -84,10 +89,6 @@ void KernelStringECMARegexSplitWithOffsets::Compute(OrtKernelContext* context) {
|
|||
memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
void* CustomOpStringECMARegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpStringECMARegexSplitWithOffsets::GetName() const { return "StringECMARegexSplitWithOffsets"; };
|
||||
|
||||
size_t CustomOpStringECMARegexSplitWithOffsets::GetInputTypeCount() const {
|
||||
|
@ -112,7 +113,7 @@ ONNXTensorElementDataType CustomOpStringECMARegexSplitWithOffsets::GetOutputType
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
"StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -9,14 +9,14 @@
|
|||
|
||||
// See https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/regex_split_with_offsets.md.
|
||||
struct KernelStringECMARegexSplitWithOffsets : BaseKernel {
|
||||
KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelStringECMARegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
bool ignore_case_;
|
||||
};
|
||||
|
||||
struct CustomOpStringECMARegexSplitWithOffsets : OrtW::CustomOpBase<CustomOpStringECMARegexSplitWithOffsets, KernelStringECMARegexSplitWithOffsets> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
@ -48,7 +48,7 @@ void ECMARegexSplitImpl(const std::string& input, const std::regex& pattern,
|
|||
end_offsets.push_back(prev_pos + matched_length);
|
||||
}
|
||||
|
||||
//no mather include the delimiter, we should skip it
|
||||
// no mather include the delimiter, we should skip it
|
||||
prev_pos += matched_length;
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include "string_hash.hpp"
|
||||
|
||||
|
||||
KernelStringHash::KernelStringHash(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelStringHash::KernelStringHash(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringHash::Compute(OrtKernelContext* context) {
|
||||
|
@ -68,7 +68,7 @@ ONNXTensorElementDataType CustomOpStringHash::GetOutputType(size_t /*index*/) co
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
};
|
||||
|
||||
KernelStringHashFast::KernelStringHashFast(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelStringHashFast::KernelStringHashFast(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringHashFast::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringHash : BaseKernel {
|
||||
KernelStringHash(const OrtApi& api);
|
||||
KernelStringHash(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
@ -20,7 +20,7 @@ struct CustomOpStringHash : OrtW::CustomOpBase<CustomOpStringHash, KernelStringH
|
|||
};
|
||||
|
||||
struct KernelStringHashFast : BaseKernel {
|
||||
KernelStringHashFast(const OrtApi& api);
|
||||
KernelStringHashFast(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "string_join.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringJoin::KernelStringJoin(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelStringJoin::KernelStringJoin(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringJoin::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringJoin : BaseKernel {
|
||||
KernelStringJoin(const OrtApi& api);
|
||||
KernelStringJoin(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
#include <algorithm>
|
||||
#include "ustring.h"
|
||||
|
||||
KernelStringLength::KernelStringLength(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelStringLength::KernelStringLength(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringLength::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLength : BaseKernel {
|
||||
KernelStringLength(const OrtApi& api);
|
||||
KernelStringLength(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
KernelStringLower::KernelStringLower(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelStringLower::KernelStringLower(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringLower::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringLower : BaseKernel {
|
||||
KernelStringLower(const OrtApi& api);
|
||||
KernelStringLower(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -8,11 +8,10 @@
|
|||
#include <codecvt>
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api) {
|
||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
||||
KernelStringMapping::KernelStringMapping(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(&info, "map");
|
||||
auto lines = SplitString(map, "\n", true);
|
||||
for (const auto& line: lines) {
|
||||
for (const auto& line : lines) {
|
||||
auto items = SplitString(line, "\t", true);
|
||||
|
||||
if (items.size() != 2) {
|
||||
|
@ -30,7 +29,7 @@ void KernelStringMapping::Compute(OrtKernelContext* context) {
|
|||
|
||||
OrtTensorDimensions dimensions(ort_, input);
|
||||
|
||||
for (auto& str: input_data) {
|
||||
for (auto& str : input_data) {
|
||||
if (map_.find(str) != map_.end()) {
|
||||
str = map_[str];
|
||||
}
|
||||
|
@ -41,10 +40,6 @@ void KernelStringMapping::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, input_data, output);
|
||||
}
|
||||
|
||||
void* CustomOpStringMapping::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpStringMapping::GetName() const { return "StringMapping"; };
|
||||
|
||||
size_t CustomOpStringMapping::GetInputTypeCount() const {
|
||||
|
|
|
@ -8,14 +8,14 @@
|
|||
#include <unordered_map>
|
||||
|
||||
struct KernelStringMapping : BaseKernel {
|
||||
KernelStringMapping(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelStringMapping(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::string> map_;
|
||||
};
|
||||
|
||||
struct CustomOpStringMapping : OrtW::CustomOpBase<CustomOpStringMapping, KernelStringMapping> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "string_split.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
KernelStringSplit::KernelStringSplit(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelStringSplit::KernelStringSplit(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringSplit::Compute(OrtKernelContext* context) {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringSplit : BaseKernel {
|
||||
KernelStringSplit(const OrtApi& api);
|
||||
KernelStringSplit(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
#include "string_to_vector.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
|
||||
StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
|
||||
ParseMappingTable(map);
|
||||
ParseUnkownValue(unk);
|
||||
|
@ -12,7 +11,7 @@ StringToVectorImpl::StringToVectorImpl(std::string& map, std::string& unk) {
|
|||
|
||||
std::vector<std::vector<int64_t>> StringToVectorImpl::Compute(std::vector<std::string>& str_input, const OrtTensorDimensions& input_dim, OrtTensorDimensions& output_dim) {
|
||||
std::vector<std::vector<int64_t>> result;
|
||||
|
||||
|
||||
// Set output dimension
|
||||
output_dim = input_dim;
|
||||
output_dim.push_back(vector_len_);
|
||||
|
@ -100,10 +99,10 @@ void StringToVectorImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
}
|
||||
}
|
||||
|
||||
KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
||||
KernelStringToVector::KernelStringToVector(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(&info, "map");
|
||||
// unk_value is string here because KernelInfoGetAttribute doesn't support returning vector
|
||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
|
||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(&info, "unk");
|
||||
|
||||
impl_ = std::make_shared<StringToVectorImpl>(map, unk);
|
||||
}
|
||||
|
@ -132,10 +131,6 @@ void KernelStringToVector::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
void* CustomOpStringToVector::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpStringToVector::GetName() const { return "StringToVector"; };
|
||||
|
||||
size_t CustomOpStringToVector::GetInputTypeCount() const {
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
#include "ocos.h"
|
||||
#include "string_utils.h"
|
||||
|
||||
|
||||
class StringToVectorImpl {
|
||||
public:
|
||||
StringToVectorImpl(std::string& map, std::string& unk);
|
||||
|
@ -29,7 +28,7 @@ class StringToVectorImpl {
|
|||
};
|
||||
|
||||
struct KernelStringToVector : BaseKernel {
|
||||
KernelStringToVector(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelStringToVector(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -37,7 +36,6 @@ struct KernelStringToVector : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpStringToVector : OrtW::CustomOpBase<CustomOpStringToVector, KernelStringToVector> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
KernelStringUpper::KernelStringUpper(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelStringUpper::KernelStringUpper(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void KernelStringUpper::Compute(OrtKernelContext* context) {
|
||||
|
@ -17,7 +17,7 @@ void KernelStringUpper::Compute(OrtKernelContext* context) {
|
|||
GetTensorMutableDataString(api_, ort_, context, input_X, X);
|
||||
|
||||
for (size_t i = 0; i < X.size(); ++i) {
|
||||
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c){ return static_cast<char>(::toupper(c)); });
|
||||
std::transform(X[i].begin(), X[i].end(), X[i].begin(), [](char c) { return static_cast<char>(::toupper(c)); });
|
||||
}
|
||||
|
||||
// Fills the output
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "string_utils.h"
|
||||
|
||||
struct KernelStringUpper : BaseKernel {
|
||||
KernelStringUpper(const OrtApi& api);
|
||||
KernelStringUpper(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
|
|
|
@ -4,12 +4,11 @@
|
|||
#include "vector_to_string.hpp"
|
||||
#include "string_tensor.h"
|
||||
|
||||
|
||||
namespace std {
|
||||
|
||||
template <class T>
|
||||
size_t hash<std::vector<T>>::operator()(const vector<T>& __vector) const noexcept {
|
||||
return util::Hash(reinterpret_cast<const char *>(__vector.data()), __vector.size() * sizeof(T));
|
||||
return util::Hash(reinterpret_cast<const char*>(__vector.data()), __vector.size() * sizeof(T));
|
||||
}
|
||||
|
||||
template struct hash<std::vector<std::string>>;
|
||||
|
@ -38,7 +37,7 @@ std::vector<std::string> VectorToStringImpl::Compute(const void* input, const Or
|
|||
|
||||
std::vector<int64_t> key(vector_len_);
|
||||
for (int64_t i = 0; i < input_dim.Size(); i = static_cast<int64_t>(i + vector_len_)) {
|
||||
//construct key
|
||||
// construct key
|
||||
for (size_t j = 0; j < vector_len_; j++) {
|
||||
key[j] = ptr[j];
|
||||
}
|
||||
|
@ -103,9 +102,9 @@ void VectorToStringImpl::ParseValues(const std::string_view& v, std::vector<int6
|
|||
}
|
||||
}
|
||||
|
||||
KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(info, "map");
|
||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unk");
|
||||
KernelVectorToString::KernelVectorToString(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
std::string map = ort_.KernelInfoGetAttribute<std::string>(&info, "map");
|
||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(&info, "unk");
|
||||
|
||||
// TODO: support more type when we can get input type from OrtKernelInfo
|
||||
impl_ = std::make_shared<VectorToStringImpl>(map, unk);
|
||||
|
@ -124,10 +123,6 @@ void KernelVectorToString::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, mapping_result, output);
|
||||
}
|
||||
|
||||
void* CustomOpVectorToString::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpVectorToString::GetName() const { return "VectorToString"; };
|
||||
|
||||
size_t CustomOpVectorToString::GetInputTypeCount() const {
|
||||
|
|
|
@ -30,10 +30,10 @@ class VectorToStringImpl {
|
|||
std::unordered_map<std::vector<int64_t>, std::string> map_;
|
||||
std::string unk_value_;
|
||||
size_t vector_len_;
|
||||
};
|
||||
};
|
||||
|
||||
struct KernelVectorToString : BaseKernel {
|
||||
KernelVectorToString(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelVectorToString(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -41,7 +41,6 @@ struct KernelVectorToString : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpVectorToString : OrtW::CustomOpBase<CustomOpVectorToString, KernelVectorToString> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -9,13 +9,13 @@
|
|||
#include <codecvt>
|
||||
#include <algorithm>
|
||||
|
||||
BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars):
|
||||
do_lower_case_(do_lower_case),
|
||||
strip_accents_(strip_accents),
|
||||
tokenize_chinese_chars_(tokenize_chinese_chars),
|
||||
tokenize_punctuation_(tokenize_punctuation),
|
||||
remove_control_chars_(remove_control_chars)
|
||||
{}
|
||||
BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents,
|
||||
bool tokenize_punctuation, bool remove_control_chars)
|
||||
: do_lower_case_(do_lower_case),
|
||||
strip_accents_(strip_accents),
|
||||
tokenize_chinese_chars_(tokenize_chinese_chars),
|
||||
tokenize_punctuation_(tokenize_punctuation),
|
||||
remove_control_chars_(remove_control_chars) {}
|
||||
|
||||
std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
||||
std::vector<ustring> result;
|
||||
|
@ -42,7 +42,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
|||
|
||||
if (do_lower_case_) {
|
||||
for (auto& c : text) {
|
||||
c = ToLower(c);
|
||||
c = ToLower(c);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
|||
continue;
|
||||
}
|
||||
|
||||
// 0x2019 unicode is not punctuation in some Linux platform,
|
||||
// 0x2019 unicode is not punctuation in some Linux platform,
|
||||
// to be consistent, take it as punctuation.
|
||||
if (tokenize_punctuation_ && IsPunct(c)) {
|
||||
push_current_token_and_clear();
|
||||
|
@ -82,7 +82,7 @@ std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
|
|||
return result;
|
||||
}
|
||||
|
||||
KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
|
@ -109,10 +109,6 @@ void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
}
|
||||
|
||||
void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
|
||||
|
||||
size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
|
||||
|
|
|
@ -9,7 +9,8 @@
|
|||
|
||||
class BasicTokenizer {
|
||||
public:
|
||||
BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars);
|
||||
BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation,
|
||||
bool remove_control_chars);
|
||||
std::vector<ustring> Tokenize(ustring text);
|
||||
|
||||
private:
|
||||
|
@ -21,14 +22,14 @@ class BasicTokenizer {
|
|||
};
|
||||
|
||||
struct KernelBasicTokenizer : BaseKernel {
|
||||
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
std::shared_ptr<BasicTokenizer> tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpBasicTokenizer : OrtW::CustomOpBase<CustomOpBasicTokenizer, KernelBasicTokenizer> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -43,13 +43,10 @@ WordpieceTokenizer::WordpieceTokenizer(
|
|||
std::shared_ptr<BertTokenizerVocab> vocab,
|
||||
ustring unk_token,
|
||||
ustring suffix_indicator,
|
||||
int max_input_chars_per_word
|
||||
) :
|
||||
max_input_chars_per_word_(max_input_chars_per_word),
|
||||
suffix_indicator_(std::move(suffix_indicator)),
|
||||
unk_token_(std::move(unk_token)),
|
||||
vocab_(std::move(vocab))
|
||||
{
|
||||
int max_input_chars_per_word) : max_input_chars_per_word_(max_input_chars_per_word),
|
||||
suffix_indicator_(std::move(suffix_indicator)),
|
||||
unk_token_(std::move(unk_token)),
|
||||
vocab_(std::move(vocab)) {
|
||||
unk_token_id_ = vocab_->FindTokenId(unk_token_);
|
||||
}
|
||||
|
||||
|
@ -190,21 +187,17 @@ BertTokenizer::BertTokenizer(
|
|||
bool strip_accents,
|
||||
ustring suffix_indicator,
|
||||
int32_t max_len,
|
||||
const std::string& truncation_strategy
|
||||
) :
|
||||
max_length_(max_len),
|
||||
do_basic_tokenize_(do_basic_tokenize),
|
||||
truncate_(std::make_unique<TruncateStrategy>(truncation_strategy))
|
||||
{
|
||||
|
||||
const std::string& truncation_strategy) : max_length_(max_len),
|
||||
do_basic_tokenize_(do_basic_tokenize),
|
||||
truncate_(std::make_unique<TruncateStrategy>(truncation_strategy)) {
|
||||
vocab_ = std::make_shared<BertTokenizerVocab>(vocab);
|
||||
|
||||
if (do_basic_tokenize) {
|
||||
basic_tokenizer_ = std::make_unique<BasicTokenizer>(
|
||||
do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
|
||||
do_lower_case, tokenize_chinese_chars, strip_accents, true, true);
|
||||
}
|
||||
wordpiece_tokenizer_ = std::make_unique<WordpieceTokenizer>(
|
||||
vocab_, unk_token, suffix_indicator);
|
||||
vocab_, unk_token, suffix_indicator);
|
||||
|
||||
unk_token_id_ = vocab_->FindTokenId(unk_token);
|
||||
sep_token_id_ = vocab_->FindTokenId(sep_token);
|
||||
|
@ -276,8 +269,8 @@ TruncateStrategy::TruncateStrategy(std::string_view strategy_name) : strategy_(T
|
|||
}
|
||||
}
|
||||
|
||||
KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab_file");
|
||||
bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
|
||||
bool do_basic_tokenize = TryToGetAttributeWithDefault("do_basic_tokenize", true);
|
||||
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
||||
|
@ -288,14 +281,15 @@ KernelBertTokenizer::KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo*
|
|||
bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
|
||||
bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
|
||||
std::string suffix_indicator = TryToGetAttributeWithDefault("suffix_indicator", std::string("##"));
|
||||
std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name", std::string("longest_first"));
|
||||
std::string truncation_strategy_name = TryToGetAttributeWithDefault("truncation_strategy_name",
|
||||
std::string("longest_first"));
|
||||
int32_t max_len = static_cast<int32_t>(TryToGetAttributeWithDefault("max_length", int64_t(-1)));
|
||||
|
||||
tokenizer_ = std::make_unique<BertTokenizer>(
|
||||
vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
|
||||
ustring(sep_token), ustring(pad_token), ustring(cls_token),
|
||||
ustring(mask_token), tokenize_chinese_chars, strip_accents,
|
||||
ustring(suffix_indicator), max_len, truncation_strategy_name);
|
||||
vocab, do_lower_case, do_basic_tokenize, ustring(unk_token),
|
||||
ustring(sep_token), ustring(pad_token), ustring(cls_token),
|
||||
ustring(mask_token), tokenize_chinese_chars, strip_accents,
|
||||
ustring(suffix_indicator), max_len, truncation_strategy_name);
|
||||
}
|
||||
|
||||
void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
|
@ -334,10 +328,6 @@ void KernelBertTokenizer::Compute(OrtKernelContext* context) {
|
|||
SetOutput(context, 2, output_dim, attention_mask);
|
||||
}
|
||||
|
||||
void* CustomOpBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
}
|
||||
|
||||
const char* CustomOpBertTokenizer::GetName() const { return "BertTokenizer"; }
|
||||
|
||||
size_t CustomOpBertTokenizer::GetInputTypeCount() const {
|
||||
|
@ -356,11 +346,12 @@ ONNXTensorElementDataType CustomOpBertTokenizer::GetOutputType(size_t /* index *
|
|||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
}
|
||||
|
||||
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info) : KernelBertTokenizer(api, info) {}
|
||||
KernelHfBertTokenizer::KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: KernelBertTokenizer(api, info) {}
|
||||
|
||||
void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
const OrtValue *const input = ort_.KernelContext_GetInput(context, 0);
|
||||
const OrtValue* const input = ort_.KernelContext_GetInput(context, 0);
|
||||
std::vector<std::string> input_data;
|
||||
GetTensorMutableDataString(api_, ort_, context, input, input_data);
|
||||
|
||||
|
@ -380,7 +371,7 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
|||
const std::vector<int64_t> inner_dims{1LL};
|
||||
for (int32_t i = 0; i < 3; ++i) {
|
||||
OrtValue* const value = ort_.KernelContext_GetOutput(context, i, outer_dims.data(), outer_dims.size());
|
||||
OrtTensorTypeAndShapeInfo *const info = ort_.GetTensorTypeAndShape(value);
|
||||
OrtTensorTypeAndShapeInfo* const info = ort_.GetTensorTypeAndShape(value);
|
||||
ort_.SetDimensions(info, inner_dims.data(), inner_dims.size());
|
||||
ort_.ReleaseTensorTypeAndShapeInfo(info);
|
||||
}
|
||||
|
@ -390,10 +381,6 @@ void KernelHfBertTokenizer::Compute(OrtKernelContext* context) {
|
|||
SetOutput(context, 2, outer_dims, token_type_ids);
|
||||
}
|
||||
|
||||
void* CustomOpHfBertTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
}
|
||||
|
||||
const char* CustomOpHfBertTokenizer::GetName() const { return "HfBertTokenizer"; }
|
||||
|
||||
size_t CustomOpHfBertTokenizer::GetInputTypeCount() const {
|
||||
|
|
|
@ -42,8 +42,8 @@ class TruncateStrategy final {
|
|||
class WordpieceTokenizer final {
|
||||
public:
|
||||
WordpieceTokenizer(
|
||||
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
|
||||
ustring suffix_indicator, int max_input_chars_per_word = 100);
|
||||
std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
|
||||
ustring suffix_indicator, int max_input_chars_per_word = 100);
|
||||
std::vector<ustring> Tokenize(const ustring& text);
|
||||
std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
|
||||
std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
|
||||
|
@ -90,7 +90,7 @@ class BertTokenizer final {
|
|||
};
|
||||
|
||||
struct KernelBertTokenizer : BaseKernel {
|
||||
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
protected:
|
||||
|
@ -98,7 +98,6 @@ struct KernelBertTokenizer : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
@ -107,12 +106,11 @@ struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelB
|
|||
};
|
||||
|
||||
struct KernelHfBertTokenizer : KernelBertTokenizer {
|
||||
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -7,16 +7,13 @@ BertTokenizerDecoder::BertTokenizerDecoder(
|
|||
std::string pad_token,
|
||||
std::string cls_token,
|
||||
std::string mask_token,
|
||||
std::string suffix_indicator
|
||||
) :
|
||||
unk_token_(unk_token),
|
||||
suffix_indicator_(suffix_indicator),
|
||||
raw_vocab_(vocab)
|
||||
{
|
||||
std::string suffix_indicator) : unk_token_(unk_token),
|
||||
suffix_indicator_(suffix_indicator),
|
||||
raw_vocab_(vocab) {
|
||||
auto tokens = SplitString(raw_vocab_, "\n", true);
|
||||
vocab_.reserve(tokens.size());
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
auto& token = tokens[i];
|
||||
auto& token = tokens[i];
|
||||
if (token == unk_token) {
|
||||
unk_token_id_ = static_cast<int32_t>(i);
|
||||
}
|
||||
|
@ -82,7 +79,6 @@ std::string BertTokenizerDecoder::Decode(const std::vector<int64_t>& ids, bool s
|
|||
}
|
||||
|
||||
bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new_token_id) {
|
||||
|
||||
if (pre_token_id < 0) {
|
||||
return true;
|
||||
}
|
||||
|
@ -123,8 +119,8 @@ bool BertTokenizerDecoder::RemoveTokenizeSpace(int64_t pre_token_id, int64_t new
|
|||
return false;
|
||||
}
|
||||
|
||||
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab_file");
|
||||
KernelBertTokenizerDecoder::KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab_file");
|
||||
std::string unk_token = TryToGetAttributeWithDefault("unk_token", std::string("[UNK]"));
|
||||
std::string sep_token = TryToGetAttributeWithDefault("sep_token", std::string("[SEP]"));
|
||||
std::string pad_token = TryToGetAttributeWithDefault("pad_token", std::string("[PAD]"));
|
||||
|
@ -154,7 +150,7 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
|||
OrtTensorDimensions positions_dim(ort_, positions);
|
||||
if (use_indices_ &&
|
||||
(!((positions_dim.Size() == 0) ||
|
||||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
|
||||
(positions_dim.size() == 2 && positions_dim[1] == 2)))) {
|
||||
ORTX_CXX_API_THROW("[BertTokenizerDecoder]: Expect positions empty or a [n, 2] matrix when use indices", ORT_INVALID_GRAPH);
|
||||
}
|
||||
|
||||
|
@ -181,10 +177,6 @@ void KernelBertTokenizerDecoder::Compute(OrtKernelContext* context) {
|
|||
FillTensorDataString(api_, ort_, context, result, output);
|
||||
}
|
||||
|
||||
void* CustomOpBertTokenizerDecoder::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpBertTokenizerDecoder::GetName() const { return "BertTokenizerDecoder"; };
|
||||
|
||||
size_t CustomOpBertTokenizerDecoder::GetInputTypeCount() const {
|
||||
|
|
|
@ -10,11 +10,10 @@
|
|||
#include "string_utils.h"
|
||||
#include "string_tensor.h"
|
||||
|
||||
|
||||
class BertTokenizerDecoder {
|
||||
public:
|
||||
BertTokenizerDecoder(std::string vocab, std::string unk_token, std::string sep_token, std::string pad_token,
|
||||
std::string cls_token,std::string mask_token,std::string suffix_indicator);
|
||||
std::string cls_token, std::string mask_token, std::string suffix_indicator);
|
||||
std::string Decode(const std::vector<int64_t>& ids, bool skip_special_tokens, bool clean_up_tokenization_spaces);
|
||||
|
||||
private:
|
||||
|
@ -33,8 +32,9 @@ class BertTokenizerDecoder {
|
|||
};
|
||||
|
||||
struct KernelBertTokenizerDecoder : BaseKernel {
|
||||
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelBertTokenizerDecoder(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
std::shared_ptr<BertTokenizerDecoder> decoder_;
|
||||
bool use_indices_;
|
||||
|
@ -43,10 +43,9 @@ struct KernelBertTokenizerDecoder : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpBertTokenizerDecoder : OrtW::CustomOpBase<CustomOpBertTokenizerDecoder, KernelBertTokenizerDecoder> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
size_t GetOutputTypeCount() const;
|
||||
ONNXTensorElementDataType GetOutputType(size_t index) const;
|
||||
};
|
||||
};
|
||||
|
|
|
@ -9,8 +9,9 @@
|
|||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
|
||||
model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||
KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info), max_sentence(-1) {
|
||||
model_data_ = ort_.KernelInfoGetAttribute<std::string>(&info, "model");
|
||||
if (model_data_.empty()) {
|
||||
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
@ -24,7 +25,7 @@ KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api
|
|||
model_ = std::shared_ptr<void>(model_ptr, FreeModel);
|
||||
|
||||
if (HasAttribute("max_sentence")) {
|
||||
max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence"));
|
||||
max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(&info, "max_sentence"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,10 +79,6 @@ void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
|
|||
OrtW::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
|
||||
}
|
||||
|
||||
void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpBlingFireSentenceBreaker::GetName() const { return "BlingFireSentenceBreaker"; };
|
||||
|
||||
size_t CustomOpBlingFireSentenceBreaker::GetInputTypeCount() const {
|
||||
|
|
|
@ -16,8 +16,9 @@ extern "C" int FreeModel(void* ModelPtr);
|
|||
extern "C" void* SetModel(const unsigned char* pImgBytes, int ModelByteCount);
|
||||
|
||||
struct KernelBlingFireSentenceBreaker : BaseKernel {
|
||||
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
using ModelPtr = std::shared_ptr<void>;
|
||||
ModelPtr model_;
|
||||
|
@ -25,8 +26,8 @@ struct KernelBlingFireSentenceBreaker : BaseKernel {
|
|||
int max_sentence;
|
||||
};
|
||||
|
||||
struct CustomOpBlingFireSentenceBreaker : OrtW::CustomOpBase<CustomOpBlingFireSentenceBreaker, KernelBlingFireSentenceBreaker> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
struct CustomOpBlingFireSentenceBreaker : OrtW::CustomOpBase<CustomOpBlingFireSentenceBreaker,
|
||||
KernelBlingFireSentenceBreaker> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -64,14 +64,14 @@ bool IsEmptyUstring(const ustring& str) {
|
|||
return std::all_of(str.begin(), str.end(), [](char32_t ch) { return IsInUnicodeSpace(ch); });
|
||||
}
|
||||
|
||||
KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info)
|
||||
KernelClipBpeTokenizer::KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
|
||||
if (vocab.empty()) {
|
||||
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges");
|
||||
std::string merges = ort_.KernelInfoGetAttribute<std::string>(&info, "merges");
|
||||
if (merges.empty()) {
|
||||
ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
@ -202,10 +202,6 @@ void KernelClipBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
void* CustomOpClipBpeTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
}
|
||||
|
||||
const char* CustomOpClipBpeTokenizer::GetName() const {
|
||||
return "CLIPTokenizer";
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
class VocabData;
|
||||
|
||||
struct KernelClipBpeTokenizer : BaseKernel {
|
||||
KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelClipBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -18,7 +18,6 @@ struct KernelClipBpeTokenizer : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpClipBpeTokenizer : OrtW::CustomOpBase<CustomOpClipBpeTokenizer, KernelClipBpeTokenizer> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "string_tensor.h"
|
||||
#include "unicode.h"
|
||||
|
||||
//Note: the following logic comes from CPython: unicodetype_db.h (_PyUnicode_IsWhitespace)
|
||||
// Note: the following logic comes from CPython: unicodetype_db.h (_PyUnicode_IsWhitespace)
|
||||
bool IsUnicodeSpace(char32_t ch) {
|
||||
switch (ch) {
|
||||
case 0x0009:
|
||||
|
@ -63,14 +63,14 @@ bool IsEmptyUString(const ustring& str) {
|
|||
return std::all_of(str.begin(), str.end(), [](char32_t ch) { return IsUnicodeSpace(ch); });
|
||||
}
|
||||
|
||||
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info)
|
||||
KernelBpeTokenizer::KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
||||
std::string vocab = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
|
||||
if (vocab.empty()) {
|
||||
ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
std::string merges = ort_.KernelInfoGetAttribute<std::string>(info, "merges");
|
||||
std::string merges = ort_.KernelInfoGetAttribute<std::string>(&info, "merges");
|
||||
if (merges.empty()) {
|
||||
ORTX_CXX_API_THROW("merges shouldn't be empty.", ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
@ -182,10 +182,6 @@ void KernelBpeTokenizer::Compute(OrtKernelContext* context) {
|
|||
}
|
||||
}
|
||||
|
||||
void* CustomOpBpeTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
}
|
||||
|
||||
const char* CustomOpBpeTokenizer::GetName() const {
|
||||
return "GPT2Tokenizer";
|
||||
}
|
||||
|
|
|
@ -2,11 +2,10 @@
|
|||
#include "ocos.h"
|
||||
#include "ustring.h"
|
||||
|
||||
|
||||
class VocabData;
|
||||
|
||||
struct KernelBpeTokenizer : BaseKernel {
|
||||
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelBpeTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -18,7 +17,6 @@ struct KernelBpeTokenizer : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpBpeTokenizer : OrtW::CustomOpBase<CustomOpBpeTokenizer, KernelBpeTokenizer> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -10,16 +10,15 @@
|
|||
#include "sentencepiece_model.pb.h"
|
||||
|
||||
struct KernelSentencepieceDecoder : BaseKernel {
|
||||
KernelSentencepieceDecoder(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string model_blob = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||
KernelSentencepieceDecoder(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
std::string model_blob = ort_.KernelInfoGetAttribute<std::string>(&info, "model");
|
||||
sentencepiece::ModelProto model_proto;
|
||||
model_proto.ParseFromArray(model_blob.data(), static_cast<int>(model_blob.size()));
|
||||
sentencepiece::util::Status status = tokenizer_.Load(model_proto);
|
||||
if (!status.ok()){
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"Failed to create SentencePieceProcessor instance. Error code is ",
|
||||
(int)status.code(), ". Message is '", status.error_message(), "'."),
|
||||
ORT_INVALID_PROTOBUF);
|
||||
if (!status.ok()) {
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to create SentencePieceProcessor instance. Error code is ",
|
||||
(int)status.code(), ". Message is '", status.error_message(), "'."),
|
||||
ORT_INVALID_PROTOBUF);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -40,7 +39,7 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
|||
std::back_inserter(tids),
|
||||
[](auto _id) { return static_cast<int>(_id); });
|
||||
auto status = tokenizer_.Decode(tids, &decoded_string);
|
||||
if (!status.ok()){
|
||||
if (!status.ok()) {
|
||||
ORTX_CXX_API_THROW("[SentencePieceDecoder] model decoding failed.", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
|
@ -54,10 +53,6 @@ struct KernelSentencepieceDecoder : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpSentencepieceDecoder : OrtW::CustomOpBase<CustomOpSentencepieceDecoder, KernelSentencepieceDecoder> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "SentencepieceDecoder";
|
||||
}
|
||||
|
|
|
@ -7,8 +7,9 @@
|
|||
#include "string_tensor.h"
|
||||
#include "base64.h"
|
||||
|
||||
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "model");
|
||||
KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
std::string model_as_string = ort_.KernelInfoGetAttribute<std::string>(&info, "model");
|
||||
sentencepiece::ModelProto model_proto;
|
||||
std::vector<uint8_t> model_as_bytes;
|
||||
if (base64_decode(model_as_string, model_as_bytes)) {
|
||||
|
@ -18,16 +19,16 @@ KernelSentencepieceTokenizer::KernelSentencepieceTokenizer(const OrtApi& api, co
|
|||
}
|
||||
sentencepiece::util::Status status = tokenizer_.Load(model_proto);
|
||||
if (!status.ok())
|
||||
throw std::runtime_error(MakeString(
|
||||
"Failed to create SentencePieceProcessor instance. Error code is ",
|
||||
(int)status.code(), ". Message is '", status.error_message(), "'."));
|
||||
ORTX_CXX_API_THROW(MakeString("Failed to create SentencePieceProcessor instance. Error code is ",
|
||||
(int)status.code(), ". Message is '", status.error_message(), "'."),
|
||||
ORT_FAIL);
|
||||
}
|
||||
|
||||
static void _check_dimension_constant(OrtW::CustomOpApi ort, const OrtValue* ort_value, const char* name) {
|
||||
OrtTensorDimensions dimensions(ort, ort_value);
|
||||
if (dimensions.size() != 1 || dimensions[0] != 1)
|
||||
throw std::runtime_error(MakeString(
|
||||
name, " must contain only one element. It has ", dimensions.size(), " dimensions."));
|
||||
ORTX_CXX_API_THROW(MakeString(name, " must contain only one element. It has ", dimensions.size(), " dimensions."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
|
||||
|
@ -64,8 +65,7 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
for (size_t i = 0; i < str_input.size(); ++i) {
|
||||
std::vector<int> inloop;
|
||||
if (!tokenizer_.Encode(str_input[i].c_str(), &inloop).ok())
|
||||
throw std::runtime_error(MakeString(
|
||||
"Unable to encode string '", str_input[i], "'."));
|
||||
ORTX_CXX_API_THROW(MakeString("Unable to encode string '", str_input[i], "'."), ORT_INVALID_ARGUMENT);
|
||||
indices.push_back(content.size());
|
||||
|
||||
if (*p_add_rev) {
|
||||
|
@ -103,10 +103,6 @@ void KernelSentencepieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
memcpy(ptr_indices, indices.data(), indices.size() * sizeof(int64_t));
|
||||
}
|
||||
|
||||
void* CustomOpSentencepieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpSentencepieceTokenizer::GetName() const {
|
||||
return "SentencepieceTokenizer";
|
||||
};
|
||||
|
@ -128,7 +124,7 @@ ONNXTensorElementDataType CustomOpSentencepieceTokenizer::GetInputType(size_t in
|
|||
case 5:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("Unexpected input index ", index));
|
||||
ORTX_CXX_API_THROW(MakeString("Unexpected input index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -143,6 +139,6 @@ ONNXTensorElementDataType CustomOpSentencepieceTokenizer::GetOutputType(size_t i
|
|||
case 1:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
default:
|
||||
throw std::runtime_error(MakeString("[SentencepieceTokenizer] Unexpected output index ", index));
|
||||
ORTX_CXX_API_THROW(MakeString("[SentencepieceTokenizer] Unexpected output index ", index), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -7,17 +7,16 @@
|
|||
#include "string_utils.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
|
||||
|
||||
struct KernelSentencepieceTokenizer : BaseKernel {
|
||||
KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelSentencepieceTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
sentencepiece::SentencePieceProcessor tokenizer_;
|
||||
};
|
||||
|
||||
struct CustomOpSentencepieceTokenizer : OrtW::CustomOpBase<CustomOpSentencepieceTokenizer, KernelSentencepieceTokenizer> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
struct CustomOpSentencepieceTokenizer : OrtW::CustomOpBase<CustomOpSentencepieceTokenizer,
|
||||
KernelSentencepieceTokenizer> {
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
|
|
@ -4,13 +4,16 @@
|
|||
#include "wordpiece_tokenizer.hpp"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelWordpieceTokenizer::KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info)
|
||||
: BaseKernel(api, info) {
|
||||
// https://github.com/tensorflow/text/blob/master/docs/api_docs/python/text/WordpieceTokenizer.md
|
||||
// https://github.com/tensorflow/text/blob/master/tensorflow_text/python/ops/bert_tokenizer.py
|
||||
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(info, "vocab");
|
||||
std::string suffix_indicator = ort_.KernelInfoGetAttribute<std::string>(info, "suffix_indicator");
|
||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(info, "unknown_token");
|
||||
max_input_chars_per_word_ = HasAttribute("max_input_chars_per_word") ? ort_.KernelInfoGetAttribute<int64_t>(info, "max_input_chars_per_word") : 200;
|
||||
std::string vocab_as_string = ort_.KernelInfoGetAttribute<std::string>(&info, "vocab");
|
||||
std::string suffix_indicator = ort_.KernelInfoGetAttribute<std::string>(&info, "suffix_indicator");
|
||||
std::string unk = ort_.KernelInfoGetAttribute<std::string>(&info, "unknown_token");
|
||||
max_input_chars_per_word_ = HasAttribute("max_input_chars_per_word")
|
||||
? ort_.KernelInfoGetAttribute<int64_t>(&info, "max_input_chars_per_word")
|
||||
: 200;
|
||||
suffix_indicator_ = ustring(suffix_indicator);
|
||||
unk_token_ = ustring(unk);
|
||||
|
||||
|
@ -73,7 +76,8 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
|
|||
} else if (text_index == existing_rows[row_index]) {
|
||||
if (row_index >= n_existing_rows)
|
||||
ORTX_CXX_API_THROW(MakeString(
|
||||
"row_index=", row_index, " is out of range=", n_existing_rows, "."), ORT_INVALID_ARGUMENT);
|
||||
"row_index=", row_index, " is out of range=", n_existing_rows, "."),
|
||||
ORT_INVALID_ARGUMENT);
|
||||
rows.push_back(indices.size());
|
||||
++row_index;
|
||||
}
|
||||
|
@ -162,10 +166,6 @@ void KernelWordpieceTokenizer::Compute(OrtKernelContext* context) {
|
|||
ptr_row_lengths[i] = row_begins[static_cast<size_t>(i)];
|
||||
}
|
||||
|
||||
void* CustomOpWordpieceTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
};
|
||||
|
||||
const char* CustomOpWordpieceTokenizer::GetName() const {
|
||||
return "WordpieceTokenizer";
|
||||
};
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
#include "string_tensor.h"
|
||||
|
||||
struct KernelWordpieceTokenizer : BaseKernel {
|
||||
KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo* info);
|
||||
KernelWordpieceTokenizer(const OrtApi& api, const OrtKernelInfo& info);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -22,7 +22,6 @@ struct KernelWordpieceTokenizer : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpWordpieceTokenizer : OrtW::CustomOpBase<CustomOpWordpieceTokenizer, KernelWordpieceTokenizer> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
|
||||
const char* GetName() const;
|
||||
size_t GetInputTypeCount() const;
|
||||
ONNXTensorElementDataType GetInputType(size_t index) const;
|
||||
|
@ -43,4 +42,4 @@ void KernelWordpieceTokenizer_Tokenizer(const std::unordered_map<std::u32string,
|
|||
std::vector<int64_t>& rows,
|
||||
const int64_t* existing_rows = nullptr,
|
||||
int64_t n_existing_rows = 0,
|
||||
int64_t max_input_chars_per_word = 200);
|
||||
int64_t max_input_chars_per_word = 200);
|
||||
|
|
|
@ -10,16 +10,12 @@
|
|||
|
||||
namespace ort_extensions {
|
||||
struct KernelDecodeImage : BaseKernel {
|
||||
KernelDecodeImage(const OrtApi& api) : BaseKernel(api) {}
|
||||
KernelDecodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
};
|
||||
|
||||
struct CustomOpDecodeImage : OrtW::CustomOpBase<CustomOpDecodeImage, KernelDecodeImage> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new KernelDecodeImage(api);
|
||||
}
|
||||
|
||||
void KernelDestroy(void* op_kernel) {
|
||||
delete static_cast<KernelDecodeImage*>(op_kernel);
|
||||
}
|
||||
|
|
|
@ -10,15 +10,20 @@
|
|||
|
||||
namespace ort_extensions {
|
||||
struct KernelEncodeImage : BaseKernel {
|
||||
KernelEncodeImage(const OrtApi& api, const std::string& format)
|
||||
: BaseKernel{api},
|
||||
extension_{std::string(".") + format} {
|
||||
KernelEncodeImage(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel{api, info} {
|
||||
OrtW::CustomOpApi op_api{api};
|
||||
std::string format = op_api.KernelInfoGetAttribute<std::string>(&info, "format");
|
||||
if (format != "jpg" && format != "png") {
|
||||
ORTX_CXX_API_THROW("[EncodeImage] 'format' attribute value must be 'jpg' or 'png'.", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
extension_ = std::string(".") + format;
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
const std::string extension_;
|
||||
std::string extension_;
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
|
@ -28,16 +33,6 @@ struct KernelEncodeImage : BaseKernel {
|
|||
/// Default is 'jpg'
|
||||
/// </summary>
|
||||
struct CustomOpEncodeImage : OrtW::CustomOpBase<CustomOpEncodeImage, KernelEncodeImage> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
OrtW::CustomOpApi op_api{api};
|
||||
std::string format = op_api.KernelInfoGetAttribute<std::string>(info, "format");
|
||||
if (format != "jpg" && format != "png") {
|
||||
ORTX_CXX_API_THROW("[EncodeImage] 'format' attribute value must be 'jpg' or 'png'.", ORT_RUNTIME_EXCEPTION);
|
||||
}
|
||||
|
||||
return new KernelEncodeImage(api, format);
|
||||
}
|
||||
|
||||
const char* GetName() const {
|
||||
return "EncodeImage";
|
||||
}
|
||||
|
|
|
@ -210,7 +210,7 @@ typedef struct {
|
|||
std::vector<int64_t> dimensions;
|
||||
} InputInformation;
|
||||
|
||||
PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
|
||||
PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo& info,
|
||||
uint64_t id, const std::vector<std::string>& attrs)
|
||||
: api_(api),
|
||||
ort_(api_),
|
||||
|
@ -218,7 +218,7 @@ PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
|
|||
size_t size;
|
||||
for (std::vector<std::string>::const_iterator it = attrs.begin(); it != attrs.end(); ++it) {
|
||||
size = 0;
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(info, it->c_str(), nullptr, &size);
|
||||
OrtStatus* status = api_.KernelInfoGetAttribute_string(&info, it->c_str(), nullptr, &size);
|
||||
if ((status != nullptr) && api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
|
||||
std::string error_message(api_.GetErrorMessage(status));
|
||||
api_.ReleaseStatus(status);
|
||||
|
@ -231,7 +231,7 @@ PyCustomOpKernel::PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info,
|
|||
}
|
||||
attrs_values_[*it] = "";
|
||||
attrs_values_[*it].resize(size);
|
||||
status = api_.KernelInfoGetAttribute_string(info, it->c_str(), &(attrs_values_[*it][0]), &size);
|
||||
status = api_.KernelInfoGetAttribute_string(&info, it->c_str(), &(attrs_values_[*it][0]), &size);
|
||||
if ((status != nullptr) && (api_.GetErrorCode(status) != ORT_OK)) {
|
||||
api_.ReleaseStatus(status);
|
||||
throw std::runtime_error(MakeString(
|
||||
|
@ -373,7 +373,7 @@ void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
|
|||
op_domain = cod->op_type.substr(0, dm_pos);
|
||||
op = cod->op_type.substr(dm_pos + 2, -1);
|
||||
}
|
||||
|
||||
|
||||
// No need to protect against concurrent access, GIL is doing that.
|
||||
auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>());
|
||||
const auto [it_domain_op, success] = PyOp_container().insert(val);
|
||||
|
@ -386,7 +386,7 @@ const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t num) {
|
|||
EnablePyCustomOps(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
auto it = PyOp_container().find(c_OpDomain);
|
||||
if (it != PyOp_container().end()) {
|
||||
const std::vector<PyCustomOpFactory>& ref = it->second;
|
||||
|
@ -414,7 +414,7 @@ bool EnablePyCustomOps(bool enabled) {
|
|||
OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions* options, const OrtApi* ortApi){
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
OrtStatus* status = nullptr;
|
||||
|
||||
|
||||
for (auto const& val_pair: PyOp_container()) {
|
||||
if (val_pair.first == c_OpDomain) {
|
||||
continue; // Register this domain in the second iteration.
|
||||
|
|
|
@ -37,7 +37,7 @@ struct PyCustomOpDef {
|
|||
};
|
||||
|
||||
struct PyCustomOpKernel {
|
||||
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
|
||||
PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo& info, uint64_t id, const std::vector<std::string>& attrs);
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
|
@ -49,7 +49,7 @@ struct PyCustomOpKernel {
|
|||
|
||||
struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
|
||||
PyCustomOpFactory() {
|
||||
// STL vector needs it.
|
||||
// STL vector needs it.
|
||||
}
|
||||
|
||||
PyCustomOpFactory(const PyCustomOpDef* opdef, const std::string& domain, const std::string& op) {
|
||||
|
@ -60,8 +60,8 @@ struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKerne
|
|||
op_type_ = op;
|
||||
}
|
||||
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info, opdef_->obj_id, opdef_->attrs);
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
|
||||
return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
|
||||
};
|
||||
|
||||
const char* GetName() const {
|
||||
|
|
|
@ -60,15 +60,19 @@ class ExternalCustomOps {
|
|||
};
|
||||
|
||||
extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
|
||||
OCOS_API_IMPL_BEGIN
|
||||
ExternalCustomOps::instance().Add(c_op);
|
||||
OCOS_API_IMPL_END
|
||||
return true;
|
||||
}
|
||||
|
||||
extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
|
||||
OrtStatus* status = nullptr;
|
||||
OCOS_API_IMPL_BEGIN
|
||||
|
||||
OrtCustomOpDomain* domain = nullptr;
|
||||
const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
|
||||
std::set<std::string> pyop_nameset;
|
||||
OrtStatus* status = nullptr;
|
||||
|
||||
#if defined(PYTHON_OP_SUPPORT)
|
||||
if (status = RegisterPythonDomainAndOps(options, ortApi); status) {
|
||||
|
@ -170,5 +174,9 @@ extern "C" ORTX_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptio
|
|||
}
|
||||
}
|
||||
|
||||
return ortApi->AddCustomOpDomain(options, domain);
|
||||
status = ortApi->AddCustomOpDomain(options, domain);
|
||||
|
||||
OCOS_API_IMPL_END
|
||||
|
||||
return status;
|
||||
}
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,120 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include <filesystem>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
#include "ocos.h"
|
||||
#include "test_kernel.hpp"
|
||||
|
||||
// throw in ctor which will be called during model load
|
||||
struct ExceptionalKernel1 : BaseKernel {
|
||||
ExceptionalKernel1(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
ORTX_CXX_API_THROW("Throw in ctor", ORT_FAIL);
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {}
|
||||
};
|
||||
|
||||
// throw in Compute which will be called during model execution
|
||||
struct ExceptionalKernel2 : BaseKernel {
|
||||
ExceptionalKernel2(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
ORTX_CXX_API_THROW("Throw in Compute", ORT_FAIL);
|
||||
}
|
||||
};
|
||||
|
||||
struct ExceptionalCustomOp1 : OrtW::CustomOpBase<ExceptionalCustomOp1, ExceptionalKernel1> {
|
||||
const char* GetName() const { return "ExceptionalCustomOp1"; };
|
||||
size_t GetInputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
};
|
||||
|
||||
struct ExceptionalCustomOp2 : OrtW::CustomOpBase<ExceptionalCustomOp2, ExceptionalKernel2> {
|
||||
const char* GetName() const { return "ExceptionalCustomOp2"; };
|
||||
size_t GetInputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
|
||||
};
|
||||
|
||||
extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
|
||||
static ExceptionalCustomOp1 custom_op1;
|
||||
static ExceptionalCustomOp2 custom_op2;
|
||||
|
||||
// test a call to an entry point wrapped with OCOS_API_IMPL_BEGIN/OCOS_API_IMPL_END behaves as expected.
|
||||
// the throw in the ctor of ExceptionalCustomOp1 should be triggered during model loading.
|
||||
TEST(Exceptions, TestApiTryCatch_ThrowInModelLoad) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
AddExternalCustomOp(&custom_op1);
|
||||
|
||||
Ort::SessionOptions session_options;
|
||||
RegisterCustomOps((OrtSessionOptions*)session_options, OrtGetApiBase());
|
||||
|
||||
std::filesystem::path model("data/exceptional_custom_op1.onnx");
|
||||
auto fail_fn = [&]() {
|
||||
Ort::Session session(*ort_env, model.c_str(), session_options);
|
||||
};
|
||||
|
||||
// if no exceptions, the ORTX_CXX_API_THROW will trigger the log+abort
|
||||
// if no exception propagation, the OCOS_API_IMPL_END will trigger the log+abort
|
||||
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
|
||||
// the exception should be caught and logged, and the process should abort so the exception is not propagated up.
|
||||
// log output needs to be manually checked
|
||||
// can test on Linux but not Windows.
|
||||
#if !defined(_WIN32)
|
||||
EXPECT_EXIT(fail_fn(), ::testing::KilledBySignal(SIGABRT), ".*");
|
||||
#endif
|
||||
#else
|
||||
// ORT catches the exceptions thrown by the custom op and rethrows them as Ort::Exception
|
||||
EXPECT_THROW(fail_fn(), Ort::Exception);
|
||||
#endif
|
||||
}
|
||||
|
||||
// test a call to an entry point wrapped with OCOS_API_IMPL_BEGIN/OCOS_API_IMPL_END behaves as expected.
|
||||
// the throw in the Compute of ExceptionalCustomOp2 should be triggered during model execution.
|
||||
TEST(Exceptions, TestApiTryCatch_ThrowInModelExecution) {
|
||||
auto ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "Default");
|
||||
AddExternalCustomOp(&custom_op2);
|
||||
|
||||
Ort::SessionOptions session_options;
|
||||
RegisterCustomOps((OrtSessionOptions*)session_options, OrtGetApiBase());
|
||||
|
||||
std::filesystem::path model("data/exceptional_custom_op2.onnx");
|
||||
Ort::Session session(*ort_env, model.c_str(), session_options);
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
const char* input_names[] = {"A"};
|
||||
const char* output_names[] = {"B"};
|
||||
|
||||
std::vector<int64_t> dims = {2};
|
||||
std::vector<float> input = {0.f, 1.f};
|
||||
std::vector<Ort::Value> ort_input;
|
||||
ort_input.push_back(Ort::Value::CreateTensor<float>(memory_info, input.data(), input.size(),
|
||||
dims.data(), dims.size()));
|
||||
|
||||
auto fail_fn = [&]() {
|
||||
// executing the model should call Compute of the custom op, which should throw
|
||||
std::vector<Ort::Value> ort_outputs;
|
||||
ort_outputs = session.Run(Ort::RunOptions{nullptr}, input_names, ort_input.data(), ort_input.size(),
|
||||
output_names, 1);
|
||||
};
|
||||
|
||||
// if no exceptions, the ORTX_CXX_API_THROW will trigger the log+abort
|
||||
// if no exception propagation, the OCOS_API_IMPL_END will trigger the log+abort
|
||||
#if defined(OCOS_NO_EXCEPTIONS) || defined(OCOS_PREVENT_EXCEPTION_PROPAGATION)
|
||||
// can test on Linux but not Windows
|
||||
#if !defined(_WIN32)
|
||||
EXPECT_EXIT(fail_fn(), ::testing::KilledBySignal(SIGABRT), ".*");
|
||||
#endif
|
||||
#else
|
||||
// ORT catches the exceptions thrown by the custom op and rethrows them as Ort::Exception
|
||||
EXPECT_THROW(fail_fn(), Ort::Exception);
|
||||
#endif
|
||||
}
|
|
@ -1,40 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "exceptions.h"
|
||||
|
||||
#define TEST_MAIN main
|
||||
|
||||
#if defined(__APPLE__)
|
||||
#include <TargetConditionals.h>
|
||||
#if TARGET_OS_SIMULATOR || TARGET_OS_IOS
|
||||
#undef TEST_MAIN
|
||||
#define TEST_MAIN main_no_link_ // there is a UI test app for iOS.
|
||||
#endif
|
||||
#include <TargetConditionals.h>
|
||||
#if TARGET_OS_SIMULATOR || TARGET_OS_IOS
|
||||
#undef TEST_MAIN
|
||||
#define TEST_MAIN main_no_link_ // there is a UI test app for iOS.
|
||||
#endif
|
||||
|
||||
// currently this is the only place with a try/catch. Move the macros to common code if that changes.
|
||||
#ifdef OCOS_NO_EXCEPTIONS
|
||||
#define OCOS_TRY if (true)
|
||||
#define OCOS_CATCH(x) else if (false)
|
||||
#define OCOS_RETHROW
|
||||
// In order to ignore the catch statement when a specific exception (not ... ) is caught and referred
|
||||
// in the body of the catch statements, it is necessary to wrap the body of the catch statement into
|
||||
// a lambda function. otherwise the exception referred will be undefined and cause build break
|
||||
#define OCOS_HANDLE_EXCEPTION(func)
|
||||
#else
|
||||
#define OCOS_TRY try
|
||||
#define OCOS_CATCH(x) catch (x)
|
||||
#define OCOS_RETHROW throw;
|
||||
#define OCOS_HANDLE_EXCEPTION(func) func()
|
||||
#endif
|
||||
|
||||
|
||||
namespace {
|
||||
void FixCurrentDir() {
|
||||
// adjust for the Google Test Adapter in Visual Studio not setting the current path to $(ProjectDir),
|
||||
|
|
|
@ -21,7 +21,7 @@ const char* GetLibraryPath() {
|
|||
}
|
||||
|
||||
struct KernelOne : BaseKernel {
|
||||
KernelOne(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelOne(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context) {
|
||||
|
@ -67,7 +67,7 @@ struct CustomOpOne : OrtW::CustomOpBase<CustomOpOne, KernelOne> {
|
|||
};
|
||||
|
||||
struct KernelTwo : BaseKernel {
|
||||
KernelTwo(const OrtApi& api) : BaseKernel(api) {
|
||||
KernelTwo(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
}
|
||||
void Compute(OrtKernelContext* context) {
|
||||
// Setup inputs
|
||||
|
@ -110,7 +110,7 @@ struct CustomOpTwo : OrtW::CustomOpBase<CustomOpTwo, KernelTwo> {
|
|||
};
|
||||
|
||||
struct KernelThree : BaseKernel {
|
||||
KernelThree(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
|
||||
KernelThree(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
|
||||
if (!TryToGetAttribute("substr", substr_)) {
|
||||
substr_ = "";
|
||||
}
|
||||
|
@ -137,8 +137,12 @@ struct KernelThree : BaseKernel {
|
|||
};
|
||||
|
||||
struct CustomOpThree : OrtW::CustomOpBase<CustomOpThree, KernelThree> {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return CreateKernelImpl(api, info);
|
||||
// This is example code to show how to override the CustomOpBase::CreateKernel method even though it is not virtual.
|
||||
// The CustomOpBase implementation will call the CreateKernel of the first class specified in the template,
|
||||
// and from there it's also possible to call the base CreateKernel as per below.
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
|
||||
std::cout << "Called CreateKernel override" << std::endl;
|
||||
return OrtW::CustomOpBase<CustomOpThree, KernelThree>::CreateKernel(api, info);
|
||||
};
|
||||
const char* GetName() const {
|
||||
return "CustomOpThree";
|
||||
|
@ -189,11 +193,11 @@ void GetTensorMutableDataString(const OrtApi& api, const OrtValue* value, std::v
|
|||
OrtTensorDimensions dimensions(OrtW::CustomOpApi(api), value);
|
||||
size_t len = static_cast<size_t>(dimensions.Size());
|
||||
size_t data_len;
|
||||
Ort::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
|
||||
OrtW::ThrowOnError(api, api.GetStringTensorDataLength(value, &data_len));
|
||||
output.resize(len);
|
||||
std::vector<char> result(data_len + len + 1, '\0');
|
||||
std::vector<size_t> offsets(len);
|
||||
Ort::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
|
||||
OrtW::ThrowOnError(api, api.GetStringTensorContent(value, (void*)result.data(), data_len, offsets.data(), offsets.size()));
|
||||
output.resize(len);
|
||||
for (int64_t i = (int64_t)len - 1; i >= 0; --i) {
|
||||
if (i < len - 1)
|
||||
|
|
Загрузка…
Ссылка в новой задаче